什么是 SideEffect

在计算机科学中,函数副作用(side effect)指当调用函数时,除了返回可能的函数值之外,还对主调用函数产生附加的影响。例如修改全局变量(函数外的变量),修改参数,向主调方的终端、管道输出字符或改变外部存储信息等。

—— WikiPedia 副作用_(计算机科学)

为什么需要在 PaddleSOT 中使用 SideEffect

我们先来看一个demo, 我们可以看到此时的的SOT并不能准确的还原global并产生副作用.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from sot.translate import symbolic_translate

global_x = 1

def foo():
global global_x
global_x = global_x + global_x
return global_x


def main():
dygraph_out = foo
symbolic_translate_out = symbolic_translate(foo)

print("symbolic_translate_out:", symbolic_translate_out()) # symbolic_translate_out: 1
print("symbolic_translate_out:", symbolic_translate_out()) # symbolic_translate_out: 1
print("dygraph_out:", dygraph_out()) # dygraph_out: 2
print("dygraph_out:", dygraph_out()) # dygraph_out: 4

原生字节码:

1
2
3
4
5
6
7
10           0 LOAD_GLOBAL              0 (global_x)
2 LOAD_GLOBAL 0 (global_x)
4 BINARY_ADD
6 STORE_GLOBAL 0 (global_x)

11 8 LOAD_GLOBAL 0 (global_x)
10 RETURN_VALUE

转写后字节码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
8           0 LOAD_GLOBAL              1 (paddle_set_eval_frame_fn)
2 LOAD_CONST 0 (None)
4 CALL_FUNCTION 1
6 STORE_FAST 0 (___old_eval_frame)
8 LOAD_GLOBAL 2 (__compiled_fn_dummy_func)
10 BUILD_TUPLE 0
12 CALL_FUNCTION 1
14 UNPACK_SEQUENCE 0
16 LOAD_GLOBAL 0 (global_x)
18 LOAD_GLOBAL 1 (paddle_set_eval_frame_fn)
20 LOAD_FAST 0 (___old_eval_frame)
22 CALL_FUNCTION 1
24 POP_TOP
26 RETURN_VALUE

如何实现 SideEffect

其实实现起来也非常简单, 我们只需要重新把数据通过SideEffect, 机制重新给到python的栈结构上就能保证数据的准确性

转写后字节码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
8           0 LOAD_GLOBAL              1 (paddle_set_eval_frame_fn)
2 LOAD_CONST 0 (None)
4 CALL_FUNCTION 1
6 STORE_FAST 0 (___old_eval_frame)
8 LOAD_GLOBAL 2 (__compiled_fn_dummy_func)
10 BUILD_TUPLE 0
12 CALL_FUNCTION 1
14 UNPACK_SEQUENCE 0
16 LOAD_CONST 1 (4)
18 LOAD_CONST 1 (4) # <--- 重新load一次数据
20 STORE_GLOBAL 0 (global_x) # <--- 重新store一次数据
22 LOAD_GLOBAL 1 (paddle_set_eval_frame_fn)
24 LOAD_FAST 0 (___old_eval_frame)
26 CALL_FUNCTION 1
28 POP_TOP
30 RETURN_VALUE

如何知道需要重新加载哪些数据呢 ?

这里是通过STORE_GLOBAL这个字节码来判断是否更新数据的,只有当global变量被更新的时候,我们才会重新加载数据.

当运行到STORE_GLOBAL时会将GlobalVariable中原有的数据更新, 在更新的同时GlobalVariable中的数据域MutableDictLikeData会记录下来更新的数据, 以及版本信息, 以便后续重新加载数据.

如何重新加载数据呢 ?

FunctionGraph类中的restore_side_effects方法中实现副作用机制的还原, 通过MutableDictLikeData中的数据来还原GlobalVariable中的数据, 以及版本信息.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
if isinstance(var, GlobalVariable):
# 根据STORE_GLOBAL记录的value值进行_reconstruct还原数据
# LOAD_CONST or LOAD_FAST
for record in var.proxy.get_last_records():
if isinstance(record, (MutationSet, MutationNew)):
record.value._reconstruct(self.pycode_gen)

# 对其他需要副作用的Variable进行还原
self.restore_side_effects(variables[1:])

# STORE_GLOBAL
for record in var.proxy.get_last_records()[::-1]:
if isinstance(record, (MutationSet, MutationNew)):
self.pycode_gen.gen_store_global(record.key)
if isinstance(record, MutationDel):
self.pycode_gen.gen_delete_global(record.key)

这里store应该要反着才能对应上load的顺序

1
2
3
4
5
6
7
load 1
load 2

下一个 var

store 2
store 1

至此,就完成了整个global副作用的实现.

参考链接