From f4200a62b8cb12f0bb1c4ac474c3f93e0977d6ea Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Mon, 22 Jul 2024 09:00:01 +0000 Subject: [PATCH] improve --- .../executor/variable_dispatch.py | 36 ++++++++----------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py b/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py index 5e029e25c88153..16a484496c9873 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py @@ -583,18 +583,6 @@ def dispatch_reversed(var: ContainerVariable): ), ) -# bool -Dispatcher.register( - bool, - ("ContainerVariable | SymbolicVariable",), - lambda var: var.bool(), -) -Dispatcher.register( - operator.truth, - ("ConstantVariable | SymbolicVariable",), - lambda var: var.bool(), -) - # str Dispatcher.register( str, @@ -981,16 +969,20 @@ def tensor_mod_dispatcher( ), ) # Symbolic -Dispatcher.register( - float, - ("SymbolicVariable",), - lambda var: var.float(), -) -Dispatcher.register( - int, - ("SymbolicVariable",), - lambda var: var.int(), -) +for unary_fn in fallback_tensor_unary_method: + Dispatcher.register( + unary_fn, + ("SymbolicVariable",), + partial( + lambda fn, var: VariableFactory.from_value( + fn(var.get_py_value()), + var.graph, + tracker=DummyTracker([var]), + ), + unary_fn, + ), + ) + for binary_fn in BINARY_OPS: for magic_method in magic_method_builtin_dispatch(binary_fn): if magic_method.name not in get_tensor_methods():