From cdefdade7e7edb5950913be230a9aae1d61f0c38 Mon Sep 17 00:00:00 2001 From: NotHaozi Date: Thu, 15 Jun 2023 13:35:53 +0000 Subject: [PATCH 1/2] fix unittests --- .gitmodules | 3 + python/CMakeLists.txt | 2 + python/paddle/jit/PaddleSOT | 1 + .../jit/dy2static/program_translator.py | 4 +- python/paddle/jit/symbolic_trace/__init__.py | 24 - .../paddle/jit/symbolic_trace/infer_meta.py | 155 --- .../opcode_translator/__init__.py | 17 - .../opcode_translator/executor/__init__.py | 15 - .../executor/function_graph.py | 294 ----- .../opcode_translator/executor/guard.py | 74 -- .../opcode_translator/executor/instr_flag.py | 32 - .../executor/opcode_executor.py | 1092 ---------------- .../executor/opcode_inline_executor.py | 117 -- .../executor/pycode_generator.py | 349 ------ .../opcode_translator/executor/tracker.py | 172 --- .../executor/tracker_viewer.py | 94 -- .../executor/variable_monkey_patch.py | 75 -- .../opcode_translator/executor/variables.py | 1111 ----------------- .../instruction_utils/__init__.py | 41 - .../instruction_utils/instruction_utils.py | 248 ---- .../instruction_utils/opcode_analysis.py | 85 -- .../instruction_utils/opcode_info.py | 120 -- .../opcode_translator/skip_files.py | 124 -- .../opcode_translator/transform.py | 50 - .../paddle/jit/symbolic_trace/proxy_tensor.py | 102 -- .../symbolic_trace/symbolic/compile_cache.py | 34 - .../symbolic_trace/symbolic/interpreter.py | 112 -- .../symbolic_trace/symbolic/statement_ir.py | 235 ---- .../symbolic/symbolic_context.py | 110 -- python/paddle/jit/symbolic_trace/trace.py | 33 - .../jit/symbolic_trace/utils/.utils.py.swp | Bin 16384 -> 0 bytes .../jit/symbolic_trace/utils/__init__.py | 67 - .../jit/symbolic_trace/utils/exceptions.py | 30 - .../jit/symbolic_trace/utils/monkey_patch.py | 56 - .../symbolic_trace/utils/paddle_api_config.py | 58 - .../utils/paddle_api_info/paddle_api.json | 297 ----- .../paddle_api_info/paddle_tensor_method.json | 189 --- .../symbolic_trace/utils/pycode_inspect.py | 20 - .../paddle/jit/symbolic_trace/utils/utils.py | 230 ---- .../dygraph_to_static_util.py | 18 + test/dygraph_to_static/test_break_continue.py | 3 + test/dygraph_to_static/test_build_strategy.py | 4 + test/dygraph_to_static/test_cache_program.py | 4 +- test/dygraph_to_static/test_cinn_prim.py | 5 + test/dygraph_to_static/test_cinn_prim_gelu.py | 3 + test/dygraph_to_static/test_cinn_prim_mean.py | 5 + .../test_closure_analysis.py | 2 + test/dygraph_to_static/test_container.py | 3 + test/dygraph_to_static/test_convert_call.py | 8 + .../test_cpu_cuda_to_tensor.py | 48 +- test/dygraph_to_static/test_declarative.py | 2 + .../test_decorator_transform.py | 9 +- test/dygraph_to_static/test_fallback.py | 3 + test/dygraph_to_static/test_gradname_parse.py | 16 +- test/dygraph_to_static/test_ifelse.py | 17 +- test/dygraph_to_static/test_mnist.py | 2 + test/dygraph_to_static/test_mobile_net.py | 2 + test/dygraph_to_static/test_op_attr.py | 4 + test/dygraph_to_static/test_tsm.py | 4 +- .../test_write_python_container.py | 16 + 60 files changed, 164 insertions(+), 5886 deletions(-) create mode 160000 python/paddle/jit/PaddleSOT delete mode 100644 python/paddle/jit/symbolic_trace/__init__.py delete mode 100644 python/paddle/jit/symbolic_trace/infer_meta.py delete mode 100644 python/paddle/jit/symbolic_trace/opcode_translator/__init__.py delete mode 100644 python/paddle/jit/symbolic_trace/opcode_translator/executor/__init__.py delete mode 100644 python/paddle/jit/symbolic_trace/opcode_translator/executor/function_graph.py delete mode 100644 python/paddle/jit/symbolic_trace/opcode_translator/executor/guard.py delete mode 100644 python/paddle/jit/symbolic_trace/opcode_translator/executor/instr_flag.py delete mode 100644 python/paddle/jit/symbolic_trace/opcode_translator/executor/opcode_executor.py delete mode 100644 python/paddle/jit/symbolic_trace/opcode_translator/executor/opcode_inline_executor.py delete mode 100644 python/paddle/jit/symbolic_trace/opcode_translator/executor/pycode_generator.py delete mode 100644 python/paddle/jit/symbolic_trace/opcode_translator/executor/tracker.py delete mode 100644 python/paddle/jit/symbolic_trace/opcode_translator/executor/tracker_viewer.py delete mode 100644 python/paddle/jit/symbolic_trace/opcode_translator/executor/variable_monkey_patch.py delete mode 100644 python/paddle/jit/symbolic_trace/opcode_translator/executor/variables.py delete mode 100644 python/paddle/jit/symbolic_trace/opcode_translator/instruction_utils/__init__.py delete mode 100644 python/paddle/jit/symbolic_trace/opcode_translator/instruction_utils/instruction_utils.py delete mode 100644 python/paddle/jit/symbolic_trace/opcode_translator/instruction_utils/opcode_analysis.py delete mode 100644 python/paddle/jit/symbolic_trace/opcode_translator/instruction_utils/opcode_info.py delete mode 100644 python/paddle/jit/symbolic_trace/opcode_translator/skip_files.py delete mode 100644 python/paddle/jit/symbolic_trace/opcode_translator/transform.py delete mode 100644 python/paddle/jit/symbolic_trace/proxy_tensor.py delete mode 100644 python/paddle/jit/symbolic_trace/symbolic/compile_cache.py delete mode 100644 python/paddle/jit/symbolic_trace/symbolic/interpreter.py delete mode 100644 python/paddle/jit/symbolic_trace/symbolic/statement_ir.py delete mode 100644 python/paddle/jit/symbolic_trace/symbolic/symbolic_context.py delete mode 100644 python/paddle/jit/symbolic_trace/trace.py delete mode 100644 python/paddle/jit/symbolic_trace/utils/.utils.py.swp delete mode 100644 python/paddle/jit/symbolic_trace/utils/__init__.py delete mode 100644 python/paddle/jit/symbolic_trace/utils/exceptions.py delete mode 100644 python/paddle/jit/symbolic_trace/utils/monkey_patch.py delete mode 100644 python/paddle/jit/symbolic_trace/utils/paddle_api_config.py delete mode 100644 python/paddle/jit/symbolic_trace/utils/paddle_api_info/paddle_api.json delete mode 100644 python/paddle/jit/symbolic_trace/utils/paddle_api_info/paddle_tensor_method.json delete mode 100644 python/paddle/jit/symbolic_trace/utils/pycode_inspect.py delete mode 100644 python/paddle/jit/symbolic_trace/utils/utils.py diff --git a/.gitmodules b/.gitmodules index 8c294e25bd609..7338cb0523d35 100644 --- a/.gitmodules +++ b/.gitmodules @@ -50,3 +50,6 @@ path = third_party/eigen3 url = https://gitlab.com/libeigen/eigen.git ignore = dirty +[submodule "python/paddle/jit/PaddleSOT"] + path = python/paddle/jit/PaddleSOT + url = https://github.com/PaddlePaddle/PaddleSOT diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 8d9073b398417..a394da6c07b89 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -133,6 +133,8 @@ else() ${PADDLE_BINARY_DIR}/python COMMAND cp -r ${PADDLE_SOURCE_DIR}/test ${PADDLE_BINARY_DIR}/ COMMAND env ${py_env} ${PYTHON_EXECUTABLE} setup.py bdist_wheel + COMMAND ln -sf ${PADDLE_SOURCE_DIR}/python/paddle/jit/PaddleSOT/sot + ${PADDLE_BINARY_DIR}/python/paddle/jit/sot COMMENT "Packing whl packages------>>>" DEPENDS copy_libpaddle ${FLUID_CORE} framework_py_proto profiler_py_proto pass_desc_py_proto ${PY_FILES}) diff --git a/python/paddle/jit/PaddleSOT b/python/paddle/jit/PaddleSOT new file mode 160000 index 0000000000000..32d04cd9e28f8 --- /dev/null +++ b/python/paddle/jit/PaddleSOT @@ -0,0 +1 @@ +Subproject commit 32d04cd9e28f8fd455de266dc9f16b1f76242d79 diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index 323d3bb7fa36c..780264fe4ba72 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -676,9 +676,9 @@ def __init__(self, function, input_spec=None, **kwargs): super().__init__(function, input_spec, **kwargs) def _perform_call(self, *args, **kwargs): - from ..symbolic_trace import symbolic_trace + from ..sot import symbolic_translate - traced_fun = symbolic_trace(self._dygraph_function) + traced_fun = symbolic_translate(self._dygraph_function) if self._class_instance is not None: args = (self._class_instance,) + args return traced_fun(*args, **kwargs) diff --git a/python/paddle/jit/symbolic_trace/__init__.py b/python/paddle/jit/symbolic_trace/__init__.py deleted file mode 100644 index 65708e7dfa932..0000000000000 --- a/python/paddle/jit/symbolic_trace/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .proxy_tensor import ProxyTensor -from .trace import symbolic_trace -from .utils import paddle_tensor_method -from .utils.monkey_patch import do_monkey_patch, proxy_tensor_method_builder - -do_monkey_patch(ProxyTensor, paddle_tensor_method, proxy_tensor_method_builder) - -__all__ = [ - "symbolic_trace", -] diff --git a/python/paddle/jit/symbolic_trace/infer_meta.py b/python/paddle/jit/symbolic_trace/infer_meta.py deleted file mode 100644 index 736d0e0d086df..0000000000000 --- a/python/paddle/jit/symbolic_trace/infer_meta.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle -from paddle.fluid.framework import Program -from paddle.utils import flatten - -from .utils import Cache, Singleton, map_if, meta_str, no_eval_frame - - -@Singleton -class InferMetaCache(Cache): - def key_fn(self, *args, **kwargs): - return hash( - (tuple(flatten(args)), tuple(kwargs.keys()), tuple(flatten(kwargs))) - ) - - def value_fn(self, *args, **kwargs): - return infer_meta(*args, **kwargs) - - -class MetaInfo: - def __init__(self, shape, dtype, stop_gradient): - self.shape = shape - self.dtype = dtype - self.stop_gradient = stop_gradient - - @staticmethod - def from_tensor(tensor): - return MetaInfo(tensor.shape, tensor.dtype, tensor.stop_gradient) - - def to_input_spec(self): - return paddle.static.InputSpec( - self.shape, dtype=self.dtype, stop_gradient=self.stop_gradient - ) - - def __repr__(self): - return meta_str(self.shape, self.dtype, self.stop_gradient) - - def __eq__(self, meta): - return ( - self.shape == meta.shape - and self.dtype == meta.dtype - and self.stop_gradient == meta.stop_gradient - ) - - def __hash__(self): - return hash((tuple(self.shape), self.dtype, self.stop_gradient)) - - -@Singleton -class VariableCreator: - def __init__(self): - self.var_cache = {} - self.main_program = Program() - self.startup_program = Program() - - def gen_name(self, meta): - name = f"{meta.dtype}_{meta.stop_gradient}" - for l in meta.shape: - name += f"_{l}" - return name - - def create_var(self, meta): - var = self.main_program.global_block().create_var( - shape=meta.shape, - dtype=meta.dtype, - stop_gradient=meta.stop_gradient, - ) - assert not isinstance( - var, paddle.Tensor - ), "Expect a Variable, but got a Tensor." - return var - - def get_variable(self, meta): - var_feature_name = self.gen_name(meta) - - if var_feature_name not in self.var_cache: - self.var_cache[var_feature_name] = self.create_var(meta) - return self.var_cache[var_feature_name] - - def infer_meta(self, func, *args, **kwargs): - paddle.enable_static() - args, kwargs = convert_to_variable(args), convert_to_variable(kwargs) - - with paddle.static.program_guard( - self.main_program, self.startup_program - ): - if isinstance(func, str): - # TODO(Aurelius84): Is length of args always greater than 0? - # Do we need add condition check here? - out = getattr(args[0], func)(*args[1:], **kwargs) - else: - out = func(*args, **kwargs) - - out = MetaInfo( - list(out.shape), - out.dtype, - out.stop_gradient, - ) - - paddle.disable_static() - return out - - -def convert_to_variable(args): - return map_if( - args, - pred=lambda x: isinstance(x, MetaInfo), - true_fn=lambda x: VariableCreator().get_variable(x), - false_fn=lambda x: x, - ) - - -def convert_to_input_spec(args): - return map_if( - args, - pred=lambda x: isinstance(x, MetaInfo), - true_fn=lambda x: x.to_input_spec(), - false_fn=lambda x: paddle.static.InputSpec.from_tensor(x), - ) - - -@no_eval_frame -def infer_meta(func, *args, **kwargs): - return VariableCreator().infer_meta(func, *args, **kwargs) - - -def infer_meta_for_layer(layer, *args, **kwargs): - assert isinstance( - layer, paddle.nn.Layer - ), f"Expect a Layer, but got {layer}." - layer = paddle.jit.to_static(layer, enable_fallback=False) - - args, kwargs = convert_to_input_spec(args), convert_to_input_spec(kwargs) - concrete_program = layer.forward.get_concrete_program(*args, **kwargs)[0] - out = concrete_program.outputs[0] - out = MetaInfo( - list(out.shape), - out.dtype, - out.stop_gradient, - ) - layer.forward.rollback() - return out diff --git a/python/paddle/jit/symbolic_trace/opcode_translator/__init__.py b/python/paddle/jit/symbolic_trace/opcode_translator/__init__.py deleted file mode 100644 index e2acb750f68c8..0000000000000 --- a/python/paddle/jit/symbolic_trace/opcode_translator/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .transform import eval_frame_callback - -__all__ = ["eval_frame_callback"] diff --git a/python/paddle/jit/symbolic_trace/opcode_translator/executor/__init__.py b/python/paddle/jit/symbolic_trace/opcode_translator/executor/__init__.py deleted file mode 100644 index 28f8ef2efeb12..0000000000000 --- a/python/paddle/jit/symbolic_trace/opcode_translator/executor/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from . import variable_monkey_patch # noqa F401 diff --git a/python/paddle/jit/symbolic_trace/opcode_translator/executor/function_graph.py b/python/paddle/jit/symbolic_trace/opcode_translator/executor/function_graph.py deleted file mode 100644 index fb7e22923cf3f..0000000000000 --- a/python/paddle/jit/symbolic_trace/opcode_translator/executor/function_graph.py +++ /dev/null @@ -1,294 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This file is specifically used to handle the problem -# of generating a Graph from a linear function call. - -from __future__ import annotations - -from collections import namedtuple -from copy import deepcopy -from typing import Any, Callable - -import paddle - -from ...infer_meta import InferMetaCache, infer_meta, infer_meta_for_layer -from ...proxy_tensor import ProxyTensor, ProxyTensorContext -from ...symbolic.statement_ir import Symbol -from ...symbolic.symbolic_context import SymbolicTraceContext -from ...utils import is_paddle_api, log, show_trackers -from .guard import Guard, StringifyExpression, make_guard -from .pycode_generator import PyCodeGen -from .tracker import DummyTracker -from .variables import ( - ContainerVariable, - PaddleLayerVariable, - TensorVariable, - VariableBase, - VariableFactory, - topo_sort_vars, -) - - -def convert_to_meta(inputs): - def func(x): - if isinstance(x, ProxyTensor): - return x.meta - return x - - return paddle.utils.map_structure(func, inputs) - - -def convert_to_symbol(inputs): - def func(x): - if isinstance(x, ProxyTensor): - return Symbol(x.name) - return x - - pack_inputs = [inputs] - ret = paddle.utils.map_structure(func, pack_inputs) - return ret[0] - - -def convert_variable_to_value(inputs): - def func(x): - return x.get_value() - - return paddle.utils.map_structure(func, inputs) - - -class FunctionGraph: - """ - A Graph representation corresponding to each FunctionFrame - The input binding diagram containing the current call represents three parts of output settings, - This Graph can be compiled as a f_locals dependency function which produce the same outputs. - """ - - Memo = namedtuple( - "function_graph_memo", - ['inner_out', 'input_variables', "stmt_ir", "global_guards"], - ) - - def __init__(self, frame): - self.sir_ctx = SymbolicTraceContext() - self.inner_out = set() - self.input_variables = [] - self.pycode_gen = PyCodeGen(frame) - self.py_frame = frame - self.out_var_prefix = "___SIR_out_" - self._global_guarded_variables: list[VariableBase] = [] - - def need_add_input(self, var): - if var.id in self.inner_out: - return False - for v in self.input_variables: - if v.id == var.id: - return False - return True - - def save_memo(self): - """ - Why don't use __deepcopy__: - bacause memo is not a deepcopy, i.e inner_out is only a - shallow copy, SIR is a deepcopy. - """ - saved_stmt_ir = deepcopy(self.sir_ctx.TOS) - return FunctionGraph.Memo( - inner_out=set(self.inner_out), - input_variables=list(self.input_variables), - stmt_ir=saved_stmt_ir, - global_guards=list(self._global_guarded_variables), - ) - - def restore_memo(self, memo): - self.inner_out = memo.inner_out - self.input_variables = memo.input_variables - self.sir_ctx.replace_TOS(memo.stmt_ir) - self._global_guarded_variables = memo.global_guards - - def collect_input_variables(self, inputs: list[VariableBase]): - for inp in inputs: - if isinstance(inp, ContainerVariable): - self.collect_input_variables(inp.get_items()) - if isinstance(inp, VariableBase) and self.need_add_input(inp): - self.input_variables.append(inp) - - @property - def guard_fn(self) -> Guard: - guards = [ - variable.make_stringify_guard() - for variable in topo_sort_vars( - self.input_variables + self._global_guarded_variables - ) - if not isinstance(variable.tracker, DummyTracker) - ] - for guard in guards: - assert isinstance( - guard, StringifyExpression - ), "guard must be StringifyExpression." - - return make_guard(guards) - - def start_compile(self, *ret_vars: VariableBase): - ret_items = [ - ret_item - for ret_var in ret_vars - for ret_item in ret_var.flatten_items() - ] - tensor_items = self._find_tensor_outputs(ret_items) - compiled_fn, statment_ir = self.sir_ctx.compile_fn( - [tensor_var.value for tensor_var in tensor_items] - ) - input_names = statment_ir.inputs - compiled_fn_name = statment_ir.name - # prepare function and inputs - self.pycode_gen.gen_load_object(compiled_fn, compiled_fn_name) - for name in input_names: - found = False - for variable in self.input_variables: - if ( - isinstance(variable, (TensorVariable, PaddleLayerVariable)) - and variable.get_symbol().name == name - ): - variable.tracker.gen_instructions(self.pycode_gen) - found = True - break - assert found, f"can't find input {name} in SIR." - # Pack all args into a tuple, because we don't support *args now. - self.pycode_gen.gen_build_tuple(count=len(input_names)) - # call the compiled_fn - self.pycode_gen.gen_call_function(argc=1) - # Store outputs to f_locals - self.pycode_gen.gen_unpack_sequence(count=len(tensor_items)) - for tensor_var in tensor_items: - self.pycode_gen.gen_store_fast(tensor_var.out_var_name) - # restore the outputs. - for ret_var in ret_vars: - ret_var.reconstruct(self.pycode_gen) - - # deal side effect - # TODO(xiongkun): add side effect handle - - tracker_output_path = show_trackers() - if tracker_output_path: - from .tracker_viewer import view_tracker - - view_tracker(list(ret_vars), tracker_output_path, format="png") - - def call_paddle_api( - self, - func: Callable[..., Any], - *args: VariableBase, - **kwargs: VariableBase, - ): - assert is_paddle_api(func) - # not fallback api, start symbolic trace. - # TODO(xiokgun): multi-output support. - # TODO(xiokgun): may have python buildin object inside metas. - # TODO(xiokgun): 4 kinds of python arguments. support it !! - log(3, f"call paddle.api : {func.__name__}", "\n") - self.collect_input_variables(list(args)) - self.collect_input_variables(list(kwargs.values())) - values, kwvalues = ( - convert_variable_to_value(args), - convert_variable_to_value(kwargs), - ) - metas = convert_to_meta(values) - kwmetas = convert_to_meta(kwvalues) - meta = InferMetaCache()(func, *metas, **kwmetas) - result = ProxyTensor(ProxyTensorContext().new_varname(), meta) - inputs_symbols = ( - convert_to_symbol(values), - convert_to_symbol(kwvalues), - ) - log(3, f" inputs : {inputs_symbols}", "\n") - self.sir_ctx.call_API( - func, - inputs=inputs_symbols, - outputs=convert_to_symbol(result), - ) # symbolic only contain symbols. - variable = VariableFactory.from_value( - result, - self, - tracker=DummyTracker(list(args) + list(kwargs.values())), - ) - self._put_inner(variable) - return variable - - def call_tensor_method(self, method_name: str, *args: VariableBase): - self.collect_input_variables(list(args)) - values = convert_variable_to_value(args) - metas = convert_to_meta(values) - meta = infer_meta(method_name, *metas) - result = ProxyTensor(ProxyTensorContext().new_varname(), meta) - self.sir_ctx.call_METHOD( - method_name, - inputs=(convert_to_symbol(values), {}), - outputs=convert_to_symbol(result), - ) # symbolic only contain symbols. - variable = VariableFactory.from_value( - result, self, tracker=DummyTracker(list(args)) - ) - self._put_inner(variable) - return variable - - def call_layer( - self, - layer: PaddleLayerVariable, - *args: VariableBase, - **kwargs: VariableBase, - ): - self.collect_input_variables([layer, *args]) - self.collect_input_variables(list(kwargs.values())) - values, kwvalues = ( - convert_variable_to_value(args), - convert_variable_to_value(kwargs), - ) - metas = convert_to_meta(values) - kwmetas = convert_to_meta(kwvalues) - meta = infer_meta_for_layer(layer.value, *metas, **kwmetas) - result = ProxyTensor(ProxyTensorContext().new_varname(), meta) - inputs_symbols = ( - (layer.get_symbol(), *convert_to_symbol(values)), - convert_to_symbol(kwvalues), - ) - self.sir_ctx.call_LAYER( - layer.value.__class__.__name__, - inputs=inputs_symbols, - outputs=convert_to_symbol(result), - ) - variable = VariableFactory.from_value( - result, - self, - tracker=DummyTracker([layer, *args] + list(kwargs.values())), - ) - self._put_inner(variable) - return variable - - def _put_inner(self, var): - self.inner_out.add(var.id) - - def add_global_guarded_variable(self, variable: VariableBase): - self._global_guarded_variables.append(variable) - - def _find_tensor_outputs( - self, outputs: list[VariableBase] - ) -> list[TensorVariable]: - output_tensors: list[TensorVariable] = [] - for output in outputs: - if isinstance(output, TensorVariable) and isinstance( - output.tracker, DummyTracker - ): - output_tensors.append(output) - return output_tensors diff --git a/python/paddle/jit/symbolic_trace/opcode_translator/executor/guard.py b/python/paddle/jit/symbolic_trace/opcode_translator/executor/guard.py deleted file mode 100644 index 17cd1697fda9e..0000000000000 --- a/python/paddle/jit/symbolic_trace/opcode_translator/executor/guard.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import ast -import types -from dataclasses import dataclass -from functools import reduce -from typing import Any, Callable - -from ...utils import InnerError, log - -Guard = Callable[[types.FrameType], bool] - -# NOTE(SigureMo): [How to write Stringify Guard?] -# 1. we should capture free variables manually, the string cannot capture free -# variables automatically. -# 2. Be aware that the comparison logic before and after stringify may be different. -# 3. we should compute as much as possible at "compile time" and encode the -# computation in the Guard string, rather than passing it to runtime to minimize -# runtime overhead. - - -@dataclass -class StringifyExpression: - expr: str - free_vars: dict[str, Any] - - def __post_init__(self): - self.check_expr(self.expr) - - def check_expr(self, expr: str): - try: - ast.parse(expr) - except SyntaxError as e: - raise InnerError(f"Invalid expression: {expr}") from e - - def __and__(self, other: StringifyExpression) -> StringifyExpression: - return StringifyExpression( - " and ".join([self.expr, other.expr]), - union_free_vars(self.free_vars, other.free_vars), - ) - - -def union_free_vars(*free_vars: dict[str, Any]): - return {k: v for d in free_vars for k, v in d.items()} - - -def make_guard(stringify_guards: list[StringifyExpression]) -> Guard: - num_guards = len(stringify_guards) - if not num_guards: - return lambda frame: True - union_guard_expr = reduce(lambda x, y: x & y, stringify_guards) - guard_string = f"lambda frame: {union_guard_expr.expr}" - guard = eval( - guard_string, - union_guard_expr.free_vars, - ) - log(3, f"[Guard]: {guard_string}\n") - assert callable(guard), "guard must be callable." - - return guard diff --git a/python/paddle/jit/symbolic_trace/opcode_translator/executor/instr_flag.py b/python/paddle/jit/symbolic_trace/opcode_translator/executor/instr_flag.py deleted file mode 100644 index 43b4cd16874b5..0000000000000 --- a/python/paddle/jit/symbolic_trace/opcode_translator/executor/instr_flag.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# flags for instructions - - -class FORMAT_VALUE_FLAG: - FVC_MASK = 0x3 - FVC_NONE = 0x0 - FVC_STR = 0x1 - FVC_REPR = 0x2 - FVC_ASCII = 0x3 - FVS_MASK = 0x4 - FVS_HAVE_SPEC = 0x4 - - -class MAKE_FUNCTION_FLAG: - MF_HAS_CLOSURE = 0x08 - MF_HAS_ANNOTATION = 0x04 - MF_HAS_KWDEFAULTS = 0x02 - MF_HAS_DEFAULTS = 0x01 diff --git a/python/paddle/jit/symbolic_trace/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/symbolic_trace/opcode_translator/executor/opcode_executor.py deleted file mode 100644 index 691d0609635d8..0000000000000 --- a/python/paddle/jit/symbolic_trace/opcode_translator/executor/opcode_executor.py +++ /dev/null @@ -1,1092 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import collections -import dis -import inspect -import operator -import types -from typing import Callable, List, Optional, Tuple - -from ...utils import ( - BreakGraphError, - InnerError, - Singleton, - UnsupportError, - is_strict_mode, - log, - log_do, -) -from ..instruction_utils.instruction_utils import Instruction, get_instructions -from .function_graph import FunctionGraph -from .guard import Guard -from .instr_flag import FORMAT_VALUE_FLAG as FV -from .instr_flag import MAKE_FUNCTION_FLAG as MF -from .pycode_generator import PyCodeGen -from .tracker import ( - BuiltinTracker, - DummyTracker, - GetItemTracker, - GlobalTracker, - LocalTracker, -) -from .variables import ( - CallableVariable, - ConstantVariable, - ConstTracker, - ContainerVariable, - DictIterVariable, - DictVariable, - IterVariable, - ListVariable, - ObjectVariable, - SequenceIterVariable, - TensorIterVariable, - TensorVariable, - TupleVariable, - UserDefinedFunctionVariable, - UserDefinedIterVariable, - VariableBase, - VariableFactory, -) - -CustomCode = collections.namedtuple( - "CustomCode", ["code", "disable_eval_frame"] -) - - -GuardedFunction = Tuple[types.CodeType, Guard] -GuardedFunctions = List[GuardedFunction] -CacheGetter = Callable[ - [types.FrameType, GuardedFunctions], Optional[CustomCode] -] -dummy_guard: Guard = lambda frame: True - -SUPPORT_COMPARE_OP = { - ">": operator.gt, - "<": operator.lt, - ">=": operator.ge, - "<=": operator.le, - "==": lambda x, y: VariableFactory.from_value( - x.value == y.value, None, tracker=DummyTracker([x, y]) - ), - "!=": lambda x, y: VariableFactory.from_value( - x.value != y.value, None, tracker=DummyTracker([x, y]) - ), - "is not": lambda x, y: VariableFactory.from_value( - x.value is not y.value, None, tracker=DummyTracker([x, y]) - ), - "is": lambda x, y: VariableFactory.from_value( - x.value is y.value, None, tracker=DummyTracker([x, y]) - ), -} - - -class Stop: - pass - - -@Singleton -class InstructionTranslatorCache: - cache: dict[types.CodeType, tuple[CacheGetter, GuardedFunctions]] - translate_count: int - - def __init__(self): - self.cache = {} - self.translate_count = 0 - - def clear(self): - self.cache.clear() - self.translate_count = 0 - - def __call__(self, frame) -> CustomCode | None: - code: types.CodeType = frame.f_code - if code not in self.cache: - cache_getter, (new_code, guard_fn) = self.translate(frame) - self.cache[code] = (cache_getter, [(new_code, guard_fn)]) - if cache_getter == self.skip: - return None - return CustomCode(new_code, False) - cache_getter, guarded_fns = self.cache[code] - return cache_getter(frame, guarded_fns) - - def lookup( - self, frame: types.FrameType, guarded_fns: GuardedFunctions - ) -> CustomCode | None: - for code, guard_fn in guarded_fns: - try: - if guard_fn(frame): - log(3, "[Cache]: Cache hit\n") - return CustomCode(code, True) - except Exception as e: - log(3, f"[Cache]: Guard function error: {e}\n") - continue - cache_getter, (new_code, guard_fn) = self.translate(frame) - guarded_fns.append((new_code, guard_fn)) - return CustomCode(new_code, False) - - def skip( - self, frame: types.FrameType, guarded_fns: GuardedFunctions - ) -> CustomCode | None: - log(3, f"[Cache]: Skip frame {frame.f_code.co_name}\n") - return None - - def translate( - self, frame: types.FrameType - ) -> tuple[CacheGetter, GuardedFunction]: - code: types.CodeType = frame.f_code - log(3, "[Cache]: Cache miss\n") - self.translate_count += 1 - - result = start_translate(frame) - if result is None: - return self.skip, (code, dummy_guard) - - new_code, guard_fn = result - return self.lookup, (new_code, guard_fn) - - -def start_translate(frame) -> GuardedFunction | None: - simulator = OpcodeExecutor(frame) - try: - new_code, guard_fn = simulator.transform() - log_do(3, lambda: dis.dis(new_code)) - return new_code, guard_fn - except InnerError as e: - raise - # TODO(0x45f): handle BreakGraphError to trigger fallback - except (UnsupportError, BreakGraphError) as e: - if is_strict_mode(): - raise - log( - 2, - f"Unsupport Frame is {frame.f_code}, error message is: {str(e)}\n", - ) - return None - except Exception as e: - raise - - -def tos_op_wrapper(fn): - nargs = len(inspect.signature(fn).parameters) - - def inner(self: OpcodeExecutorBase, instr: Instruction): - args = self.pop_n(nargs) - self.push(fn(*args)) - - return inner - - -def breakoff_graph_with_jump(normal_jump): - """breakoff graph when meet jump.""" - - def jump_instruction_with_fallback(self: OpcodeExecutor, instr): - result = self.peek() - if isinstance(result, TensorVariable): - self.pop() - # fallback when in OpcodeExecutor - # raise error in OpcodeInlineExecutor - self._fallback_in_jump(result, instr) - return Stop() - else: - return normal_jump(self, instr) - - return jump_instruction_with_fallback - - -class OpcodeExecutorBase: - def __init__(self, code: types.CodeType, graph: FunctionGraph): - # fake env for run, new env should be gened by PyCodeGen - self._stack: list[VariableBase] = [] - self._co_consts = [] - self._locals = {} - self._globals = {} - self._builtins = {} - self._lasti = 0 # idx of instruction list - self._code = code - self._instructions = get_instructions(self._code) - self._graph = graph - self.new_code = None - self.guard_fn = None - self._prepare_virtual_env() - - def _prepare_virtual_env(self): - raise NotImplementedError("Please inplement virtual_env.") - - def transform(self): - raise NotImplementedError() - - def run(self): - log(3, f"start execute opcode: {self._code}\n") - self._lasti = 0 - while True: - if self._lasti >= len(self._instructions): - raise InnerError("lasti out of range, InnerError.") - cur_instr = self._instructions[self._lasti] - self._lasti += 1 - is_stop = self.step(cur_instr) - if is_stop: - break - - def step(self, instr): - if not hasattr(self, instr.opname): - raise UnsupportError(f"opcode: {instr.opname} is not supported.") - log(3, f"[TraceExecution]: {instr.opname}, stack is {self._stack}\n") - return getattr(self, instr.opname)(instr) # run single step. - - def indexof(self, instr): - return self._instructions.index(instr) - - def pop(self) -> VariableBase: - return self._stack.pop() - - def peek(self) -> VariableBase: - return self._stack[-1] - - def peek_n(self, n) -> list[VariableBase]: - return self._stack[-n:] - - def pop_n(self, n: int) -> list[VariableBase]: - if n == 0: - return [] - retval = self._stack[-n:] - self._stack[-n:] = [] - return retval - - def push(self, val: VariableBase): - self._stack.append(val) - - # unary operators - UNARY_POSITIVE = tos_op_wrapper(operator.pos) - UNARY_NEGATIVE = tos_op_wrapper(operator.neg) - # UNARY_NOT = tos_op_wrapper(operator.not_) - UNARY_INVERT = tos_op_wrapper(operator.invert) - - # binary operators - BINARY_POWER = tos_op_wrapper(operator.pow) - BINARY_MULTIPLY = tos_op_wrapper(operator.mul) - BINARY_MATRIX_MULTIPLY = tos_op_wrapper(operator.matmul) - BINARY_FLOOR_DIVIDE = tos_op_wrapper(operator.floordiv) - BINARY_TRUE_DIVIDE = tos_op_wrapper(operator.truediv) - BINARY_MODULO = tos_op_wrapper(operator.mod) - BINARY_ADD = tos_op_wrapper(operator.add) - BINARY_SUBTRACT = tos_op_wrapper(operator.sub) - BINARY_LSHIFT = tos_op_wrapper(operator.lshift) - BINARY_RSHIFT = tos_op_wrapper(operator.rshift) - BINARY_AND = tos_op_wrapper(operator.and_) - BINARY_OR = tos_op_wrapper(operator.or_) - BINARY_XOR = tos_op_wrapper(operator.xor) - - # inplace operators - # paddle variable do not have inplace operators. For example when call `y **= x`, will call var.__pow__ - INPLACE_POWER = tos_op_wrapper(operator.ipow) - INPLACE_MULTIPLY = tos_op_wrapper(operator.imul) - INPLACE_MATRIX_MULTIPLY = tos_op_wrapper(operator.imatmul) - INPLACE_FLOOR_DIVIDE = tos_op_wrapper(operator.ifloordiv) - INPLACE_TRUE_DIVIDE = tos_op_wrapper(operator.itruediv) - INPLACE_MODULO = tos_op_wrapper(operator.imod) - INPLACE_ADD = tos_op_wrapper(operator.iadd) - INPLACE_SUBTRACT = tos_op_wrapper(operator.isub) - INPLACE_LSHIFT = tos_op_wrapper(operator.ilshift) - INPLACE_RSHIFT = tos_op_wrapper(operator.irshift) - INPLACE_AND = tos_op_wrapper(operator.iand) - INPLACE_OR = tos_op_wrapper(operator.ior) - INPLACE_XOR = tos_op_wrapper(operator.ixor) - - def LOAD_ATTR(self, instr): - attr_name = instr.argval - obj = self.pop() - self.push(getattr(obj, attr_name)) - - def LOAD_FAST(self, instr): - varname = instr.argval - var = self._locals[varname] - self.push(var) - - def LOAD_METHOD(self, instr): - method_name = instr.argval - obj = self.pop() - method = getattr(obj, method_name) - self.push(method) - - def STORE_FAST(self, instr): - """ - TODO: side effect may happen - """ - var = self.pop() - self._locals[instr.argval] = var - - def LOAD_GLOBAL(self, instr): - name = instr.argval - if name in self._globals.keys(): - value = self._globals[name] - else: - value = self._builtins[name] - self.push(value) - - def LOAD_CONST(self, instr): - var = self._co_consts[instr.arg] - self.push(var) - - def BINARY_SUBSCR(self, instr): - key = self.pop() - container = self.pop() - assert isinstance(key, VariableBase) - self._graph.add_global_guarded_variable(key) - self.push(container[key.value]) - - def STORE_SUBSCR(self, instr): - key = self.pop() - container = self.pop() - value = self.pop() - assert isinstance(key, VariableBase) - self._graph.add_global_guarded_variable(key) - container[key.value] = value - - def CALL_FUNCTION(self, instr): - n_args = instr.arg - assert n_args <= len(self._stack) - args = self.pop_n(n_args) - kwargs = {} - fn = self.pop() - if not isinstance(fn, CallableVariable): - raise UnsupportError(f"CALL_FUNCTION: {fn} is not callable") - ret = fn(*args, **kwargs) - self.push(ret) - - def CALL_FUNCTION_KW(self, instr): - n_args = instr.arg - assert n_args + 2 <= len(self._stack) - - kwargs_keys = self.pop() - assert isinstance(kwargs_keys, TupleVariable) - assert len(kwargs_keys) > 0 - kwargs_keys = [ - x.value if isinstance(x, VariableBase) else x - for x in kwargs_keys.value - ] - - # split arg_list to args and kwargs - arg_list = self.pop_n(n_args) - args = arg_list[0 : -len(kwargs_keys)] - kwargs_values = arg_list[-len(kwargs_keys) :] - kwargs = dict(zip(kwargs_keys, kwargs_values)) - - fn = self.pop() - if not isinstance(fn, CallableVariable): - raise UnsupportError(f"CALL_FUNCTION_KW: {fn} is not callable.") - ret = fn(*args, **kwargs) - self.push(ret) - - def CALL_FUNCTION_EX(self, instr): - flag = instr.arg - if flag & 0x01: # has kwargs - kwargs_variable = self.pop() - assert isinstance(kwargs_variable, DictVariable) - kwargs = kwargs_variable.get_wrapped_items() - else: - kwargs = {} - - args_variable = self.pop() - assert isinstance(args_variable, TupleVariable) - args = args_variable.get_wrapped_items() - - fn = self.pop() - if not isinstance(fn, CallableVariable): - raise UnsupportError(f"CALL_FUNCTION_EX: {fn} is not callable.") - ret = fn(*args, **kwargs) - self.push(ret) - - def CALL_METHOD(self, instr): - n_args = instr.argval - assert n_args <= len(self._stack) - args = self.pop_n(n_args) - method = self.pop() - if not isinstance(method, CallableVariable): - raise UnsupportError(f"CALL METHOD: {method} is not callable.") - ret = method(*args) - self.push(ret) - - def COMPARE_OP(self, instr): - op = instr.argval - if op in SUPPORT_COMPARE_OP: - right, left = self.pop(), self.pop() - self.push(SUPPORT_COMPARE_OP[op](left, right)) - return - else: - raise UnsupportError( - f"{instr} is not support. may be not a supported compare op." - ) - - @breakoff_graph_with_jump - def JUMP_IF_FALSE_OR_POP(self, instr): - pred_obj = self.peek() - if isinstance(pred_obj, (ConstantVariable, ContainerVariable)): - self._graph.add_global_guarded_variable(pred_obj) - is_jump = not bool(pred_obj) - if is_jump: - self._lasti = self.indexof(instr.jump_to) - else: - self.pop() - return - raise UnsupportError( - "Currently don't support predicate a non-const / non-tensor obj." - ) - - @breakoff_graph_with_jump - def JUMP_IF_TRUE_OR_POP(self, instr): - pred_obj = self.peek() - if isinstance(pred_obj, (ConstantVariable, ContainerVariable)): - self._graph.add_global_guarded_variable(pred_obj) - is_jump = bool(pred_obj) - if is_jump: - self._lasti = self.indexof(instr.jump_to) - else: - self.pop() - return - raise UnsupportError( - "Currently don't support predicate a non-const / non-tensor obj." - ) - - @breakoff_graph_with_jump - def POP_JUMP_IF_FALSE(self, instr): - pred_obj = self.pop() - if isinstance(pred_obj, (ConstantVariable, ContainerVariable)): - self._graph.add_global_guarded_variable(pred_obj) - is_jump = not bool(pred_obj) - if is_jump: - self._lasti = self.indexof(instr.jump_to) - return - raise UnsupportError( - "Currently don't support predicate a non-const / non-tensor obj." - ) - - @breakoff_graph_with_jump - def POP_JUMP_IF_TRUE(self, instr): - pred_obj = self.pop() - if isinstance(pred_obj, (ConstantVariable, ContainerVariable)): - self._graph.add_global_guarded_variable(pred_obj) - is_jump = bool(pred_obj) - if is_jump: - self._lasti = self.indexof(instr.jump_to) - return - raise UnsupportError( - "Currently don't support predicate a non-const / non-tensor obj." - ) - - def _fallback_in_jump(self, result, instr): - raise NotImplementedError() - - def JUMP_FORWARD(self, instr): - self._lasti = self.indexof(instr.jump_to) - - def JUMP_ABSOLUTE(self, instr): - self._lasti = self.indexof(instr.jump_to) - - def RETURN_VALUE(self, instr): - assert ( - len(self._stack) == 1 - ), f"Stack must have one element, but get {len(self._stack)} elements." - ret_val = self.pop() - self._graph.start_compile(ret_val) - self._graph.pycode_gen.gen_return() - self.new_code = self._graph.pycode_gen.gen_pycode() - self.guard_fn = self._graph.guard_fn - return Stop() - - def BUILD_LIST(self, instr): - list_size = instr.arg - assert list_size <= len( - self._stack - ), f"OpExecutor want BUILD_LIST with size {list_size}, but current stack do not have enough elems." - val_list = self.pop_n(list_size) - self.push( - VariableFactory.from_value( - val_list, graph=self._graph, tracker=DummyTracker(val_list) - ) - ) - - def BUILD_TUPLE(self, instr): - tuple_size = instr.arg - assert tuple_size <= len( - self._stack - ), f"OpExecutor want BUILD_TUPLE with size {tuple_size}, but current stack do not have enough elems." - val_tuple = self.pop_n(tuple_size) - self.push( - VariableFactory.from_value( - tuple(val_tuple), - graph=self._graph, - tracker=DummyTracker(val_tuple), - ) - ) - - def BUILD_MAP(self, instr): - map_size = instr.arg - built_map = {} - assert map_size * 2 <= len( - self._stack - ), f"OpExecutor want BUILD_MAP with size {map_size} * 2, but current stack do not have enough elems." - val_for_dict = self.pop_n(map_size * 2) - keys = val_for_dict[::2] - values = val_for_dict[1::2] - self.push(self.build_map(keys, values)) - - def BUILD_CONST_KEY_MAP(self, instr): - map_size = instr.arg - assert map_size + 1 <= len( - self._stack - ), f"OpExecutor want BUILD_CONST_KEY_MAP with size {map_size} + 1, but current stack do not have enough elems." - keys = self.pop().get_items() - assert len(keys) == map_size - values = self.pop_n(map_size) - self.push(self.build_map(keys, values)) - - def build_map( - self, keys: list[VariableBase], values: list[VariableBase] - ) -> VariableBase: - built_map = {} - for key, value in zip(keys, values): - assert isinstance(key, VariableBase) - # Add key to global guarded variable to avoid missing the key guard - self._graph.add_global_guarded_variable(key) - key = key.value - built_map[key] = value - return DictVariable( - built_map, - graph=self._graph, - tracker=DummyTracker(keys + values), - ) - - def _rot_top_n(self, n): - # a1 a2 a3 ... an <- TOS - # the stack changes to - # an a1 a2 a3 an-1 <- TOS - assert ( - len(self._stack) >= n - ), f"There are not enough elements on the stack. {n} is needed." - top = self.pop() - self._stack[-(n - 1) : -(n - 1)] = [top] - - def POP_TOP(self, instr): - self.pop() - - def ROT_TWO(self, instr): - self._rot_top_n(2) - - def ROT_THREE(self, instr): - self._rot_top_n(3) - - def ROT_FOUR(self, instr): - self._rot_top_n(4) - - def UNPACK_SEQUENCE(self, instr): - sequence = self.pop() - - ''' - TODO: To unpack iterator - To unpack is easy, just like: - seq = tuple(sequence.value) - - But what is the `source` when iterator returned a value ? - ''' - if isinstance(sequence, TensorVariable): - # TODO: If need to unpack a Tensor, should have different logic. - raise NotImplementedError("Unpack a iterator is not implemented.") - elif isinstance(sequence, (ListVariable, TupleVariable)): - seq = sequence.value - else: - raise NotImplementedError(f"Unpack {sequence} is not implemented.") - - assert ( - len(seq) == instr.arg - ), f"Want unpack {seq} to {instr.arg}, but the len is {len(seq)}." - - for i in range(instr.arg - 1, -1, -1): - self.push( - VariableFactory.from_value( - seq[i], - graph=self._graph, - tracker=GetItemTracker(sequence, i), - ) - ) - - def BUILD_STRING(self, instr): - count = instr.arg - assert count <= len( - self._stack - ), f"OpExecutor want BUILD_STRING with size {count}, but current stack do not have enough elems." - str_list = self.pop_n(count) - new_str = '' - for s in str_list: - assert isinstance(s.value, str) - new_str += s.value - self.push(ConstantVariable.wrap_literal(new_str)) - - def FORMAT_VALUE(self, instr): - - flag = instr.arg - which_conversion = flag & FV.FVC_MASK - have_fmt_spec = bool((flag & FV.FVS_MASK) == FV.FVS_HAVE_SPEC) - - fmt_spec = self.pop().value if have_fmt_spec else "" - value = self.pop() - - if which_conversion == FV.FVC_NONE: - convert_fn = None - elif which_conversion == FV.FVC_STR: - convert_fn = "__str__" - elif which_conversion == FV.FVC_REPR: - convert_fn = "__repr__" - elif which_conversion == FV.FVC_ASCII: - convert_fn = "__ascii__" - else: - raise InnerError( - f"Unexpected conversion flag {flag} for FORMAT_VALUE" - ) - - # different type will lead to different Tracker, so call self.push in different branch - if isinstance(value, ConstantVariable): - result = value.value - if convert_fn is not None: - result = getattr(result, convert_fn)(result) - - if not isinstance(result, str) or fmt_spec != "": - result = format(result, fmt_spec) - - self.push( - VariableFactory.from_value( - result, self._graph, DummyTracker([value]) - ) - ) - else: - raise UnsupportError(f"Do not support format {type(value)} now") - - def build_seq_unpack(self, instr): - oparg = instr.arg - assert oparg <= len(self._stack) - unpack_values = self.pop_n(oparg) - - retval = [] - for item in unpack_values: - assert isinstance(item, (TupleVariable, ListVariable)) - retval.extend(item.get_wrapped_items()) - - if instr.opname in { - "BUILD_TUPLE_UNPACK_WITH_CALL", - "BUILD_TUPLE_UNPACK", - }: - retval = tuple(retval) - - self.push( - VariableFactory.from_value( - retval, self._graph, DummyTracker(unpack_values) - ) - ) - - def BUILD_TUPLE_UNPACK_WITH_CALL(self, instr): - self.build_seq_unpack(instr) - - def BUILD_TUPLE_UNPACK(self, instr): - self.build_seq_unpack(instr) - - def BUILD_LIST_UNPACK(self, instr): - self.build_seq_unpack(instr) - - def BUILD_MAP_UNPACK(self, instr): - oparg = instr.arg - assert oparg <= len(self._stack) - unpack_values = self.pop_n(oparg) - - retval = {} - for item in unpack_values: - assert isinstance(item.value, dict) - retval.update(item.get_wrapped_items()) - - self.push( - VariableFactory.from_value( - retval, self._graph, DummyTracker(unpack_values) - ) - ) - - def BUILD_MAP_UNPACK_WITH_CALL(self, instr): - oparg = instr.arg - assert oparg <= len(self._stack) - unpack_values = self.pop_n(oparg) - - retval = {} - for item in unpack_values: - assert isinstance(item.value, dict) - wrapped_item = item.get_wrapped_items() - if wrapped_item.items() & retval.items(): - raise InnerError( - "BUILD_MAP_UNPACK_WITH_CALL found repeated key." - ) - retval.update(wrapped_item) - - self.push( - VariableFactory.from_value( - retval, self._graph, DummyTracker(unpack_values) - ) - ) - - def MAKE_FUNCTION(self, instr): - fn_name = self.pop() - codeobj = self.pop() - global_dict = self._globals - - related_list = [fn_name, codeobj] - - flag = instr.arg - if flag & MF.MF_HAS_CLOSURE: - # closure should be a tuple of Variables - closure_variable = self.pop() - assert isinstance(closure_variable, TupleVariable) - related_list.append(closure_variable) - closure = tuple(closure_variable.get_wrapped_items()) - else: - closure = () - - if flag & MF.MF_HAS_ANNOTATION: - # can not set annotation in python env, skip it - related_list.append(self.pop()) - - if flag & MF.MF_HAS_KWDEFAULTS: - raise UnsupportError( - "Found need func_kwdefaults when MAKE_FUNCTION." - ) - - if flag & MF.MF_HAS_DEFAULTS: - ''' - default_args should have tracker too, like: - - def f(x): - def g(z=x): - pass - ''' - default_args_variable = self.pop() - assert isinstance(default_args_variable, TupleVariable) - related_list.append(default_args_variable) - default_args = tuple(default_args_variable.get_wrapped_items()) - else: - default_args = () - - new_fn = types.FunctionType( - codeobj.value, global_dict, fn_name.value, default_args, closure - ) - - self.push( - UserDefinedFunctionVariable( - new_fn, self._graph, DummyTracker(related_list) - ) - ) - - def BUILD_SLICE(self, instr): - if instr.arg == 3: - step = self.pop() - else: - step = None - stop = self.pop() - start = self.pop() - - related_list = [start, stop, step] if step else [start, stop] - - slice_ = slice(*(x.value for x in related_list)) - - self.push( - VariableFactory.from_value( - slice_, self._graph, DummyTracker(related_list) - ) - ) - - def DUP_TOP(self, instr): - self.push(self.peek()) - - def DUP_TOP_TWO(self, instr): - for ref in self.peek_n(2): - self.push(ref) - - def NOP(self, instr): - pass - - def GET_ITER(self, instr): - iterator = self.pop() - if isinstance(iterator, IterVariable): - return self.push(iterator) - - if isinstance(iterator, (ListVariable, TupleVariable)): - self.push( - SequenceIterVariable( - iterator, self._graph, DummyTracker([iterator]) - ) - ) - elif isinstance(iterator, DictVariable): - self.push( - DictIterVariable( - iterator, self._graph, DummyTracker([iterator]) - ) - ) - elif isinstance(iterator, TensorVariable): - self.push( - TensorIterVariable( - iterator, self._graph, DummyTracker([iterator]) - ) - ) - else: - self.push( - UserDefinedIterVariable( - iterator, self._graph, DummyTracker([iterator]) - ) - ) - - def FOR_ITER(self, instr): - iterator = self.pop() - assert isinstance(iterator, IterVariable) - - # simplely get next - if isinstance(iterator, (SequenceIterVariable, DictIterVariable)): - try: - val, next_iterator = iterator.next() - self.push( - next_iterator - ) # need a new iterator to replace the old one - self.push(val) - except StopIteration: - self._lasti = self.indexof(instr.jump_to) - - # TODO need support TensorIterVariable.next - - else: - self._fallback_in_for_loop(iterator, instr) - return Stop() - - -class OpcodeExecutor(OpcodeExecutorBase): - def __init__(self, frame): - graph = FunctionGraph(frame) - self._frame = frame - super().__init__(frame.f_code, graph) - - def _prepare_virtual_env(self): - for name, value in self._frame.f_locals.items(): - self._locals[name] = VariableFactory.from_value( - value, self._graph, LocalTracker(name) - ) - - for name, value in self._frame.f_globals.items(): - self._globals[name] = VariableFactory.from_value( - value, self._graph, GlobalTracker(name) - ) - - for name, value in self._frame.f_builtins.items(): - self._builtins[name] = VariableFactory.from_value( - value, self._graph, BuiltinTracker(name) - ) - - for value in self._code.co_consts: - self._co_consts.append( - VariableFactory.from_value( - value, self._graph, ConstTracker(value) - ) - ) - - def _create_resume_fn(self, index): - pycode_gen = PyCodeGen(self._frame) - fn, inputs = pycode_gen.gen_resume_fn_at(index) - return fn, inputs - - def _fallback_in_jump(self, result, instr): - if_fn, if_inputs = self._create_resume_fn(self.indexof(instr) + 1) - else_fn, else_inputs = self._create_resume_fn( - self.indexof(instr.jump_to) - ) - inputs_name = if_inputs | else_inputs - inputs_var = [ - self._locals[name] - for name in inputs_name - if self._locals[name] is not result - ] - ret_vars = [ - result, - ] + inputs_var - self._graph.start_compile(*ret_vars) - for _ in inputs_var: - self._graph.pycode_gen.gen_pop_top() - - if if_fn is not None: - self._graph.pycode_gen.gen_load_object( - if_fn, if_fn.__code__.co_name - ) - insert_index = len(self._graph.pycode_gen._instructions) - 1 - for name in if_inputs: - self._locals[name].reconstruct(self._graph.pycode_gen) - self._graph.pycode_gen.gen_call_function( - argc=if_fn.__code__.co_argcount - ) - self._graph.pycode_gen.gen_return() - else: - insert_index = len(self._graph.pycode_gen._instructions) - 1 - self._graph.pycode_gen.gen_return() - - if else_fn is not None: - self._graph.pycode_gen.gen_load_object( - else_fn, else_fn.__code__.co_name - ) - jump_to = self._graph.pycode_gen._instructions[-1] - for name in else_inputs: - self._locals[name].reconstruct(self._graph.pycode_gen) - self._graph.pycode_gen.gen_call_function( - argc=else_fn.__code__.co_argcount - ) - self._graph.pycode_gen.gen_return() - else: - self._graph.pycode_gen.gen_return() - jump_to = self._graph.pycode_gen._instructions[-1] - - self._graph.pycode_gen._insert_instr( - insert_index, instr.opname, jump_to=jump_to - ) - - self.new_code = self._graph.pycode_gen.gen_pycode() - self.guard_fn = self._graph.guard_fn - - def transform(self): - self.run() - if self.new_code is None: - raise InnerError("OpExecutor return a empty new_code.") - return self.new_code, self.guard_fn - - def _create_loop_body_fn(self, start, end): - pycode_gen = PyCodeGen(self._frame) - fn, inputs = pycode_gen.gen_loop_body_fn_between(start, end) - return fn, inputs - - def _fallback_in_for_loop(self, iterator, instr): - ''' - instr: the FOR_ITER opcode - - need find out opcodes which unpack value from FOR_ITER, by analysing stack - - case 1: - for i in iter: - - FOR_ITER - STORE_FAST i - - case 2: - for i,j in iter: - - FOR_ITER - UNPACK_SEQUENCE 2 - STORE_FAST i - STORE_FAST j - ''' - unpack_instr_idx = self.indexof(instr) + 1 - curent_stack = 1 - - while True: - if unpack_instr_idx >= len(self._instructions): - raise InnerError("Can not balance stack in loop body.") - cur_instr = self._instructions[unpack_instr_idx] - # do not consider jump instr - stack_effect = dis.stack_effect( - cur_instr.opcode, cur_instr.arg, jump=False - ) - curent_stack += stack_effect - unpack_instr_idx += 1 - if curent_stack == 0: - break - - loop_body, loop_inputs = self._create_loop_body_fn( - unpack_instr_idx, self.indexof(instr.jump_to) - ) - - after_loop_fn, fn_inputs = self._create_resume_fn( - self.indexof(instr.jump_to) - ) - - # 1. part before for-loop - inputs_var = [ - self._locals[name] for name in loop_inputs if name in self._locals - ] - self._graph.start_compile(*inputs_var) - - for _ in inputs_var: - self._graph.pycode_gen.gen_pop_top() - - # 2. load iterator to stack - iterator.reconstruct(self._graph.pycode_gen) - - # 3. gen FOR_ITER and unpack data - self._graph.pycode_gen.extend_instrs( - self._instructions[self.indexof(instr) : unpack_instr_idx] - ) - - # 4. call loop body - self._graph.pycode_gen.gen_load_object( - loop_body, loop_body.__code__.co_name - ) - - def update_locals(name, variable): - self._locals[name] = variable - return variable - - for name in loop_inputs: - if name in self._locals: - self._locals[name].reconstruct(self._graph.pycode_gen) - elif name in self._globals: - self._globals[name].reconstruct(self._graph.pycode_gen) - elif name in self._builtins: - self._builtins[name].reconstruct(self._graph.pycode_gen) - else: - variable = update_locals( - name, ObjectVariable(None, self._graph, LocalTracker(name)) - ) - variable.reconstruct(self._graph.pycode_gen) - - self._graph.pycode_gen.gen_call_function( - argc=loop_body.__code__.co_argcount - ) - - # 5. unpack and store - self._graph.pycode_gen.gen_unpack_sequence(len(loop_inputs)) - for name in loop_inputs: - self._graph.pycode_gen.gen_store_fast( - name - ) # TODO: need check data scope with globals, builtins - - # 6. add JUMP_ABSOLUTE - self._graph.pycode_gen.gen_jump_abs(instr) - - # 7. call after_loop_fn - self._graph.pycode_gen.gen_load_object( - after_loop_fn, after_loop_fn.__code__.co_name - ) - - for name in fn_inputs: - if name in self._locals: - self._locals[name].reconstruct(self._graph.pycode_gen) - elif name in self._globals: - self._globals[name].reconstruct(self._graph.pycode_gen) - elif name in self._builtins: - self._builtins[name].reconstruct(self._graph.pycode_gen) - - self._graph.pycode_gen.gen_call_function( - argc=after_loop_fn.__code__.co_argcount - ) - - self._graph.pycode_gen.gen_return() - self.new_code = self._graph.pycode_gen.gen_pycode() - self.guard_fn = self._graph.guard_fn diff --git a/python/paddle/jit/symbolic_trace/opcode_translator/executor/opcode_inline_executor.py b/python/paddle/jit/symbolic_trace/opcode_translator/executor/opcode_inline_executor.py deleted file mode 100644 index 2a556fb3662b9..0000000000000 --- a/python/paddle/jit/symbolic_trace/opcode_translator/executor/opcode_inline_executor.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import builtins -import inspect -from typing import TYPE_CHECKING - -from ...utils import log -from .guard import StringifyExpression, union_free_vars -from .opcode_executor import OpcodeExecutorBase, Stop -from .tracker import BuiltinTracker, ConstTracker, DummyTracker, Tracker - -if TYPE_CHECKING: - from .pycode_generator import PyCodeGen - from .variables import FunctionVariable - - -class FunctionGlobalTracker(Tracker): - def __init__(self, fn: FunctionVariable, name: str): - super().__init__([fn]) - self.fn = fn - self.name = name - - def gen_instructions(self, codegen: PyCodeGen): - self.fn.tracker.gen_instructions(codegen) - codegen.gen_load_attr("__globals__") - codegen.gen_load_const(self.name) - codegen.gen_subscribe() - - def trace_value_from_frame(self): - fn_tracer = self.fn.tracker.trace_value_from_frame() - return StringifyExpression( - f"{fn_tracer.expr}.__globals__['{self.name}']", - union_free_vars(fn_tracer.free_vars), - ) - - def __repr__(self) -> str: - return f"FunctionGlobalTracker(fn={self.fn}, name={self.name})" - - -class OpcodeInlineExecutor(OpcodeExecutorBase): - def __init__(self, fn_variable, *args, **kwargs): - self._fn_var = fn_variable - self._fn_value = fn_variable.value - self.return_value = None - super().__init__(fn_variable.get_code(), fn_variable.graph) - self._prepare_locals(*args, **kwargs) - # TODO: consider generator. - - def _prepare_locals(self, *args, **kwargs): - from .variables import VariableBase, VariableFactory - - sig = inspect.signature(self._fn_value) - bound_args = sig.bind(*args, **kwargs) - bound_args.apply_defaults() - for name, value in bound_args.arguments.items(): - assert name in sig.parameters - # Convert varargs and kwargs to Variable - if sig.parameters[name].kind == inspect.Parameter.VAR_POSITIONAL: - tracker = DummyTracker(value) - elif sig.parameters[name].kind == inspect.Parameter.VAR_KEYWORD: - tracker = DummyTracker(list(value.values())) - # Convert default args to Variable - elif not isinstance(value, VariableBase): - tracker = ConstTracker(value) - else: - tracker = value.tracker - value = VariableFactory.from_value(value, self._graph, tracker) - self._locals[name] = value - - log( - 5, f"[INLINE CALL] {self._code.co_name} with locals: ", self._locals - ) - - def _prepare_virtual_env(self): - # prepare globals - from .variables import VariableFactory - - for name, value in self._fn_value.__globals__.items(): - self._globals[name] = VariableFactory.from_value( - value, self._graph, FunctionGlobalTracker(self._fn_var, name) - ) - - # prepare builtins - for name, value in builtins.__dict__.items(): - self._builtins[name] = VariableFactory.from_value( - value, self._graph, BuiltinTracker(name) - ) - - # prepare consts - for value in self._code.co_consts: - self._co_consts.append( - VariableFactory.from_value( - value, self._graph, ConstTracker(value) - ) - ) - - def RETURN_VALUE(self, instr): - self.return_value = self.pop() - return Stop() - - def inline_call(self): - self.run() - return self.return_value diff --git a/python/paddle/jit/symbolic_trace/opcode_translator/executor/pycode_generator.py b/python/paddle/jit/symbolic_trace/opcode_translator/executor/pycode_generator.py deleted file mode 100644 index a6f1bccb590c2..0000000000000 --- a/python/paddle/jit/symbolic_trace/opcode_translator/executor/pycode_generator.py +++ /dev/null @@ -1,349 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This class is used for abstract code generation: -# We only need to care about what type of bytecode our code needs to generate, -# without worrying about the subscripts of bytecode instructions in the code option. - -from __future__ import annotations - -import dis -import types - -import opcode - -from ...utils import ( - ResumeFnNameFactory, - list_contain_by_id, - list_find_index_by_id, -) -from ..instruction_utils import ( - gen_instr, - get_instructions, - modify_instrs, - modify_vars, -) -from ..instruction_utils.opcode_analysis import read_write_analysis - -''' - code options for PyCodeObject -''' - -pycode_attributes = [ - "co_argcount", - "co_posonlyargcount", - "co_kwonlyargcount", - "co_nlocals", - "co_stacksize", - "co_flags", - "co_code", - "co_consts", - "co_names", - "co_varnames", - "co_filename", - "co_name", - "co_firstlineno", - "co_lnotab", - "co_freevars", - "co_cellvars", -] - - -def gen_code_options(code): - code_options = {} - for k in pycode_attributes: - val = getattr(code, k) - if isinstance(val, tuple): - val = list(val) - code_options[k] = val - return code_options - - -''' - generator a new code object -''' - - -def gen_new_opcode(instrs, code_options, keys): - bytecode, lnotab = assemble(instrs, code_options["co_firstlineno"]) - code_options["co_lnotab"] = lnotab - code_options["co_code"] = bytecode - code_options["co_nlocals"] = len(code_options["co_varnames"]) - code_options["co_stacksize"] = stacksize(instrs) - for key, val in code_options.items(): - if isinstance(val, list): - code_options[key] = tuple(val) - # code_options is a dict, use keys to makesure the input order - return types.CodeType(*[code_options[k] for k in keys]) - - -# list of instructions => bytecode & lnotab -def assemble(instructions, firstlineno): - cur_line = firstlineno - cur_bytecode = 0 - - code = [] - lnotab = [] - - for instr in instructions: - # set lnotab - if instr.starts_line is not None: - line_offset = instr.starts_line - cur_line - bytecode_offset = len(code) - cur_bytecode - - cur_line = instr.starts_line - cur_bytecode = len(code) - - lnotab.extend(modify_lnotab(bytecode_offset, line_offset)) - - # get bytecode - arg = instr.arg or 0 - code.extend((instr.opcode, arg & 0xFF)) - - return bytes(code), bytes(lnotab) - - -def to_byte(num): - if num < 0: - # -1 => 255 - num += 256 - return num - - -def modify_lnotab(byte_offset, line_offset): - if byte_offset > 127: - ret = [] - while byte_offset > 127: - ret.extend((127, 0)) - byte_offset -= 127 - # line_offset might > 127, call recursively - ret.extend(modify_lnotab(byte_offset, line_offset)) - return ret - - if line_offset > 127: - # here byte_offset < 127 - ret = [byte_offset, 127] - line_offset -= 127 - while line_offset > 0: - ret.extend((0, line_offset)) - line_offset -= 127 - return ret - - # both < 127 - return [to_byte(byte_offset), to_byte(line_offset)] - - -# TODO: need to update -def stacksize(instructions): - # two list below shows the possible stack size before opcode is called - # the stack size might be different in different branch, so it has max and min - max_stack = [float("-inf")] * len(instructions) - min_stack = [float("inf")] * len(instructions) - - max_stack[0] = 0 - min_stack[0] = 0 - - def update_stacksize(lasti, nexti, stack_effect): - max_stack[nexti] = max( - max_stack[nexti], max_stack[lasti] + stack_effect - ) - min_stack[nexti] = min( - min_stack[nexti], max_stack[lasti] + stack_effect - ) - - for idx in range(len(instructions)): - instr = instructions[idx] - - if idx + 1 < len(instructions): - stack_effect = dis.stack_effect(instr.opcode, instr.arg, jump=False) - update_stacksize(idx, idx + 1, stack_effect) - - if instr.opcode in opcode.hasjabs or instr.opcode in opcode.hasjrel: - stack_effect = dis.stack_effect(instr.opcode, instr.arg, jump=True) - target_idx = instructions.index(instr.jump_to) - update_stacksize(idx, target_idx, stack_effect) - - assert min(min_stack) >= 0 - return max(max_stack) - - -''' - helper to create new code object -''' - - -class PyCodeGen: - def __init__(self, frame): - self._frame = frame - self._origin_code = frame.f_code - self._code_options = gen_code_options(self._origin_code) - self._f_globals = frame.f_globals - self._instructions = [] - self.objname_map = {} # map from name to LOAD_GLOBAL index - - def gen_pycode(self): - """ - return a new pycode, which is runnable. - """ - modify_instrs(self._instructions) - modify_vars(self._instructions, self._code_options) - new_code = gen_new_opcode( - self._instructions, self._code_options, pycode_attributes - ) - return new_code - - def gen_resume_fn_at(self, index): - self._instructions = get_instructions(self._origin_code) - if self._instructions[index].opname == 'RETURN_VALUE': - return None, set() - inputs = read_write_analysis(self._instructions, index) - self._instructions = [ - gen_instr('JUMP_ABSOLUTE', jump_to=self._instructions[index]) - ] + self._instructions - - self._code_options['co_argcount'] = len(inputs) - # inputs should be at the front of the co_varnames - self._code_options['co_varnames'] = tuple( - list(inputs) - + [ - var_name - for var_name in self._origin_code.co_varnames - if var_name not in inputs - ] - ) - fn_name = ResumeFnNameFactory().next() - self._code_options['co_name'] = fn_name - - new_code = self.gen_pycode() - fn = types.FunctionType(new_code, self._f_globals, fn_name) - return fn, inputs - - def gen_loop_body_fn_between(self, start, end): - self._instructions = get_instructions(self._origin_code) - inputs = read_write_analysis(self._instructions, start) - - # del JUMP_ABSOLUTE at self._instructions[end-1] - self._instructions = self._instructions[start : end - 1] - for name in inputs: - self.gen_load_fast(name) - self.gen_build_tuple(len(inputs)) - self.gen_return() - - self._code_options['co_argcount'] = len(inputs) - self._code_options['co_varnames'] = tuple( - list(inputs) - + [ - var_name - for var_name in self._origin_code.co_varnames - if var_name not in inputs - ] - ) - fn_name = ResumeFnNameFactory().next() - self._code_options['co_name'] = fn_name - - new_code = self.gen_pycode() - fn = types.FunctionType(new_code, self._f_globals, fn_name) - return fn, inputs - - def gen_load_const(self, value): - # Python `list.index` will find an item equal to query, i.e. `query == item` - # returns a value of True. Since `1 == True`, this will result in an incorrect - # index. To avoid this problem, we use id for comparison. - if not list_contain_by_id(self._code_options["co_consts"], value): - self._code_options["co_consts"].append(value) - idx = list_find_index_by_id(self._code_options["co_consts"], value) - self._add_instr("LOAD_CONST", arg=idx, argval=value) - - def gen_load_global(self, name): - if name not in self._code_options["co_names"]: - self._code_options["co_names"].append(name) - idx = self._code_options["co_names"].index(name) - self._add_instr("LOAD_GLOBAL", arg=idx, argval=name) - - def gen_load_object(self, obj, obj_name): - if obj_name not in self.objname_map: - self._f_globals[obj_name] = obj - self._code_options["co_names"].append(obj_name) - idx = len(self._code_options["co_names"]) - 1 - self.objname_map[obj_name] = idx - idx = self.objname_map[obj_name] - self._add_instr("LOAD_GLOBAL", arg=idx, argval=obj_name) - - def gen_store_fast(self, name): - if name not in self._code_options["co_varnames"]: - self._code_options["co_varnames"].append(name) - idx = self._code_options["co_varnames"].index(name) - self._add_instr("STORE_FAST", arg=idx, argval=name) - - def gen_load_fast(self, name): - assert name in self._code_options["co_varnames"] - idx = self._code_options["co_varnames"].index(name) - self._add_instr("LOAD_FAST", arg=idx, argval=name) - - def gen_load_attr(self, name: str): - if name not in self._code_options["co_names"]: - self._code_options["co_names"].append(name) - idx = self._code_options["co_names"].index(name) - self._add_instr("LOAD_ATTR", arg=idx, argval=name) - - def gen_subscribe(self): - self._add_instr("BINARY_SUBSCR") - - def gen_build_tuple(self, count): - self._add_instr("BUILD_TUPLE", arg=count, argval=count) - - def gen_build_list(self, count): - self._add_instr("BUILD_LIST", arg=count, argval=count) - - def gen_build_map(self, count): - self._add_instr("BUILD_MAP", arg=count, argval=count) - - def gen_unpack_sequence(self, count): - self._add_instr("UNPACK_SEQUENCE", arg=count, argval=count) - - def gen_call_function(self, argc=0): - self._add_instr("CALL_FUNCTION", arg=argc, argval=argc) - - def gen_pop_top(self): - self._add_instr("POP_TOP") - - def gen_return(self): - self._add_instr("RETURN_VALUE") - - def add_pure_instructions(self, instructions): - """ - add instructions and do nothing. - """ - self._instructions.extend(instructions) - - def _add_instr(self, *args, **kwargs): - instr = gen_instr(*args, **kwargs) - self._instructions.append(instr) - - def _insert_instr(self, index, *args, **kwargs): - instr = gen_instr(*args, **kwargs) - self._instructions.insert(index, instr) - - def pprint(self): - for instr in self._instructions: - print(instr.opname, "\t\t", instr.argval) - - def gen_jump_abs(self, jump_to): - instr = gen_instr("JUMP_ABSOLUTE", jump_to=jump_to) - nop = gen_instr("NOP") - self._instructions.extend([instr, nop]) - jump_to.jump_to = nop - - def extend_instrs(self, instrs): - self._instructions.extend(instrs) diff --git a/python/paddle/jit/symbolic_trace/opcode_translator/executor/tracker.py b/python/paddle/jit/symbolic_trace/opcode_translator/executor/tracker.py deleted file mode 100644 index c7c0ea02a0799..0000000000000 --- a/python/paddle/jit/symbolic_trace/opcode_translator/executor/tracker.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import builtins -from typing import TYPE_CHECKING - -from ...utils import InnerError, NameGenerator -from .guard import StringifyExpression, union_free_vars - -if TYPE_CHECKING: - from .pycode_generator import PyCodeGen - from .variables import VariableBase - - -class Tracker: - inputs: list[VariableBase] - name_generator = NameGenerator("tracker_") - - def __init__(self, inputs: list[VariableBase]): - self.inputs = inputs - self.id = Tracker.name_generator.next() - - def gen_instructions(self, codegen: PyCodeGen): - raise NotImplementedError() - - def trace_value_from_frame(self) -> StringifyExpression: - raise NotImplementedError() - - def is_traceable(self): - for input in self.inputs: - if not input.tracker.is_traceable(): - return False - return True - - -class DummyTracker(Tracker): - def __init__(self, inputs: list[VariableBase]): - super().__init__(inputs) - - def gen_instructions(self, codegen: PyCodeGen): - raise InnerError("DummyTracker has no instructions") - - def trace_value_from_frame(self): - raise InnerError("DummyTracker can't trace value from frame") - - def is_traceable(self): - return False - - def __repr__(self) -> str: - return f"DummyTracker(num_inputs={len(self.inputs)})" - - -class LocalTracker(Tracker): - def __init__(self, name: str): - super().__init__([]) - self.name = name - - def gen_instructions(self, codegen: PyCodeGen): - codegen.gen_load_fast(self.name) - - def trace_value_from_frame(self): - return StringifyExpression(f"frame.f_locals['{self.name}']", {}) - - def __repr__(self) -> str: - return f"LocalTracker(name={self.name})" - - -class GlobalTracker(Tracker): - def __init__(self, name): - super().__init__([]) - self.name = name - - def gen_instructions(self, codegen: PyCodeGen): - codegen.gen_load_global(self.name) - - def trace_value_from_frame(self): - return StringifyExpression(f"frame.f_globals['{self.name}']", {}) - - def __repr__(self) -> str: - return f"GlobalTracker(name={self.name})" - - -class BuiltinTracker(Tracker): - def __init__(self, name: str): - super().__init__([]) - self.name = name - - def gen_instructions(self, codegen: PyCodeGen): - codegen.gen_load_global(self.name) - - def trace_value_from_frame(self): - return StringifyExpression( - f"builtins.__dict__[{self.name}]", {"builtins": builtins} - ) - - def __repr__(self) -> str: - return f"BuiltinTracker(name={self.name})" - - -class ConstTracker(Tracker): - def __init__(self, value): - super().__init__([]) - self.value = value - - def gen_instructions(self, codegen: PyCodeGen): - codegen.gen_load_const(self.value) - - def trace_value_from_frame(self): - return StringifyExpression(f"{self.value}", {}) - - def __repr__(self) -> str: - return f"ConstTracker(value={self.value})" - - -class GetAttrTracker(Tracker): - def __init__(self, obj: VariableBase, attr: str): - super().__init__([obj]) - self.obj = obj - self.attr = attr - - def gen_instructions(self, codegen: PyCodeGen): - self.obj.tracker.gen_instructions(codegen) - codegen.gen_load_attr(self.attr) - - def trace_value_from_frame(self): - obj_tracer = self.obj.tracker.trace_value_from_frame() - if self.attr.isidentifier(): - expr = f"{obj_tracer.expr}.{self.attr}" - else: - expr = f"getattr({obj_tracer.expr}, '{self.attr}')" - return StringifyExpression( - expr, - union_free_vars(obj_tracer.free_vars), - ) - - def __repr__(self) -> str: - return f"GetAttrTracker(attr={self.attr})" - - -class GetItemTracker(Tracker): - def __init__(self, container_var: VariableBase, key: object): - super().__init__([container_var]) - self.container = container_var - self.key = key - - def gen_instructions(self, codegen: PyCodeGen): - self.container.tracker.gen_instructions(codegen) - codegen.gen_load_const(self.key) - codegen.gen_subscribe() - - def trace_value_from_frame(self): - container_tracer = self.container.tracker.trace_value_from_frame() - return StringifyExpression( - f"{container_tracer.expr}[{self.key!r}]", - union_free_vars(container_tracer.free_vars), - ) - - def __repr__(self) -> str: - return f"GetItemTracker(key={self.key!r})" diff --git a/python/paddle/jit/symbolic_trace/opcode_translator/executor/tracker_viewer.py b/python/paddle/jit/symbolic_trace/opcode_translator/executor/tracker_viewer.py deleted file mode 100644 index 5de7f61a955a8..0000000000000 --- a/python/paddle/jit/symbolic_trace/opcode_translator/executor/tracker_viewer.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import queue -from typing import TYPE_CHECKING - -from .tracker import DummyTracker -from .variables import VariableBase - -SIR_GRAPH_CLUSTER_NAME = "cluster_sir_part" - -if TYPE_CHECKING: - import graphviz - - -def try_import_graphviz(): - try: - import graphviz - - return graphviz - except ImportError: - return None - - -def draw_variable(graph: graphviz.Digraph, var: VariableBase): - # Draw Variable - graph.attr('node', shape='oval', style="solid") - graph.attr('edge', style='solid') - graph.node(var.id, str(var)) - - # Draw Tracker - tracker = var.tracker - if isinstance(tracker, DummyTracker): - graph.attr('edge', style='dashed') - graph.attr('node', style='dashed') - graph.attr('node', shape='rect') - graph.node(tracker.id, str(tracker)) - - # Draw edge (Tracker -> Variable) - graph.edge(tracker.id, var.id) - - # Draw edge (Tracker inputs -> Tracker) - graph.attr('node', shape='oval') - graph.attr('node', shape='oval', style="solid") - for input in tracker.inputs: - graph.edge(input.id, tracker.id) - - -def view_tracker( - root_variables: list[VariableBase], filename: str, format: str -): - # TODO(SigureMo): - # 1. Colorize the trackers - # 2. Highlight the user specific node, to speedup debug process - graphviz = try_import_graphviz() - if graphviz is None: - print("Cannot import graphviz, please install it first.") - return - - graph = graphviz.Digraph("graph", filename=filename, format=format) - visited = set() - var_queue = queue.Queue() - for var in root_variables: - var_queue.put(var) - - while not var_queue.empty(): - var = var_queue.get() - if var.id in visited: - continue - visited.add(var.id) - if isinstance(var.tracker, DummyTracker): - with graph.subgraph(name=SIR_GRAPH_CLUSTER_NAME) as sir_part: - sir_part.attr(color='green') - draw_variable(sir_part, var) - else: - draw_variable(graph, var) - for input in var.tracker.inputs: - if input not in var_queue.queue: - var_queue.put(input) - - graph.render(view=False) diff --git a/python/paddle/jit/symbolic_trace/opcode_translator/executor/variable_monkey_patch.py b/python/paddle/jit/symbolic_trace/opcode_translator/executor/variable_monkey_patch.py deleted file mode 100644 index 87c7441596ae4..0000000000000 --- a/python/paddle/jit/symbolic_trace/opcode_translator/executor/variable_monkey_patch.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ...utils.monkey_patch import ( - binary_operator_methods, - do_monkey_patch, - unary_operator_methods, -) -from .variables import ConstantVariable, TensorVariable - - -# TensorVaraible MonkeyPatch -def tensor_variable_unary_method_builder(method_name): - def __impl__(self): - return self.graph.call_tensor_method(method_name, self) - - return __impl__ - - -def tensor_variable_binary_method_builder(method_name): - def __impl__(self, other): - if not isinstance(other, (ConstantVariable, TensorVariable)): - return NotImplemented - return self.graph.call_tensor_method(method_name, self, other) - - return __impl__ - - -do_monkey_patch( - TensorVariable, unary_operator_methods, tensor_variable_unary_method_builder -) -do_monkey_patch( - TensorVariable, - binary_operator_methods, - tensor_variable_binary_method_builder, -) - - -# ConstantVariable MonkeyPatch -def constant_variable_unary_method_builder(method_name): - def __impl__(self): - return self.apply_unary_operator(method_name) - - return __impl__ - - -def constant_variable_binary_method_builder(method_name): - def __impl__(self, other): - return self.apply_binary_operator(other, method_name) - - return __impl__ - - -do_monkey_patch( - ConstantVariable, - unary_operator_methods, - constant_variable_unary_method_builder, -) - -do_monkey_patch( - ConstantVariable, - binary_operator_methods, - constant_variable_binary_method_builder, -) diff --git a/python/paddle/jit/symbolic_trace/opcode_translator/executor/variables.py b/python/paddle/jit/symbolic_trace/opcode_translator/executor/variables.py deleted file mode 100644 index def44ebd6a359..0000000000000 --- a/python/paddle/jit/symbolic_trace/opcode_translator/executor/variables.py +++ /dev/null @@ -1,1111 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import collections -import inspect -import types -from queue import Queue -from typing import TYPE_CHECKING, Any, Callable - -import paddle - -from ...infer_meta import MetaInfo -from ...proxy_tensor import ProxyTensor, ProxyTensorContext -from ...symbolic.statement_ir import Symbol -from ...utils import ASSERT, NameGenerator, is_paddle_api, log_do -from ...utils.exceptions import BreakGraphError, FallbackErrorBase, InnerError -from .guard import StringifyExpression, union_free_vars -from .pycode_generator import PyCodeGen -from .tracker import ( - ConstTracker, - DummyTracker, - GetAttrTracker, - GetItemTracker, - Tracker, -) - -if TYPE_CHECKING: - from .function_graph import FunctionGraph - - -ConstTypes = (int, float, str, bool, type(None)) - - -def get_zero_degree_vars( - variables: set[VariableBase], visited_vars: list[VariableBase] -) -> list[VariableBase]: - return [ - var - for var in variables - if var not in visited_vars - and len(set(var.get_traceable_inputs()) - set(visited_vars)) == 0 - ] - - -def topo_sort_vars( - root_vars: list[VariableBase], -) -> list[VariableBase]: - unique_vars = set() - - for var in root_vars: - unique_vars.add(var) - unique_vars |= set(var.flatten_traceable_inputs()) - - topo_ordered_vars = [] - topo_queue = Queue() - for var in get_zero_degree_vars(unique_vars, topo_ordered_vars): - topo_queue.put(var) - - while not topo_queue.empty(): - var = topo_queue.get() - topo_ordered_vars.append(var) - for zero_degree_var in get_zero_degree_vars( - unique_vars, topo_ordered_vars - ): - if ( - zero_degree_var in topo_queue.queue - or zero_degree_var in topo_ordered_vars - ): - continue - topo_queue.put(zero_degree_var) - return topo_ordered_vars - - -class VariableFactory: - registered_funcs: list[Callable] = [] - - @staticmethod - def default_from_value(value, graph, tracker): - return ObjectVariable(value, graph, tracker) - - @staticmethod - def register_from_value(from_value_func: Callable): - VariableFactory.registered_funcs.append(from_value_func) - - @staticmethod - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - for func in VariableFactory.registered_funcs: - var = func(value, graph, tracker) - if var is not None: - return var - return VariableFactory.default_from_value(value, graph, tracker) - - -class VariableBase: - """ - VariableBase is a basic concept and each symbols in VM stack is regarded as - an Variable Object in symblic tracing process. - """ - - tracker: Tracker - name_generator = NameGenerator("object_") - - def __init__(self, tracker: Tracker): - self.tracker = tracker - self.id = VariableBase.name_generator.next() - - def __hash__(self): - return hash(self.id) - - def make_stringify_guard(self) -> StringifyExpression: - assert not isinstance( - self.tracker, DummyTracker - ), "Can not make guard from dummy tracker" - - frame_value_tracer = self.tracker.trace_value_from_frame() - log_do( - 4, - lambda: print( - f"[Guard]: guard_fn for {self}, tracker={self.tracker.__class__.__name__}, value={frame_value_tracer.expr}" - ), - ) - if isinstance(self, TensorVariable): - return StringifyExpression( - f"str(MetaInfo.from_tensor({frame_value_tracer.expr})) == '{self.get_value().meta}'", - union_free_vars( - {"MetaInfo": MetaInfo}, - frame_value_tracer.free_vars, - ), - ) - if isinstance(self, LayerVariable): - return StringifyExpression( - f"id({frame_value_tracer.expr}) == {id(self.get_value())}", - union_free_vars(frame_value_tracer.free_vars), - ) & StringifyExpression( - f"{frame_value_tracer.expr}.training == {self.get_value().training}", - union_free_vars(frame_value_tracer.free_vars), - ) - return StringifyExpression( - f"{frame_value_tracer.expr} == {self.get_value()}", - union_free_vars(frame_value_tracer.free_vars), - ) - - def get_value(self) -> Any: - raise NotImplementedError() - - def reconstruct(self, codegen: PyCodeGen): - """ - Contruct an opcode and append it into codegen.instructions. - """ - if ( - not isinstance(self.tracker, DummyTracker) - and self.tracker.is_traceable() - ): - self.tracker.gen_instructions(codegen) - else: - self._reconstruct(codegen) - - def _reconstruct(self, codegen: PyCodeGen): - raise NotImplementedError() - - def flatten_items(self) -> list[VariableBase]: - if not isinstance(self, ContainerVariable): - return [self] - flattened_items = [] - for item in self.get_items(): - flattened_items.extend(item.flatten_items()) - return flattened_items - - def get_inputs(self) -> list[VariableBase]: - return self.tracker.inputs - - def get_traceable_inputs(self) -> list[VariableBase]: - if self.tracker.is_traceable: - return [] - - return list( - filter(lambda x: x.tracker.is_traceable(), self.tracker.inputs) - ) - - def flatten_traceable_inputs(self) -> list[VariableBase]: - flattened_traceable_inputs: list[VariableBase] = [self] - if self.tracker.is_traceable: - return flattened_traceable_inputs - - for input in self.get_inputs(): - flattened_traceable_inputs.extend(input.flatten_traceable_inputs()) - return flattened_traceable_inputs - - def call_function(self, *args, **kwargs): - pass - - def getattr(self, *args, **kwargs): - pass - - def getitem(self, *args, **kwargs): - pass - - @VariableFactory.register_from_value - def from_value( - value: Any, - graph: FunctionGraph | None, - tracker: Tracker, - ): - if isinstance(value, VariableBase): - return value - return None - - -class ConstantVariable(VariableBase): - def __init__( - self, - value: Any, - tracker: Tracker, - ): - super().__init__(tracker) - self.value = value - - def get_value(self): - return self.value - - def _reconstruct(self, codegen: PyCodeGen): - codegen.gen_load_const(self.value) - - def __repr__(self) -> str: - return f"ConstantVariable({self.value})" - - def __bool__(self) -> bool: - return bool(self.value) - - def apply_unary_operator(self, magic_name): - operator = getattr(self.value, magic_name) - var = VariableFactory.from_value( - operator(), - None, - tracker=DummyTracker( - [ - self, - ] - ), - ) - return var - - def apply_binary_operator(self, other, magic_name): - if not isinstance(other, ConstantVariable): - return NotImplemented - operator = getattr(self.value, magic_name) - var = VariableFactory.from_value( - operator(other.value), None, tracker=DummyTracker([self, other]) - ) - return var - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, ConstTypes): - return ConstantVariable(value, tracker) - return None - - @staticmethod - def wrap_literal(value: Any) -> ConstantVariable: - if isinstance(value, ConstantVariable): - return value - assert isinstance( - value, ConstTypes - ), f"value: {value},type: {type(value)}" - return ConstantVariable(value, ConstTracker(value)) - - -class TensorVariable(VariableBase): - def __init__( - self, - tensor: paddle.Tensor | ProxyTensor, - graph: FunctionGraph, - tracker: Tracker, - ): - super().__init__(tracker) - if isinstance(tensor, paddle.Tensor): - self.value: ProxyTensor = ProxyTensorContext().from_tensor(tensor) - elif isinstance(tensor, ProxyTensor): - self.value = tensor - else: - raise InnerError( - "Required type(tensor) is paddle.Tensor or ProxyTensor, but received {}.".format( - type(tensor).__name__ - ) - ) - self.graph = graph - - def get_value(self): - return self.value - - def get_symbol(self) -> Symbol: - return Symbol(self.value.name) - - @property - def out_var_name(self): - return f"{self.graph.out_var_prefix}{self.value.name}" - - def _reconstruct(self, codegen: PyCodeGen): - codegen.gen_load_fast(self.out_var_name) - - def __repr__(self) -> str: - return f"TensorVariable{self.value.meta}" - - def __getitem__(self, key): - return self.graph.call_tensor_method('__getitem__', self, key) - - @property - def T(self): - perm = list(range(len(self.value.shape) - 1, -1, -1)) - perm_var = VariableFactory.from_value( - perm, self.graph, tracker=ConstTracker(perm) - ) - out = self.graph.call_paddle_api(paddle.transpose, self, perm_var) - return out - - def __getattr__(self, name: str): - if callable(getattr(paddle.Tensor, name)): - return TensorMethodVariable( - self, name, self.graph, tracker=GetAttrTracker(self, name) - ) - else: - return VariableFactory.from_value( - getattr(self.value, name), - self.graph, - tracker=GetAttrTracker(self, name), - ) - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, (paddle.Tensor, ProxyTensor)): - assert graph is not None - return TensorVariable(value, graph, tracker) - return None - - -class ContainerVariable(VariableBase): - def get_items(self) -> list[VariableBase]: - raise NotImplementedError() - - def __len__(self): - raise NotImplementedError() - - def __bool__(self): - return len(self) > 0 - - -class ListVariable(ContainerVariable): - def __init__( - self, - val_list: list[VariableBase], - graph: FunctionGraph, - tracker: Tracker, - ): - super().__init__(tracker) - self.graph = graph - # everything in stack is VariableBase, so just accept the input list is ok - self.value = val_list - - def get_value(self): - return [self[i].get_value() for i in range(len(self))] - - def _reconstruct(self, codegen: PyCodeGen): - size = len(self) - for idx in range(size): - self[idx].reconstruct(codegen) - codegen.gen_build_list(size) - - def get_items(self): - size = len(self) - return [self[idx] for idx in range(size)] - - def get_wrapped_items(self): - return self.get_items() - - def __repr__(self) -> str: - return f"ListVariable(len={len(self)})" - - def __len__(self): - return len(self.value) - - def __getitem__(self, key): - ''' - we need to make sure that: - before an inplace change happens to ListVariable, - the related items should already be wrapped as VariableBase - - if not, tracker might be set to a wrong elem - ''' - if isinstance(key, VariableBase): - raise InnerError( - f"[{self.__class__.__name__}]: recieved {key} as key." - ) - - retval = self.value[key] - - # if list is an input of funciton, we need make sure __getitem__ returns a VariableBase - retval = VariableFactory.from_value( - retval, self.graph, tracker=GetItemTracker(self, key) - ) - - return retval - - def __setitem__(self, key, value): - ''' - why __setitem__ is ok: - - case: - def f(x = [t0, t1]) - ... - x[0] = 0 - ... - - 1. if setitem happens after get t0: t0 is a VariableBase (transformed at getitem), so it is ok - 2. if setitem happens before get t0: t0 will not be used - ''' - if isinstance(key, VariableBase): - raise InnerError( - f"[{self.__class__.__name__}]: received {key} as key." - ) - - if not isinstance(value, VariableBase): - raise InnerError( - f"[{self.__class__.__name__}]: received {value} to set value." - ) - self.value[key] = value - - def __delitem__(self, key): - if isinstance(key, VariableBase): - raise InnerError( - f"[{self.__class__.__name__}]: received {key} as key to delete." - ) - del self.value[key] - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, list): - assert graph is not None - return ListVariable(value, graph=graph, tracker=tracker) - return None - - -class TupleVariable(ContainerVariable): - def __init__( - self, - val_tuple: list[VariableBase], - graph: FunctionGraph, - tracker: Tracker, - ): - super().__init__(tracker) - self.graph = graph - # exactly it is a list (need replace item with VariableBase) - self.value = list(val_tuple) - - def get_value(self): - return tuple(self[i].get_value() for i in range(len(self))) - - def _reconstruct(self, codegen: PyCodeGen): - size = len(self) - for idx in range(size): - self[idx].reconstruct(codegen) - codegen.gen_build_tuple(size) - - def get_items(self): - size = len(self) - return [self[idx] for idx in range(size)] - - def get_wrapped_items(self): - return self.get_items() - - def __repr__(self) -> str: - return f"TupleVariable(len={len(self)})" - - def __len__(self): - return len(self.value) - - def __getitem__(self, key): - if isinstance(key, VariableBase): - raise InnerError( - f"[{self.__class__.__name__}]: recieved {key} as key." - ) - retval = self.value[key] - - return VariableFactory.from_value( - retval, graph=self.graph, tracker=GetItemTracker(self, key) - ) - - def __setitem__(self, key, value): - raise InnerError( - f"[{self.__class__.__name__}]: setitem is not allowed." - ) - - def __delitem__(self, key): - raise InnerError( - f"[{self.__class__.__name__}]: delitem is not allowed." - ) - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, tuple): - return TupleVariable(value, graph, tracker) - return None - - -class DictVariable(ContainerVariable): - def __init__( - self, - val_dict: dict[object, VariableBase], - graph: FunctionGraph, - tracker: Tracker, - ): - super().__init__(tracker) - self.graph = graph - self.value = val_dict - - def get_value(self): - return {key: self[key].get_value() for key in self.value} - - def _reconstruct(self, codegen: PyCodeGen): - size = len(self) - for key in self.value.keys(): - if not isinstance(key, ConstTypes): - raise InnerError( - f"[{self.__class__.__name__}]: recieved {key} as key." - ) - key_var = ConstantVariable.wrap_literal(key) - value_var = self[key] - key_var.reconstruct(codegen) - value_var.reconstruct(codegen) - codegen.gen_build_map(size) - - def get_items(self): - items = [] - for key in self.value.keys(): - if not isinstance(key, ConstTypes): - raise InnerError( - f"[{self.__class__.__name__}]: recieved {key} as key." - ) - key_var = VariableFactory.from_value( - key, self.graph, tracker=ConstTracker(key) - ) - value_var = self[key] - items.extend([key_var, value_var]) - return items - - def get_wrapped_items(self): - items = {} - for key in self.value.keys(): - if not isinstance(key, ConstTypes): - raise InnerError( - f"[{self.__class__.__name__}]: recieved {key} as key." - ) - items[key] = self[key] - return items - - def __repr__(self) -> str: - return f"DictVariable(len={len(self)})" - - def __len__(self): - return len(self.value) - - def __getitem__(self, key): - if isinstance(key, VariableBase): - raise InnerError( - f"[{self.__class__.__name__}]: recieved {key} as key." - ) - - retval = self.value[key] - - return VariableFactory.from_value( - retval, self.graph, tracker=GetItemTracker(self, key) - ) - - def __setitem__(self, key, value): - if isinstance(key, VariableBase): - raise InnerError( - f"[{self.__class__.__name__}]: recieved {key} as key." - ) - - if not isinstance(value, ConstantVariable): - raise InnerError( - f"[{self.__class__.__name__}]: recieved {value} to set value." - ) - - self.value[key] = value - - def __delitem__(self, key): - if isinstance(key, VariableBase): - raise InnerError( - f"[{self.__class__.__name__}]: recieved {key} as key to delete." - ) - del self.value[key] - - def __getattr__(self, name): - def keys(self): - raw_list = [ - ConstantVariable(x, ConstTracker(x)) for x in self.value.keys() - ] - key_list = VariableFactory.from_value( - raw_list, self.graph, ConstTracker(raw_list) - ) - return SequenceIterVariable( - key_list, self.graph, DummyTracker([key_list]) - ) - - def values(self): - raw_list = list(self.get_wrapped_items().values()) - value_list = VariableFactory.from_value( - raw_list, self.graph, DummyTracker([self]) - ) - return SequenceIterVariable( - value_list, self.graph, DummyTracker([value_list]) - ) - - def items(self): - keys = [ - ConstantVariable(x, ConstTracker(x)) for x in self.value.keys() - ] - values = list(self.get_wrapped_items().values()) - raw_list = list(zip(keys, values)) - item_list = VariableFactory.from_value( - raw_list, self.graph, DummyTracker([self]) - ) - return SequenceIterVariable( - item_list, self.graph, DummyTracker([item_list]) - ) - - if name == "keys": - return DirectlyCallMethodVariable( - None, - types.MethodType(keys, self), - self.graph, - GetAttrTracker(self, "keys"), - ) - elif name == "values": - return DirectlyCallMethodVariable( - None, - types.MethodType(values, self), - self.graph, - GetAttrTracker(self, "values"), - ) - elif name == "items": - return DirectlyCallMethodVariable( - None, - types.MethodType(items, self), - self.graph, - GetAttrTracker(self, "items"), - ) - else: - raise NotImplementedError( - f"attribute {name} for dict is not implemented" - ) - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, dict): - assert graph is not None - return DictVariable(value, graph=graph, tracker=tracker) - - -class CallableVariable(VariableBase): - def __init__(self, graph: FunctionGraph, tracker: Tracker): - super().__init__(tracker) - self.graph = graph - - def __call__(self, *args, **kwargs) -> VariableBase: - return self.call_function(*args, **kwargs) - - def call_function(self, *args, **kwargs): - raise NotImplementedError("call_function is not implemented.") - - -class FunctionVariable(CallableVariable): - def __init__( - self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker - ): - super().__init__(graph, tracker) - self.value = fn - - def get_value(self): - return self.value - - def get_code(self) -> types.CodeType: - return self.value.__code__ - - -class PaddleApiVariable(FunctionVariable): - def __init__( - self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker - ): - super().__init__(fn, graph, tracker) - - def call_function(self, *args, **kwargs): - return self.graph.call_paddle_api(self.value, *args, **kwargs) - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - # This should be front of FunctionVariable to avoid conflict. - if callable(value) and is_paddle_api(value): - return PaddleApiVariable(value, graph, tracker) - return None - - def __repr__(self) -> str: - return f"PaddleApiVariable({self.value.__name__})" - - -class UserDefinedFunctionVariable(FunctionVariable): - def __init__( - self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker - ): - super().__init__(fn, graph, tracker) - - def call_function(self, *args, **kwargs) -> VariableBase: - from .opcode_inline_executor import OpcodeInlineExecutor - - if self.value is ASSERT: - return self.value(args[0].value) - - checkpoint = self.graph.save_memo() - try: - inline_executor = OpcodeInlineExecutor(self, *args, **kwargs) - output = inline_executor.inline_call() - except FallbackErrorBase as e: - self.graph.restore_memo(checkpoint) - raise BreakGraphError( - f"{self.value} is raise a inline call error. {e}" - ) - return output - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, (types.FunctionType)): - return UserDefinedFunctionVariable(value, graph, tracker) - return None - - def __repr__(self) -> str: - return f"UserDefinedFunctionVariable({self.value.__name__})" - - -class MethodVariable(CallableVariable): - def __init__( - self, - bound_instance: VariableBase, - graph: FunctionGraph, - tracker: Tracker, - ): - super().__init__(graph, tracker) - self.bound_instance = bound_instance - - -class TensorMethodVariable(MethodVariable): - def __init__( - self, - tensor: TensorVariable, - method_name: str, - graph: FunctionGraph, - tracker: Tracker, - ): - super().__init__(tensor, graph, tracker) - self.tensor = tensor - self.method_name = method_name - - def get_value(self): - return getattr(self.tensor, self.method_name) - - def call_function(self, *args, **kwargs): - return self.graph.call_tensor_method( - self.method_name, self.tensor, *args, **kwargs - ) - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if inspect.ismethod(value) and isinstance( - value.__self__, paddle.Tensor - ): - # NOTE(SigureMo): Since the method_self need method_var as the obj - # of the tracker, we need to temporarily set the tracker of method_self - # to DummyTracker, and set it to GetAttrTracker after method_var is created. - method_self = TensorVariable( - value.__self__, graph, DummyTracker([]) - ) - method_var = TensorMethodVariable( - method_self, - value.__name__, - graph, - tracker, - ) - method_self.tracker = GetAttrTracker(method_var, "__self__") - return method_var - return None - - def __repr__(self) -> str: - return f"TensorMethodVariable({self.method_name})" - - -class UserDefinedMethodVariable(MethodVariable): - def __init__( - self, bound_instance, fn, graph: FunctionGraph, tracker: Tracker - ): - super().__init__(bound_instance, graph, tracker) - self.bound_instance = bound_instance - self.fn = fn - - def get_value(self): - return self.fn.__get__( - self.bound_instance, self.bound_instance.__class__ - ) - - def call_function(self, *args, **kwargs): - fn_var = UserDefinedFunctionVariable( - self.fn, self.graph, GetAttrTracker(self, "__func__") - ) - - return fn_var(*(self.bound_instance, *args), **kwargs) - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if inspect.ismethod(value): - method_self = VariableFactory.from_value( - value.__self__, graph, DummyTracker([]) - ) - method_var = UserDefinedMethodVariable( - method_self, - value.__func__, - graph, - tracker, - ) - method_self.tracker = GetAttrTracker(method_var, "__self__") - return method_var - return None - - def __repr__(self) -> str: - return f"UserDefinedMethodVariable({self.fn.__name__})" - - -class DirectlyCallMethodVariable(MethodVariable): - def __init__( - self, bound_instance, fn, graph: FunctionGraph, tracker: Tracker - ): - super().__init__(bound_instance, graph, tracker) - self.bound_instance = bound_instance - self.fn = fn - - def get_value(self): - return self.fn.__get__( - self.bound_instance, self.bound_instance.__class__ - ) - - def call_function(self, *args, **kwargs): - return self.fn() - - -class LayerVariable(CallableVariable): - def __init__( - self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker - ): - super().__init__(graph, tracker) - self.value = layer - - def get_value(self): - return self.value - - def __getattr__(self, name: str): - if not hasattr(self.value, name): - raise InnerError(f"LayerVariable {self} has no attribute {name}") - attr = getattr(self.value, name) - if inspect.ismethod(attr): - return UserDefinedMethodVariable( - self, attr.__func__, self.graph, GetAttrTracker(self, name) - ) - return VariableFactory.from_value( - attr, self.graph, tracker=GetAttrTracker(self, name) - ) - - -class PaddleLayerVariable(LayerVariable): - def __init__( - self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker - ): - super().__init__(layer, graph, tracker) - self.name = self.graph.sir_ctx.new_layername() - - def get_symbol(self) -> Symbol: - return Symbol(self.name) - - def call_function(self, *args, **kwargs): - # TODO: Remove this trick after we support for-loop. - if isinstance(self.value, paddle.nn.Sequential): - assert len(args) == 1, "Sequential only accept one input" - input = args[0] - for i, layer in enumerate(self.value._sub_layers.values()): - layer_var = VariableFactory.from_value( - layer, self.graph, tracker=GetItemTracker(self, i) - ) - assert isinstance(layer_var, LayerVariable) - input = layer_var(input) - return input - return self.graph.call_layer(self, *args, **kwargs) - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - # TODO(SigureMo): Add a more common way to check if a value is a paddle builtin layer. - if isinstance(value, paddle.nn.Layer) and value.__module__.startswith( - "paddle.nn." - ): - return PaddleLayerVariable(value, graph, tracker) - return None - - def __getattr__(self, name: str): - if not hasattr(self.value, name): - raise InnerError( - f"PaddleLayerVariable {self} has no attribute {name}" - ) - attr = getattr(self.value, name) - return VariableFactory.from_value( - attr, self.graph, tracker=GetAttrTracker(self, name) - ) - - def __repr__(self) -> str: - return f"PaddleLayerVariable({self.value.__class__.__name__})" - - -class UserDefinedLayerVariable(LayerVariable): - def __init__( - self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker - ): - super().__init__(layer, graph, tracker) - - def call_function(self, *args, **kwargs): - fn_var = UserDefinedFunctionVariable( - self.value.__class__.__call__, - self.graph, - GetAttrTracker(self, "__call__"), - ) - - return fn_var(*(self, *args), **kwargs) - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance( - value, paddle.nn.Layer - ) and not value.__module__.startswith("paddle.nn."): - return UserDefinedLayerVariable(value, graph, tracker) - return None - - def __repr__(self) -> str: - return f"UserDefinedLayerVariable({self.value.__class__.__name__})" - - -class BuiltinVariable(CallableVariable): - def __init__( - self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker - ): - super().__init__(graph, tracker) - self.value = fn - - def call_function(self, *args, **kwargs): - # TODO(0x45f): For builtin functions, may have 3 different ways to process as below: - # 1. Simulation execution: ensure correct simulation execution and handle trackers with care - # 2. Trigger the paddle api call - # 3. Trigger fallback - args = [ - arg.value if isinstance(arg, ConstantVariable) else arg - for arg in args - ] - kwargs = { - k: (v.value if isinstance(v, ConstantVariable) else v) - for k, v in kwargs.items() - } - return self.value(*args, **kwargs) - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, (types.BuiltinFunctionType)): - return BuiltinVariable(value, graph, tracker) - return None - - def __repr__(self) -> str: - return f"BuiltinVariable({self.value.__name__})" - - -class SliceVariable(VariableBase): - def __init__(self, slice_, graph, tracker): - super().__init__(tracker) - self.value = slice_ - self.graph = graph - - def __repr__(self) -> str: - return f"SliceVariable({self.value})" - - def get_value(self): - return self.value - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, slice): - return SliceVariable(value, graph, tracker) - return None - - -class ModuleVariable(VariableBase): - def __init__(self, func, graph, tracker): - super().__init__(tracker) - self.value = func - self.graph = graph - - def get_value(self): - return self.value - - def __getattr__(self, name: str): - if not hasattr(self.value, name): - raise InnerError(f"ModuleVariable {self} has no attribute {name}") - attr = getattr(self.value, name) - return VariableFactory.from_value( - attr, self.graph, tracker=GetAttrTracker(self, name) - ) - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, types.ModuleType): - return ModuleVariable(value, graph, tracker) - return None - - -class ObjectVariable(VariableBase): - def __init__(self, obj, graph, tracker): - super().__init__(tracker) - self.value = obj - self.graph = graph - - def __repr__(self) -> str: - return f"ObjectVariable({self.value})" - - def __getattr__(self, name: str): - if not hasattr(self.value, name): - raise InnerError(f"ObjectVariable {self} has no attribute {name}") - attr = getattr(self.value, name) - return VariableFactory.from_value( - attr, self.graph, tracker=GetAttrTracker(self, name) - ) - - -class IterVariable(VariableBase): - def __init__(self, obj, graph, tracker): - super().__init__(tracker) - self.hold = obj - self.graph = graph - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, collections.Iterable): - return UserDefinedIterVariable(value, graph, tracker) - return None - - -class SequenceIterVariable(IterVariable): - def __init__(self, obj, graph, tracker): - super().__init__(obj, graph, tracker) - self.idx = 0 - - def next(self): - if self.idx < len(self.hold): - val = self.hold[self.idx] - new_iter = SequenceIterVariable( - self.hold, self.graph, DummyTracker([self]) - ) - new_iter.idx = self.idx + 1 - return val, new_iter - else: - raise StopIteration() - - -class DictIterVariable(IterVariable): - def __init__(self, obj, graph, tracker): - super().__init__(obj, graph, tracker) - self.key_list = list(self.hold) - self.idx = 0 - - def next(self): - if self.idx < len(self.key_list): - val = self.key_list[self.idx] - new_iter = DictIterVariable( - self.hold, self.graph, DummyTracker([self]) - ) - new_iter.idx = self.idx + 1 - return val, new_iter - else: - raise StopIteration() - - -class TensorIterVariable(IterVariable): - def __init__(self, obj, graph, tracker): - super().__init__(obj, graph, tracker) - - -# what UserDefinedIterVariable holds doesn't matter, because use user defined iterator will trigger break graph -class UserDefinedIterVariable(IterVariable): - def __init__(self, obj, graph, tracker): - super().__init__(obj, graph, tracker) diff --git a/python/paddle/jit/symbolic_trace/opcode_translator/instruction_utils/__init__.py b/python/paddle/jit/symbolic_trace/opcode_translator/instruction_utils/__init__.py deleted file mode 100644 index 7e0b7e98a1696..0000000000000 --- a/python/paddle/jit/symbolic_trace/opcode_translator/instruction_utils/__init__.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .instruction_utils import ( - Instruction, - convert_instruction, - gen_instr, - get_instructions, - instrs_info, - modify_extended_args, - modify_instrs, - modify_vars, - relocate_jump_target, - replace_instr, - reset_offset, -) - -__all__ = [ - "Instruction", - "convert_instruction", - "gen_instr", - "get_instructions", - "modify_instrs", - "modify_vars", - "reset_offset", - "relocate_jump_target", - "modify_extended_args", - "replace_instr", - "instrs_info", -] diff --git a/python/paddle/jit/symbolic_trace/opcode_translator/instruction_utils/instruction_utils.py b/python/paddle/jit/symbolic_trace/opcode_translator/instruction_utils/instruction_utils.py deleted file mode 100644 index d1264cdff6a71..0000000000000 --- a/python/paddle/jit/symbolic_trace/opcode_translator/instruction_utils/instruction_utils.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import dataclasses -import dis -from typing import Any - -from .opcode_info import ABS_JUMP, ALL_JUMP, REL_JUMP - - -@dataclasses.dataclass -class Instruction: - opcode: int - opname: str - arg: int | None - argval: Any - offset: int | None = None - starts_line: int | None = None - is_jump_target: bool = False - jump_to: Instruction | None = None - is_generated: bool = True - - # for analys EXTENDED_ARG - first_ex_arg: Instruction | None = None - ex_arg_for: Instruction | None = None - - # used in modify_extended_args - def __hash__(self): - return id(self) - - -def gen_instr(name, arg=None, argval=None, gened=True, jump_to=None): - return Instruction( - opcode=dis.opmap[name], - opname=name, - arg=arg, - argval=argval, - is_generated=gened, - jump_to=jump_to, - ) - - -def convert_instruction(instr): - return Instruction( - instr.opcode, - instr.opname, - instr.arg, - instr.argval, - instr.offset, - instr.starts_line, - instr.is_jump_target, - jump_to=None, - is_generated=False, - ) - - -def get_instructions(code): - # instrs do not contain EXTENDED_ARG - instrs = list(map(convert_instruction, dis.get_instructions(code))) - for instr in instrs: - # for 3.8, see dis.py - if instr.opname in ALL_JUMP: - if instr.opname in REL_JUMP: - origin_jump_target = instr.offset + 2 + instr.arg - - elif instr.opname in ABS_JUMP: - origin_jump_target = instr.arg - - jump_offset = origin_jump_target - while instrs[jump_offset // 2].opname == "EXTENDED_ARG": - jump_offset += 2 - - if origin_jump_target != jump_offset: - # copy infos from EXETENDED_ARG to other opcode - if instrs[origin_jump_target // 2].is_jump_target: - instrs[jump_offset // 2].is_jump_target = instrs[ - origin_jump_target // 2 - ].is_jump_target - if instrs[origin_jump_target // 2].starts_line: - instrs[jump_offset // 2].starts_line = instrs[ - origin_jump_target // 2 - ].starts_line - - instr.jump_to = instrs[jump_offset // 2] - - ''' - if the origin opcode contains EXTENDED_ARG, it should be like: - >> EXTENDED_ARG 1 - XX 388 <- 256 + 132 - filter all EXTENDED_ARG here - ''' - instrs = [x for x in instrs if x.opname != "EXTENDED_ARG"] - return instrs - - -''' - modify instructions: - 1. reset offset - 2. relocate jump target - 3. add EXTENDED_ARG instruction if needed -''' - - -def modify_instrs(instructions): - modify_completed = False - while not modify_completed: - reset_offset(instructions) - relocate_jump_target(instructions) - modify_completed = modify_extended_args(instructions) - - -def reset_offset(instructions): - for idx, instr in enumerate(instructions): - instr.offset = idx * 2 - - -def relocate_jump_target(instuctions): - extended_arg = [] - for instr in instuctions: - if instr.opname == "EXTENDED_ARG": - extended_arg.append(instr) - continue - - if instr.opname in ALL_JUMP: - # if jump target has extended_arg, should jump to the first extended_arg opcode - jump_target = ( - instr.jump_to.offset - if instr.jump_to.first_ex_arg is None - else instr.jump_to.first_ex_arg.offset - ) - - if instr.opname in REL_JUMP: - new_arg = jump_target - instr.offset - 2 - elif instr.opname in ABS_JUMP: - new_arg = jump_target - - if extended_arg: - instr.arg = new_arg & 0xFF - new_arg = new_arg >> 8 - for ex in reversed(extended_arg): - ex.arg = new_arg & 0xFF - new_arg = new_arg >> 8 - - # need more extended_args instr - # set arg in the first extended_arg - if new_arg > 0: - extended_arg[0].arg += new_arg << 8 - else: - instr.arg = new_arg - - extended_arg.clear() - - -def modify_extended_args(instructions): - modify_completed = True - extend_args_record = {} - for instr in instructions: - if instr.arg and instr.arg >= 256: # more than one byte - _instrs = [ - instr - ] # replace instr with _instrs later (it is a set of instrs), all operations will be recorded in extend_args_record - val = instr.arg - instr.arg = val & 0xFF - val = val >> 8 - while val > 0: - _instrs.append(gen_instr("EXTENDED_ARG", arg=val & 0xFF)) - val = val >> 8 - - extend_args_record.update({instr: list(reversed(_instrs))}) - - if extend_args_record: - # if new EXTENDED_ARG inserted, we need update offset and jump target - modify_completed = False - - def bind_ex_arg_with_instr(ex_arg, instr): - # move opcode info to EXTENDED_ARG - ex_arg.starts_line = instr.starts_line - instr.starts_line = None - ex_arg.is_jump_target = instr.is_jump_target - instr.is_jump_target = False - - if instr.ex_arg_for is not None: - # instr is also an ex_arg for another instr - instr.ex_arg_for.first_ex_arg = ex_arg - ex_arg.ex_arg_for = instr.ex_arg_for - instr.ex_arg_for = None - else: - instr.first_ex_arg = ex_arg - ex_arg.ex_arg_for = instr - - for key, val in extend_args_record.items(): - bind_ex_arg_with_instr(val[0], key) - replace_instr(instructions, instr=key, new_instr=val) - - return modify_completed - - -def modify_vars(instructions, code_options): - co_names = code_options['co_names'] - co_varnames = code_options['co_varnames'] - for instrs in instructions: - if instrs.opname == 'LOAD_FAST' or instrs.opname == 'STORE_FAST': - instrs.arg = co_varnames.index(instrs.argval) - elif instrs.opname == 'LOAD_GLOBAL': - instrs.arg = co_names.index(instrs.argval) - - -''' - utils -''' - - -def replace_instr(instructions, instr, new_instr): - idx = instructions.index(instr) - instructions[idx, idx + 1] = new_instr - - -def instrs_info(instrs): - ret = [] - for idx, instr in enumerate(instrs): - if instr.starts_line is not None: - ret.append("") - ret.append( - "{line:<8s}{is_jump_target:>2s}{offset:>4d} {opname:<30s}{arg:<4s}{argval}".format( - line=str(instr.starts_line) if instr.starts_line else "", - is_jump_target=">>" if instr.is_jump_target else " ", - offset=instr.offset - if instr.offset or instr.offset == 0 - else -1, - opname=instr.opname, - arg=str(instr.arg) if instr.arg else "", - argval=f"({instr.argval})" if instr.argval else "", - ) - ) - return "\n".join(ret) diff --git a/python/paddle/jit/symbolic_trace/opcode_translator/instruction_utils/opcode_analysis.py b/python/paddle/jit/symbolic_trace/opcode_translator/instruction_utils/opcode_analysis.py deleted file mode 100644 index 0a3a3e1599f0b..0000000000000 --- a/python/paddle/jit/symbolic_trace/opcode_translator/instruction_utils/opcode_analysis.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import dis - -# TODO: Refactor this file - - -HASLOCAL_OPCODES = set(dis.haslocal) -HASFREE_OPCODES = set(dis.hasfree) -COMPARE_OPCODES = set(dis.cmp_op) -HASJREL_OPCODES = set(dis.hasjrel) -HASJABS_OPCODES = set(dis.hasjabs) -JUMP_OPCODES = HASJREL_OPCODES | HASJABS_OPCODES - - -def calc_offset_from_bytecode_offset(bytecode_offset: int) -> int: - # Calculate the index from bytecode offset, because it have 2 bytes per instruction - # TODO: Change this for Python 3.11+. - return bytecode_offset // 2 - - -def calc_jump_target( - instructions: list[dis.Instruction], current_instr_idx: int -) -> int: - """ - Handle the case where the jump target is in the middle of an extended arg. - """ - num_instr = len(instructions) - # For each opcode, at most three prefixal EXTENDED_ARG are allowed, so we - # need to check at most 4 instructions. - # See more details in https://docs.python.org/3.10/library/dis.html#opcode-EXTENDED_ARG - for i in range(current_instr_idx, min(current_instr_idx + 4, num_instr)): - if instructions[i].opcode != dis.EXTENDED_ARG: - return i - else: - raise ValueError("Could not find jump target") - - -def read_write_analysis( - instructions: list[dis.Instruction], - current_instr_idx: int, - stop_instr_idx: int = None, -): - writes = set() - reads = set() - visited = set() - - def walk(start): - end = len(instructions) if stop_instr_idx is None else stop_instr_idx - for i in range(start, end): - if i in visited: - continue - visited.add(i) - - instr = instructions[i] - if instr.opcode in HASLOCAL_OPCODES | HASFREE_OPCODES: - if ( - instr.opname.startswith("LOAD") - and instr.argval not in writes - ): - reads.add(instr.argval) - elif instr.opname.startswith("STORE"): - writes.add(instr.argval) - elif instr.opcode in JUMP_OPCODES: - target_idx = calc_offset_from_bytecode_offset(instr.argval) - target_idx = calc_jump_target(instructions, target_idx) - # Fork to two branches, jump or not - walk(target_idx) - - walk(current_instr_idx) - return reads diff --git a/python/paddle/jit/symbolic_trace/opcode_translator/instruction_utils/opcode_info.py b/python/paddle/jit/symbolic_trace/opcode_translator/instruction_utils/opcode_info.py deleted file mode 100644 index 37fc7bfbd3b8a..0000000000000 --- a/python/paddle/jit/symbolic_trace/opcode_translator/instruction_utils/opcode_info.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import opcode - -UNARY = { - "UNARY_POSITIVE", - "UNARY_NEGATIVE", - "UNARY_NOT", - "UNARY_INVERT", -} - -BINARY = { - "BINARY_MATRIX_MULTIPLY", - "BINARY_POWER", - "BINARY_MULTIPLY", - "BINARY_MODULO", - "BINARY_ADD", - "BINARY_SUBTRACT", - "BINARY_SUBSCR", - "BINARY_FLOOR_DIVIDE", - "BINARY_TRUE_DIVIDE", - "BINARY_LSHIFT", - "BINARY_RSHIFT", - "BINARY_AND", - "BINARY_XOR", - "BINARY_OR", -} - -INPLACE = { - "INPLACE_MATRIX_MULTIPLY", - "INPLACE_FLOOR_DIVIDE", - "INPLACE_TRUE_DIVIDE", - "INPLACE_ADD", - "INPLACE_SUBTRACT", - "INPLACE_MULTIPLY", - "INPLACE_MODULO", - "INPLACE_POWER", - "INPLACE_LSHIFT", - "INPLACE_RSHIFT", - "INPLACE_AND", - "INPLACE_XOR", - "INPLACE_OR", -} - -CALL = { - "CALL_FUNCTION", - "CALL_FUNCTION_KW", - "CALL_FUNCTION_EX", - "CALL_METHOD", -} - -COMPARE = { - "COMPARE_OP", -} - -IMPORT = { - "IMPORT_FROM", -} - -ITER = { - "FOR_ITER", -} - -LOAD = { - "LOAD_BUILD_CLASS", - "LOAD_CONST", - "LOAD_NAME", - "LOAD_ATTR", - "LOAD_GLOBAL", - "LOAD_FAST", - "LOAD_CLOSURE", - "LOAD_DEREF", - "LOAD_CLASSDEREF", - "LOAD_METHOD", -} - -MAKE_FUNCTION = { - "MAKE_FUNCTION", -} - -UNPACK = { - "UNPACK_SEQUENCE", - "UNPACK_EX", -} - - -PUSH_ONE = ( - UNARY - | BINARY - | INPLACE - | CALL - | COMPARE - | IMPORT - | ITER - | LOAD - | MAKE_FUNCTION -) -PUSH_ARG = UNPACK - -ALL_WITH_PUSH = PUSH_ONE | PUSH_ARG - -REL_JUMP = {opcode.opname[x] for x in opcode.hasjrel} -ABS_JUMP = {opcode.opname[x] for x in opcode.hasjabs} -ALL_JUMP = REL_JUMP | ABS_JUMP - -RETURN = { - "RETURN_VALUE", -} diff --git a/python/paddle/jit/symbolic_trace/opcode_translator/skip_files.py b/python/paddle/jit/symbolic_trace/opcode_translator/skip_files.py deleted file mode 100644 index c878b6d7d312f..0000000000000 --- a/python/paddle/jit/symbolic_trace/opcode_translator/skip_files.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc -import codecs -import collections -import contextlib -import copy -import copyreg -import dataclasses -import enum -import functools -import importlib -import inspect -import linecache -import logging -import multiprocessing -import operator -import os -import posixpath -import random -import re -import selectors -import signal -import sre_compile -import sre_parse -import sys -import tempfile -import threading -import tokenize -import traceback -import types -import typing -import unittest -import uuid -import weakref - -import _collections_abc -import _weakrefset -import decorator -import numpy - - -def _strip_init_py(s): - return re.sub(r"__init__.py$", "", s) - - -def _module_dir(m: types.ModuleType): - return _strip_init_py(m.__file__) - - -skip_file_names = { - _module_dir(m) - for m in ( - abc, - collections, - contextlib, - copy, - copyreg, - dataclasses, - enum, - functools, - importlib, - inspect, - linecache, - logging, - multiprocessing, - numpy, - operator, - os, - posixpath, - random, - re, - selectors, - sre_compile, - sre_parse, - signal, - tempfile, - threading, - tokenize, - traceback, - types, - typing, - unittest, - weakref, - _collections_abc, - _weakrefset, - decorator, - codecs, - uuid, - ) -} - - -symbolic_trace_path = os.path.dirname(__file__).rpartition("/")[0] + "/" -paddle_path = sys.modules["paddle"].__file__.rpartition("/")[0] + "/" - -skip_file_names.add(symbolic_trace_path) -skip_file_names.add(paddle_path) -skip_file_names.add( - "") - -skip_file_name_re = re.compile( - f"^({'|'.join(map(re.escape, skip_file_names))})" -) - - -def need_skip_path(filepath): - if not filepath.startswith("<"): - filepath = os.path.abspath(filepath) - return bool(skip_file_name_re.match(filepath)) diff --git a/python/paddle/jit/symbolic_trace/opcode_translator/transform.py b/python/paddle/jit/symbolic_trace/opcode_translator/transform.py deleted file mode 100644 index d3e5140378c1f..0000000000000 --- a/python/paddle/jit/symbolic_trace/opcode_translator/transform.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import dis - -from ..utils import log, log_do -from .executor.opcode_executor import InstructionTranslatorCache -from .skip_files import need_skip_path - - -def eval_frame_callback(frame): - # is generator - if frame.f_code.co_flags & 0x20 > 0: - return None - - if not need_skip_path(frame.f_code.co_filename): - log( - 2, - "[eval_frame_callback] start to translate: " - + frame.f_code.co_name - + "\n", - ) - - log(8, "[transform_opcode] old_opcode: " + frame.f_code.co_name + "\n") - log_do(8, lambda: dis.dis(frame.f_code)) - - new_code = InstructionTranslatorCache()(frame) - - log( - 7, - "\n[transform_opcode] new_opcode: " + frame.f_code.co_name + "\n", - ) - if new_code is not None: - log_do(7, lambda: dis.dis(new_code.code)) - else: - log_do(7, f"Skip frame: {frame.f_code.co_name}") - - return new_code - return None diff --git a/python/paddle/jit/symbolic_trace/proxy_tensor.py b/python/paddle/jit/symbolic_trace/proxy_tensor.py deleted file mode 100644 index a6b1bcfe589b3..0000000000000 --- a/python/paddle/jit/symbolic_trace/proxy_tensor.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import paddle - -from .infer_meta import MetaInfo -from .utils import NameGenerator, Singleton, log - - -# global variables -@Singleton -class ProxyTensorContext: - def __init__(self): - self.reset() - - def reset(self): - self.runtime_name_to_proxy_tensor: dict[str, ProxyTensor] = {} - self.runtime_proxy_tensor_to_name: dict[int, str] = {} - self.tensor_to_proxy_tensor: dict[int, ProxyTensor] = {} - self.var_name_generator = NameGenerator("var_") - - def new_varname(self): - return self.var_name_generator.next() - - def from_tensor(self, tensor) -> ProxyTensor: - # TODO: don't have the same name. - if self.tensor_to_proxy_tensor.get(id(tensor), None) is not None: - return self.tensor_to_proxy_tensor[id(tensor)] - - # TODO(id may have collision) - name = self.new_varname() - proxy_tensor = ProxyTensor(name, MetaInfo.from_tensor(tensor)) - self.tensor_to_proxy_tensor[id(tensor)] = proxy_tensor - proxy_tensor.set_value(tensor) - return proxy_tensor - - def bind_name_to_proxy_tensor(self, name, proxy_tensor): - self.runtime_name_to_proxy_tensor[name] = proxy_tensor - self.runtime_proxy_tensor_to_name[id(proxy_tensor)] = name - - def clear_proxy_tensor_by_name(self, name): - log(3, f"[GC] trying to GC {name}\n") - proxy_tensor = self.runtime_name_to_proxy_tensor[name] - proxy_tensor_id = id(proxy_tensor) - has_value = proxy_tensor.value() is not None - eager_tensor_id = id(proxy_tensor.value()) - - del self.runtime_name_to_proxy_tensor[name] - del self.runtime_proxy_tensor_to_name[proxy_tensor_id] - if has_value and eager_tensor_id in self.tensor_to_proxy_tensor: - del self.tensor_to_proxy_tensor[eager_tensor_id] - log(3, f"[GC] {name} GCed\n") - - def get_runtime(self): - return self.runtime_name_to_proxy_tensor - - -class ProxyTensor: - def __init__(self, name, meta): - self.name: str = name - self.meta: MetaInfo = meta - self.value_: paddle.Tensor = None - ProxyTensorContext().bind_name_to_proxy_tensor(name, self) - - @property - def shape(self): - # TODO(xiongkun) consider dynamic shape. - return self.meta.shape - - @property - def ndim(self): - return len(self.meta.shape) - - @property - def dtype(self): - return self.meta.dtype - - def set_value(self, value): - """ - value is a eager tensor. - when a proxytensor have value, it means it can be evaluated outer to_static. - """ - self.value_ = value - - def clear_value(self): - self.value_ = None - - def value(self): - return self.value_ diff --git a/python/paddle/jit/symbolic_trace/symbolic/compile_cache.py b/python/paddle/jit/symbolic_trace/symbolic/compile_cache.py deleted file mode 100644 index 84e5d85c8ecf3..0000000000000 --- a/python/paddle/jit/symbolic_trace/symbolic/compile_cache.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle - -from ..utils import Cache, Singleton -from .interpreter import compile_sir - - -@Singleton -class CompileSIRCache(Cache): - def __init__(self): - super().__init__(weak=False) - - def key_fn(self, context, sir_name): - sir = context.get_sir(sir_name) - hash_key = hash(str(sir)) - return hash_key - - def value_fn(self, context, sir_name): - return paddle.jit.to_static( - compile_sir(context, sir_name), enable_fallback=False - ) diff --git a/python/paddle/jit/symbolic_trace/symbolic/interpreter.py b/python/paddle/jit/symbolic_trace/symbolic/interpreter.py deleted file mode 100644 index f828d530b1e35..0000000000000 --- a/python/paddle/jit/symbolic_trace/symbolic/interpreter.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle - -from ..utils import map_if -from .statement_ir import SIRRuntimeCache, Symbol - - -def replace_symbol(values, state): - return map_if( - values, - pred=lambda x: isinstance(x, Symbol), - true_fn=lambda x: state[x.name], - false_fn=lambda x: x, - ) - - -class Interpreter: - def __init__(self, symbolic_context): - self._context = symbolic_context - - def get_sir(self, name): - return self._context.get_sir(name) - - def run_sir(self, name, state): - SIR = self.get_sir(name) - gc_pass(SIR) - for stmt in SIR.statements: - inputs = replace_symbol(stmt.inputs, state) - outs = getattr(self, stmt.type)(stmt, inputs) - - def _set(v, s): - state[s.name] = v - - map_if( - outs, - stmt.outputs, - pred=lambda v, s: isinstance(s, Symbol), - true_fn=lambda v, s: _set(v, s), - false_fn=lambda v, s: None, - ) - # fetch outputs - return replace_symbol(SIR.outputs, state) - - def call(self, stmt, inputs): - SIR = self.get_sir(stmt.name) - state = prepare_state(SIR, inputs) - return self.run_sir(stmt.name, state) - - def api(self, stmt, inputs): - args, kwargs = inputs - return stmt.name(*args, **kwargs) - - def method(self, stmt, inputs): - args, kwargs = inputs - var = args[0] - return getattr(var, stmt.name)(*args[1:], **kwargs) - - def layer(self, stmt, inputs): - args, kwargs = inputs - layer, args = args[0], args[1:] - return layer(*args, **kwargs) - - def delete(self, stmt, inputs): - pass - - -def gc_pass(sir): - pass - - -def compile_sir(context, name): - @paddle.jit.not_to_static - def wrapper(args): - """ - This function will be decorated by paddle.to_static. - so the args is variables, not eager tensors. - """ - interpreter = Interpreter(context) - SIR = interpreter.get_sir(name) - state = prepare_state(SIR, args) - return interpreter.run_sir(name, state) - - return wrapper - - -def prepare_state(SIR, inputs): - state = {} - - # update free vars if exsits - if SIRRuntimeCache().has_key(SIR.name): - free_var_seeker = SIRRuntimeCache().get_free_vars(SIR.name) - if free_var_seeker: - state = free_var_seeker() - - # bind inputs - for sir_inp, inp in zip(SIR.inputs, inputs): - state[sir_inp.name] = inp - - return state diff --git a/python/paddle/jit/symbolic_trace/symbolic/statement_ir.py b/python/paddle/jit/symbolic_trace/symbolic/statement_ir.py deleted file mode 100644 index 3b4b377851aaa..0000000000000 --- a/python/paddle/jit/symbolic_trace/symbolic/statement_ir.py +++ /dev/null @@ -1,235 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -THIS FILE IS PRIVATE !! - -use interface in symbolic_context.py first. -""" -from __future__ import annotations - -from copy import deepcopy - -from paddle.utils import flatten, is_sequence, map_structure - -from ..utils import NameGenerator, Singleton - - -class Symbol: - """ - we need this class to distinguish the string and `math variable` - """ - - def __init__(self, name): - self.name = name - - def __str__(self): - return self.name - - def __repr__(self): - return str(self) - - def __eq__(self, other): - if isinstance(other, str): - return self.name == other - return self.name == other.name - - def __hash__(self): - return hash(self.name) - - def __deepcopy__(self, memo=None): - return Symbol(self.name) - - -class Statement: - def __init__(self, type, name, inputs, outputs): - assert type in ["call", "api", "method", "layer"] - self.name = name - self.inputs = inputs # (list of Symbols, dict of Symbols) - self.outputs = outputs # list of Symbol | PythonObj - self.type = type - - def __deepcopy__(self, memo=None): - return Statement( - self.type, self.name, deepcopy(self.inputs), deepcopy(self.outputs) - ) - - def __str__(self): - def to_string(inps): - if isinstance(inps, str) or not is_sequence(inps): - return inps.__str__() - inps = (x.__str__() for x in inps) - return ", ".join(inps) - - name = ( - self.name - if isinstance(self.name, str) - else "paddle." + self.name.__name__ - ) - return "{} || {} = {} ({}) ".format( - self.type + " " * (10 - len(self.type)), - to_string(self.outputs), - name, - to_string(self.inputs), - ) - - def __repr__(self): - return self.__str__() - - -class StatementIR: - """ - Don't create by yourself, just use the StatementIRCache.get() - """ - - def __init__(self, name): - self.name = name - self.inputs = [] # list of Symbol | PythonObj - self.outputs = [] # list of Symbol | PythonObj - self.statements = [] # list of Statement - - def __deepcopy__(self, memo=None): - new_sir = StatementIR(self.name) - new_sir.inputs = deepcopy(self.inputs) - new_sir.outputs = deepcopy(self.outputs) - new_sir.statements = deepcopy(self.statements) - return new_sir - - def add_input(self, input): - self.inputs.append(input) - - def add_output(self, output): - self.outputs.append(output) - - def add_statement(self, statement): - assert isinstance(statement, Statement) - self.statements.append(statement) - - def analyse_inputs(self): - used_symbols = set() - generated_symbols = set() - for stmt in self.statements: - for inp in flatten(stmt.inputs): - if isinstance(inp, Symbol): - used_symbols.add(inp) - for out in flatten(stmt.outputs): - if isinstance(out, Symbol): - generated_symbols.add(out) - - input_symbols = list(used_symbols - generated_symbols) - input_symbols = sorted(input_symbols, key=lambda x: x.name) - return input_symbols - - def __str__(self): - strs = [] - strs.append("StatmentIR: %s" % self.name) - strs.append(f" inputs: {map_structure(lambda x: x.name, self.inputs)}") - strs.append( - f" outputs: {map_structure(lambda x: x.name, self.outputs)}" - ) - strs.append(" statements: ") - for stmt in self.statements: - strs.append(f" {stmt}") - return "\n".join(strs) - - def __repr__(self): - return self.__str__() - - -@Singleton -class StatementIRFactory: - def __init__(self): - self.cache = {} - self.name_generator = NameGenerator("SIR_") - - def __getitem__(self, key): - return self.cache[key] - - def create(self, input_name=None): - if input_name: - name = input_name - else: - name = self.name_generator.next() - - sir = StatementIR(name) - self.cache[name] = sir - return sir - - def update(self, stmt_ir): - name = stmt_ir.name - self.cache[name] = stmt_ir - - def clear(self): - want_clear = [ - key - for key in self.cache.keys() - if self.name_generator.match_name(key) - ] - for key in want_clear: - del self.cache[key] - - -@Singleton -class SIRRuntimeCache: - def __init__(self): - self.cache = {} - # { name : (inputs, outputs, free_vars) } - # inputs : can be used when call_SIR, if free_vars exist - # outputs : used for generator new ProxyTensor output before fallback - # free_vars: (name, function) - - def __getitem__(self, key): - return self.cache[key] - - def has_key(self, key): - return key in self.cache.keys() - - def set_origin_inputs(self, key, inputs): - if key in self.cache.keys(): - val = self.cache[key] - self.cache[key] = (inputs, val[1], val[2]) - else: - self.cache[key] = (inputs, None, None) - - def set_origin_outputs(self, key, outputs): - if key in self.cache.keys(): - val = self.cache[key] - self.cache[key] = (val[0], outputs, val[2]) - else: - self.cache[key] = (None, outputs, None) - - def set_free_vars(self, key, free_vars): - if key in self.cache.keys(): - val = self.cache[key] - self.cache[key] = (val[0], val[1], free_vars) - else: - self.cache[key] = (None, None, free_vars) - - def get_origin_inputs(self, key): - if key in self.cache.keys(): - return self.cache[key][0] - else: - return None - - def get_origin_outputs(self, key): - if key in self.cache.keys(): - return self.cache[key][1] - else: - return None - - def get_free_vars(self, key): - if key in self.cache.keys(): - return self.cache[key][2] - else: - return None diff --git a/python/paddle/jit/symbolic_trace/symbolic/symbolic_context.py b/python/paddle/jit/symbolic_trace/symbolic/symbolic_context.py deleted file mode 100644 index e618316b4768a..0000000000000 --- a/python/paddle/jit/symbolic_trace/symbolic/symbolic_context.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import paddle - -from ..utils import NameGenerator, log -from .compile_cache import CompileSIRCache -from .statement_ir import Statement, StatementIR, StatementIRFactory, Symbol - - -class SymbolicTraceContext: - def __init__(self): - self.reset() - - def reset(self): - self.statement_factory = StatementIRFactory() - self.statement_factory.clear() - self.sir_stack = [self.statement_factory.create()] - self.layer_name_generator = NameGenerator("layer_") - - @property - def TOS(self): - return self.sir_stack[-1] - - def new_layername(self): - return self.layer_name_generator.next() - - def call_SIR(self, sirname, inputs, outputs): - stmt = Statement("call", sirname, inputs, outputs) - self.TOS.add_statement(stmt) - - def call_API(self, api, inputs, outputs): - assert callable(api), "call_API must receive a paddle api." - stmt = Statement("api", api, inputs, outputs) - self.TOS.add_statement(stmt) - - def call_METHOD(self, method_name, inputs, outputs): - assert isinstance( - method_name, str - ), "call_METHOD must method api name. string." - assert isinstance( - inputs[0][0], Symbol - ), "call_METHOD must first augument must be Symbol Variable." - stmt = Statement("method", method_name, inputs, outputs) - self.TOS.add_statement(stmt) - - def call_LAYER(self, layer_name, inputs, outputs): - stmt = Statement("layer", layer_name, inputs, outputs) - self.TOS.add_statement(stmt) - - def get_sir(self, name): - return self.statement_factory[name] - - def reset_TOS(self): - self.sir_stack.pop() - self.sir_stack.append(self.statement_factory.create()) - - def replace_TOS(self, sir): - """Use deepcopyed sir to replace the TOS. - This function will update statment_factory. - """ - self.sir_stack.pop() - self.sir_stack.append(sir) - self.statement_factory.update(sir) - - def compile_do_nothing(self, ret_vals): - def dummy_func(*args, **kwargs): - return [] - - # return None function - dummy_stmt_ir = StatementIR("dummy_func") - dummy_stmt_ir.outputs = [] - dummy_stmt_ir.inputs = [] - return dummy_func, dummy_stmt_ir - - def compile_fn(self, ret_vals): - """ - start compile and return the python function, which must can be to_static without errors. - """ - cur_sir: StatementIR = self.TOS - # step0: if no statement, return a dummy function - if len(cur_sir.statements) == 0: - return self.compile_do_nothing(ret_vals) - # step1: analyse sir inputs and outputs - cur_sir.inputs = cur_sir.analyse_inputs() - # TODO: output analysis - cur_sir.outputs = paddle.utils.map_structure( - lambda x: Symbol(x.name), ret_vals - ) - log(1, "start subgraph compile and execution.\n") - log(1, self.TOS, "\n") - # step2: call compile_sir and get python function, third cache is triggered here. - static_func = CompileSIRCache()(self, cur_sir.name) - # step3: GC and reset TOS - # self.reset_TOS() - - return static_func, cur_sir diff --git a/python/paddle/jit/symbolic_trace/trace.py b/python/paddle/jit/symbolic_trace/trace.py deleted file mode 100644 index bb1d491dd7be6..0000000000000 --- a/python/paddle/jit/symbolic_trace/trace.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle - -from .opcode_translator import eval_frame_callback -from .proxy_tensor import ProxyTensorContext - - -def symbolic_trace(func): - def impl(*args, **kwargs): - ProxyTensorContext().reset() - paddle.fluid.core.set_eval_frame(eval_frame_callback) - try: - outs = func(*args, **kwargs) - except Exception as e: - raise e - finally: - paddle.fluid.core.set_eval_frame(None) - return outs - - return impl diff --git a/python/paddle/jit/symbolic_trace/utils/.utils.py.swp b/python/paddle/jit/symbolic_trace/utils/.utils.py.swp deleted file mode 100644 index 3f2b59a3680ec8bb64f1e0247ea453c89d1a1d67..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 16384 zcmeI2O^h5z6~`;F`HF24B_{%widn!t&Un_wHnKNFLuPnYI?eBb~@YB zJ?^gF-LYkI2p24oa6!NcBRQ7@DTxG1Od^J#p|b-`(8GaNNb%bN4QK&#|}HZd_q3jv`5yZ*}b0 z_K97)cei)#a<+RnD?8nddScphq%*zhN8#d98cz2e&kMwsWYB9zf$wgSvEz#A?9|q$ zsq}+nT75SAgPS!`-K#uMd0>(UQrX$EYXjT<@onl+pQ?SB-*fjTCzY(ORUW83P7yhqu?Z%1G8X1*aU6^f7k%|;0NG&@FX}2J_bGt zJ_7Cpe_qenPrxPcGPnp*umBE#Iq(7S);h-i1l|O{0I!2*fCLG!!5;7+cmQkyf4Kv) z!6ooAcnMqpkAtJ&A+Q(h0h_=--pANw@J(^M64bzYunxTSUg!ee1U~{#fdQBS z7Pt$n0pEBJV_yYr;DAqq{eXja-p$wx;5--r3)~BS`7Xv@1Wy77%z`G^2tEWp2pIVJ zTE>0`u7fM!E8xrEOF)3bU?;c-{0#}7zk=7m6>u4R4SWuq29JS9z{6kzSP#~LKVsbc z3Va=01T>D8!7UDs!|!lE@xw$qp)0IJ#;GgQSTwj5_=#+Anf3!wueTT#uGMPlI|_pV zPttxriY2FVe42Z{E4dST+?NRtMIwdAb8&fXPjq-E7UGPs^P5GhXa-Kl5eq00qKTf> z2^=X!X!V?asfwTAUnwcn+?h-#=U6Qa>iK7d27_r-!I>QI7@5jQkHaqcRVvHme zjrjaJU2~%}l(yfoN=-IY^-;BlRgFbIz~5 z)9#QBq}e13(9c?UgJ+728r5|YbVO`L?Z+FuX7}T0bzq}ok|?%ob!NO)zDz|&rF7EJ z<$g%51Z~G%vT5{WmGg^E-?ymbXze+2l5$NDEm}LF2UnLK3!wpds(cu)&BHKOzD~dM z>U4LkHayykyfjc_w24O}mdT1QyB5Yu$6tk3nY>OE(+I#tGD`oA%7*r~sju%9<@0tk z0p)LM0)3sN0iGbvTYSW*aTApX$|ADgN|bGfPEXjj5`~V1Lz~$(xDIBstUP^_Zutec zAnS0dM?G$&)k@}P=5%~K_u6{ZjKm+yxY5CQ!iG7Gl3TR`g-M0O`Anv2R#j;=Nu#U) z>7)i=Xq!IJcP0)Ze_|nyp3DYksUOwwB=68+MpQ8h>n+ZwjN#3B^(^qD8>NBA!$|VB z;487y?+dTN+o|N;XoWlN2)`oP6){hInL6aF z#uS<4oJ0g28;_>a(O(RsSa3Xb98L3{kln~j@>VAaEG-Lq2xzx7!?1bC8HjkIV5O)? zvM0Lmyxf!&E|peoZxRf1+HkzeIFj2@jJ^}VJGvzm#JB|%mcc(&Uv#ukPr z3?CGaAxViH-^?X+CRSb1mm?)q(`Gp8X%v(A=A{+2STL$ahm1>-bjn~OqP zwL?b_*oS6M%pSr6TiaU0s!AH-S+~8&QgSevg!$Y|&Guswq*l=e7oj&(GjrE}LXAv^ zY|$QDOKIf8l&~3!RcZX8oJ^;5lXNGw>BtJh8AmV*4~S62j*Mb!lWo&vOan8tHx*rF z+kWUv+t!l!mIpX>Ws9ky&yT4pb0Nu^b6g0*=-6}I4;KR=qtNVcsA2QW(Hvb~W0XE- zp56ia5FH8LqcvBx=Z$YAdBzSM5piVBQP6YGucKJGe)R925c@#P2yPBhh^x8NbB!nUMY`+`BEBK9(%y9pBYA8JkBN^C*16iLKf^9zdph zt(b&|wdEOv_mrl#2~>cxjv&M-Rd zwD{bPt=kaOQX9HpQ~k+2lWqb4Y#z!H05fgt8m2mAoWblvk%+lnTrovfgtLU`0=@-40}g{j;1O^) zApiWDd{iBk2PzL#9;iG}d7$#Z|Aq%>vuU>OzGwMT^y-u}TTmx;PN_5>`?e4X@^IOxHX3vvpFehR-@JYJ=>Ay~ zcJrATUfWj7QYnbYutlXEys$=vb!t}9b^3@1&#Lcp4en9cfOGYY^F&5{dogxAU%*W$ zS3W;i*UeyO?og6lOEW@7>W;RpGhA^IP#XhvB2z^Z)blz+P@7#9ZYl+*J#!0&3U!W+ zw@p`cyh6o?uOqO*wq6zCTVM(jvg(9UQ9AOwGfHxkgq4CitqY`3iYZG|jAwER!6-^1 z%5o??Gq#RN(5V@%vrw9iXJ%Bs38?~ku!lD)J@kU`X5$!%h9b5`rc0;zDLzE5i4skv zlv63V{8ywci2=PXxTP!=jCPUI2-6-6FJ76b8?NJaMbixgqRK*3eNF{vJg7dvhv|A{ zMt-r7vY=$y(CZZ187`&Ha8;Gg{V?i*Tqw0@TAL_;sGus5890>tATe1ITPep;Ol7Kb zRYw!mND$=rXtpYvFS3V40b`KqiizfoWfE@M@k6$hJBv&ybb%))gtnzeLf}nGq3~Gc zQmz69Q!<+Z=W^jej^Uuw6#Ho0jLgTFTJ-BhMh#A7J&36|zoUoP)Q#Uu z(!7>5_hA~OTnCyXQFIiO2c4Z=YK^~WT-M9XQiCt6uFKSE(n5kPP?4Z9W2;CX=;S+m iQA&&FN@J$P++5PdR97!g>CL4z`p%OkRpSffSN{o)bMjRH diff --git a/python/paddle/jit/symbolic_trace/utils/__init__.py b/python/paddle/jit/symbolic_trace/utils/__init__.py deleted file mode 100644 index 1bb5244181261..0000000000000 --- a/python/paddle/jit/symbolic_trace/utils/__init__.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .exceptions import BreakGraphError, InnerError, UnsupportError -from .utils import ( - ASSERT, - Cache, - NameGenerator, - ResumeFnNameFactory, - Singleton, - count_if, - execute_time, - freeze_structure, - in_paddle_module, - is_fallback_api, - is_paddle_api, - is_proxy_tensor, - is_strict_mode, - list_contain_by_id, - list_find_index_by_id, - log, - log_do, - map_if, - meta_str, - no_eval_frame, - paddle_tensor_method, - show_trackers, -) - -__all__ = [ - "InnerError", - "UnsupportError", - "BreakGraphError", - "Singleton", - "NameGenerator", - "log", - "log_do", - "no_eval_frame", - "is_paddle_api", - "in_paddle_module", - "is_fallback_api", - "is_proxy_tensor", - "map_if", - "count_if", - "freeze_structure", - "Cache", - "execute_time", - "meta_str", - "is_strict_mode", - "paddle_tensor_method", - "ASSERT", - "ResumeFnNameFactory", - "list_contain_by_id", - "list_find_index_by_id", - "show_trackers", -] diff --git a/python/paddle/jit/symbolic_trace/utils/exceptions.py b/python/paddle/jit/symbolic_trace/utils/exceptions.py deleted file mode 100644 index 40edbcc170731..0000000000000 --- a/python/paddle/jit/symbolic_trace/utils/exceptions.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class FallbackErrorBase(Exception): - pass - - -class InnerError(FallbackErrorBase): - pass - - -class UnsupportError(FallbackErrorBase): - pass - - -# raise in inline function call strategy. -class BreakGraphError(FallbackErrorBase): - pass diff --git a/python/paddle/jit/symbolic_trace/utils/monkey_patch.py b/python/paddle/jit/symbolic_trace/utils/monkey_patch.py deleted file mode 100644 index dd9e369b34c39..0000000000000 --- a/python/paddle/jit/symbolic_trace/utils/monkey_patch.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .utils import no_eval_frame - - -# The MoneyPatch module adds methods to a class. -def proxy_tensor_method_builder(method_name): - @no_eval_frame - def __impl__(self, other): - return self.call_method(method_name, self, other) - - return __impl__ - - -def do_monkey_patch(cls, patch_names, method_builder): - for method_name in patch_names: - setattr(cls, method_name, method_builder(method_name)) - - -binary_operator_methods = [ - '__add__', - '__sub__', - '__rsub__', - '__radd__', - '__mul__', - '__rmul__', - '__gt__', - '__xor__', - '__or__', - '__and__', - '__mod__', - '__matmul__', - '__pow__', - '__floordiv__', - '__truediv__', - '__lshift__', - '__rshift__', -] - -unary_operator_methods = [ - '__invert__', - '__neg__', - '__pos__', -] diff --git a/python/paddle/jit/symbolic_trace/utils/paddle_api_config.py b/python/paddle/jit/symbolic_trace/utils/paddle_api_config.py deleted file mode 100644 index e12949731e6d1..0000000000000 --- a/python/paddle/jit/symbolic_trace/utils/paddle_api_config.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os -import sys -import warnings - -paddle_api_file_path = os.path.join( - os.path.dirname(__file__), "paddle_api_info", "paddle_api.json" -) -with open(paddle_api_file_path, "r") as file: - paddle_api = json.load(file) - -# tensor_methods skipped __iadd__ __isub__, because variable do not support inplace operators -paddle_tensor_method_file_path = os.path.join( - os.path.dirname(__file__), "paddle_api_info", "paddle_tensor_method.json" -) -# TODO(Aurelius84): Can we automitically parse the apis list from dir(paddle.tensor). -with open(paddle_tensor_method_file_path, "r") as file: - paddle_tensor_method = json.load(file) - -paddle_api_list = set() -for module_name in paddle_api.keys(): - # it should already be imported - if module_name in sys.modules.keys(): - module = sys.modules[module_name] - apis = paddle_api[module_name] - for api in apis: - if api in module.__dict__.keys(): - obj = module.__dict__[api] - paddle_api_list.add(obj) - else: - warnings.warn(f"{module_name} not imported.") - -# TODO(Aurelius84): It seems that we use it to judge 'in_paddle_module()'. -# Bug what does 'is_paddle_module' really means? Is all paddle.xx sub module -# considered as paddle module? -paddle_api_module_prefix = { - "paddle.nn.functional", - "paddle.nn.layer.activation", -} - -fallback_list = { - print, - # paddle.utils.map_structure, -} diff --git a/python/paddle/jit/symbolic_trace/utils/paddle_api_info/paddle_api.json b/python/paddle/jit/symbolic_trace/utils/paddle_api_info/paddle_api.json deleted file mode 100644 index ff70d9def1297..0000000000000 --- a/python/paddle/jit/symbolic_trace/utils/paddle_api_info/paddle_api.json +++ /dev/null @@ -1,297 +0,0 @@ -{ - "paddle.nn.functional": [ - "normalize", - "max_unpool2d", - "avg_pool2d", - "adaptive_max_pool3d", - "conv1d", - "adaptive_avg_pool3d", - "adaptive_avg_pool1d", - "rrelu", - "dropout", - "diag_embed", - "cosine_similarity", - "fold", - "cross_entropy", - "hardsigmoid", - "mse_loss", - "avg_pool3d", - "tanhshrink", - "hardshrink", - "dropout2d", - "max_unpool1d", - "conv2d", - "upsample", - "conv2d_transpose", - "pad", - "hardswish", - "conv3d_transpose", - "logsigmoid", - "pixel_unshuffle", - "one_hot", - "hardtanh", - "adaptive_avg_pool2d", - "prelu", - "tanh", - "triplet_margin_with_distance_loss", - "softsign", - "local_response_norm", - "gelu", - "adaptive_max_pool2d", - "interpolate", - "relu6", - "max_unpool3d", - "adaptive_max_pool1d", - "leaky_relu", - "linear", - "silu", - "log_softmax", - "sigmoid", - "conv1d_transpose", - "celu", - "softmax", - "avg_pool1d", - "alpha_dropout", - "instance_norm", - "softshrink", - "grid_sample", - "affine_grid", - "elu_", - "softplus", - "relu_", - "elu", - "margin_ranking_loss", - "smooth_l1_loss", - "relu", - "rrelu_", - "conv3d", - "dropout3d", - "selu", - "bilinear", - "nll_loss", - "pixel_shuffle", - "glu" - ], - "paddle": [ - "to_tensor", - "asinh", - "take_along_axis", - "lgamma", - "numel", - "floor", - "frac", - "histogram", - "tile", - "broadcast_to", - "is_tensor", - "matmul", - "log", - "tan", - "isinf", - "triu", - "randn", - "minimum", - "deg2rad", - "no_grad", - "log2", - "rsqrt", - "bincount", - "pow", - "linspace", - "atanh", - "logit", - "logspace", - "lerp", - "prod", - "sign", - "greater_than", - "sum", - "sqrt", - "not_equal", - "rot90", - "as_real", - "multinomial", - "sinh", - "triu_indices", - "as_complex", - "nonzero", - "expm1", - "square", - "tril", - "any", - "empty", - "is_complex", - "amax", - "einsum", - "stack", - "zeros_like", - "argmax", - "cumsum", - "conj", - "index_select", - "inner", - "isnan", - "clone", - "clip", - "bernoulli", - "rad2deg", - "isclose", - "topk", - "bitwise_or", - "divide", - "amin", - "ones_like", - "reshape", - "atan", - "roll", - "greater_equal", - "imag", - "addmm", - "cross", - "less_than", - "normal", - "bitwise_not", - "atan2", - "ceil", - "mean", - "full_like", - "mode", - "unsqueeze", - "squeeze", - "allclose", - "maximum", - "fmax", - "zeros", - "masked_select", - "transpose", - "aaa", - "bitwise_xor", - "var", - "gcd", - "tril_indices", - "digamma", - "diff", - "flip", - "moveaxis", - "bmm", - "isfinite", - "dist", - "cosh", - "floor_divide", - "cos", - "concat", - "acosh", - "cumprod", - "sin", - "eye", - "tensordot", - "neg", - "logical_and", - "tanh", - "ones", - "median", - "logcumsumexp", - "abs", - "trunc", - "unbind", - "reciprocal", - "flatten", - "add", - "kthvalue", - "arange", - "angle", - "randperm", - "remainder", - "full", - "acos", - "argmin", - "renorm", - "nanmedian", - "multiply", - "lcm", - "asin", - "sort", - "set_default_dtype", - "nanmean", - "set_grad_enabled", - "mm", - "heaviside", - "real", - "log10", - "is_floating_point", - "std", - "diagonal", - "unique_consecutive", - "movedim", - "erfinv", - "trace", - "rand", - "mv", - "logical_or", - "nansum", - "diagflat", - "fmin", - "empty_like", - "logical_xor", - "get_default_dtype", - "log1p", - "seed", - "bitwise_and", - "less_equal", - "equal", - "t", - "logsumexp", - "logical_not", - "diag", - "exp", - "outer", - "complex", - "where", - "dtype", - "poisson", - "all", - "Tensor", - "erf", - "kron", - "dot", - "gather" - ], - "paddle.linalg": [ - "det", - "pinv", - "slogdet", - "triangular_solve", - "lstsq", - "matrix_power", - "eig", - "cholesky_solve", - "lu_unpack", - "inv", - "norm", - "cholesky" - ], - "paddle.signal": [ - "istft" - ], - "paddle.fft": [ - "rfftn", - "irfftn", - "ihfft", - "fftfreq", - "rfft", - "fft", - "irfft2", - "ifft2", - "rfft2", - "ifftn", - "hfft", - "rfftfreq", - "ifft", - "fftn", - "fft2", - "irfft" - ], - "paddle.nn.functional.conv": [ - "_conv_nd" - ] -} diff --git a/python/paddle/jit/symbolic_trace/utils/paddle_api_info/paddle_tensor_method.json b/python/paddle/jit/symbolic_trace/utils/paddle_api_info/paddle_tensor_method.json deleted file mode 100644 index 4d3097af4a11a..0000000000000 --- a/python/paddle/jit/symbolic_trace/utils/paddle_api_info/paddle_tensor_method.json +++ /dev/null @@ -1,189 +0,0 @@ -[ - "__ne__", - "__add__", - "__floordiv__", - "__le__", - "__mod__", - "__rmul__", - "__gt__", - "__rdiv__", - "__radd__", - "__div__", - "__pow__", - "__ge__", - "__rsub__", - "__truediv__", - "__sub__", - "__mul__", - "__lt__", - "__rtruediv__", - "__matmul__", - "__rpow__", - "__eq__", - - "heaviside", - "nanmean", - "log", - "log2", - "rot90", - "greater_than", - "kthvalue", - "scale_", - "arcsinh", - "atan2", - "diagonal", - "clip", - "movedim", - "tan", - "lerp", - "remainder", - "tolist", - "tile", - "var", - "fill_diagonal_", - "digamma", - "isclose", - "addmm", - "mod", - "is_floating_point", - "logical_or", - "exp_", - "gcd", - "trace", - "index_select", - "median", - "transpose", - "deg2rad", - "logsumexp", - "allclose", - "tanh_", - "acosh", - "floor_", - "isnan", - "nansum", - "tanh", - "bitwise_and", - "any", - "asinh", - "logit", - "abs", - "bmm", - "lerp_", - "less_equal", - "dot", - "neg", - "bincount", - "exponential_", - "prod", - "expand_as", - "sort", - "squeeze_", - "ceil", - "is_complex", - "diff", - "numel", - "log10", - "reciprocal_", - "max", - "square", - "chunk", - "where", - "all", - "sqrt", - "dim", - "outer", - "amin", - "asin", - "lcm", - "isfinite", - "logical_and", - "mean", - "cumsum", - "sign", - "renorm", - "acos", - "less_than", - "sum", - "argsort", - "atan", - "nonzero", - "cross", - "fmax", - "clip_", - "cosh", - "rad2deg", - "std", - "argmax", - "exp", - "erfinv_", - "mode", - "unbind", - "lu", - "reciprocal", - "uniform_", - "not_equal", - "register_hook", - "conj", - "argmin", - "arctanh", - "erf", - "logcumsumexp", - "sinh", - "cholesky", - "angle", - "floor_divide", - "sin", - "broadcast_to", - "remainder_", - "unsqueeze_", - "inner", - "pin_memory", - "flip", - "rsqrt_", - "roll", - "divide", - "multiply", - "dist", - "item", - "matmul", - "topk", - "ceil_", - "equal", - "logical_xor", - "squeeze", - "t", - "log1p", - "atanh", - "minimum", - "floor", - "equal_all", - "histogram", - "take_along_axis", - "cos", - "erfinv", - "greater_equal", - "flatten", - "isinf", - "unique_consecutive", - "trunc", - "unsqueeze", - "frac", - "zero_", - "fmin", - "nanmedian", - "fill_", - "mm", - "pow", - "rsqrt", - "bitwise_not", - "logical_not", - "amax", - "maximum", - "lgamma", - "inverse", - "matrix_power", - "arccosh", - "element_size", - "sqrt_", - "masked_select" -] diff --git a/python/paddle/jit/symbolic_trace/utils/pycode_inspect.py b/python/paddle/jit/symbolic_trace/utils/pycode_inspect.py deleted file mode 100644 index 89dcbc0f60e4a..0000000000000 --- a/python/paddle/jit/symbolic_trace/utils/pycode_inspect.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import types - - -def is_generator(code: types.CodeType): - co_generator = 0x20 - return (code.co_flags & co_generator) > 0 diff --git a/python/paddle/jit/symbolic_trace/utils/utils.py b/python/paddle/jit/symbolic_trace/utils/utils.py deleted file mode 100644 index 6867ef431f778..0000000000000 --- a/python/paddle/jit/symbolic_trace/utils/utils.py +++ /dev/null @@ -1,230 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import inspect -import os -import time -from typing import Any, Generic, TypeVar -from weakref import WeakValueDictionary - -from frozendict import frozendict - -import paddle -from paddle.utils import flatten, map_structure - -from .paddle_api_config import paddle_tensor_method # noqa: F401 -from .paddle_api_config import ( - fallback_list, - paddle_api_list, - paddle_api_module_prefix, -) - -T = TypeVar("T") - - -class Singleton(Generic[T]): - def __init__(self, cls: type[T]): - self._cls = cls - self._instance = {} - - def __call__(self) -> T: - if self._cls not in self._instance: - self._instance[self._cls] = self._cls() - return self._instance[self._cls] - - -class NameGenerator: - def __init__(self, prefix): - self.counter = 0 - self.prefix = prefix - - def next(self): - name = self.prefix + str(self.counter) - self.counter += 1 - return name - - def match_name(self, name: str) -> bool: - return name.startswith(self.prefix) - - -@Singleton -class ResumeFnNameFactory: - def __init__(self) -> None: - self.gen = NameGenerator('__resume_fn_') - - def next(self): - return self.gen.next() - - -def log(level, *args): - cur_level = int(os.environ.get("LOG_LEVEL", "0")) - if level <= cur_level: - print(*args, end="") - - -def log_do(level, fn): - cur_level = int(os.environ.get("LOG_LEVEL", "0")) - if level <= cur_level: - fn() - - -def no_eval_frame(func): - def no_eval_frame_func(*args, **kwargs): - old_cb = paddle.fluid.core.set_eval_frame(None) - try: - retval = func(*args, **kwargs) - except: - raise - finally: - paddle.fluid.core.set_eval_frame(old_cb) - return retval - - return no_eval_frame_func - - -def is_paddle_api(func): - if isinstance(func, paddle.nn.Layer): # ignore all the classes - return False - if hasattr(func, "__self__"): # ignore all the methods - return False - if inspect.isclass( - func - ): # paddle.Tensor should not be wrapped, but how about other situations? - return False - return in_paddle_module(func) or func in paddle_api_list - - -def in_paddle_module(func): - if hasattr(func, "__module__"): - module_str = func.__module__ - log(5, "find paddle function with __module__: ", module_str, "\n") - if hasattr(func, "__name__"): - log( - 5, " with __name__ : ", func.__name__, "\n" - ) - log(5, " with results : ") - for prefix in paddle_api_module_prefix: - if module_str.startswith(prefix): - log(5, " True\n") - return True - log(5, " False\n") - return False - - -def is_fallback_api(func): - return func in fallback_list - - -def is_proxy_tensor(obj): - return hasattr(obj, "_proxy_tensor_") - - -def map_if(*structures, pred, true_fn, false_fn): - def replace(*args): - if pred(*args): - return true_fn(*args) - return false_fn(*args) - - return map_structure(replace, *structures) - - -def count_if(*structures, pred): - def is_true(*args): - if pred(*args): - return 1 - return 0 - - return sum(flatten(map_structure(is_true, *structures))) - - -def freeze_structure(structure): - """ - only support list / dict and its nested structure - """ - if isinstance(structure, (list, tuple)): - return tuple(map(freeze_structure, structure)) - if isinstance(structure, dict): - return frozendict( - {k: freeze_structure(v) for k, v in structure.items()} - ) - # if isinstance(structure, types.CodeType): - # return id(structure) - return structure - - -class Cache: - def __init__(self, weak=False): - if not weak: - self.cache = {} - else: - self.cache = WeakValueDictionary() - self.hit_num = 0 - - def __call__(self, *args, **kwargs): - cache_key = self.key_fn(*args, **kwargs) - if cache_key in self.cache: - log(5, "cache hit: ", cache_key, "\n") - self.hit_num += 1 - return self.cache[cache_key] - value = self.value_fn(*args, **kwargs) - self.cache[cache_key] = value - return value - - def clear(self): - self.cache.clear() - self.hit_num = 0 - - def key_fn(self, *args, **kwargs): - raise NotImplementedError() - - def value_fn(self, *args, **kwargs): - raise NotImplementedError() - - -def execute_time(func): - def wrapper(*args, **kwargs): - start_time = time.time() - result = func(*args, **kwargs) - end_time = time.time() - execution_time = end_time - start_time - print("Execute time:", execution_time) - return result - - return wrapper - - -def meta_str(shape, dtype, stop_gradient): - return f"(shape: {shape}, dtype: {dtype}, stop_gradient: {stop_gradient})" - - -def is_strict_mode(): - return os.environ.get("STRICT_MODE", "0") == "1" - - -def show_trackers() -> str | None: - return os.environ.get("SHOW_TRACKERS", None) - - -def ASSERT(input: bool): - assert input - - -def list_find_index_by_id(li: list[Any], item: Any) -> int: - return [id(it) for it in li].index(id(item)) - - -def list_contain_by_id(li: list[Any], item: Any) -> int: - return id(item) in [id(it) for it in li] diff --git a/test/dygraph_to_static/dygraph_to_static_util.py b/test/dygraph_to_static/dygraph_to_static_util.py index c21c37daccdbc..5abc549e380e4 100644 --- a/test/dygraph_to_static/dygraph_to_static_util.py +++ b/test/dygraph_to_static/dygraph_to_static_util.py @@ -73,3 +73,21 @@ def impl(*args, **kwargs): func(*args, **kwargs) return impl + + +def sot_only_test(func): + """ + run this test function in ast only mode. + Usage: + + class TestA (unittest.TestCase): + @ast_only_test + def test_ast_only(self): + pass + """ + + def impl(*args, **kwargs): + if os.environ.get("ENABLE_FALL_BACK", "True") == "True": + func(*args, **kwargs) + + return impl diff --git a/test/dygraph_to_static/test_break_continue.py b/test/dygraph_to_static/test_break_continue.py index df4490e9e4977..946dfb9850c59 100644 --- a/test/dygraph_to_static/test_break_continue.py +++ b/test/dygraph_to_static/test_break_continue.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from dygraph_to_static_util import ast_only_test, dy2static_unittest import paddle from paddle import fluid @@ -25,12 +26,14 @@ np.random.seed(SEED) +@dy2static_unittest class TestDy2staticException(unittest.TestCase): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.dyfunc = None self.error = "Your if/else have different number of return value." + @ast_only_test def test_error(self): if self.dyfunc: with self.assertRaisesRegex(Dygraph2StaticException, self.error): diff --git a/test/dygraph_to_static/test_build_strategy.py b/test/dygraph_to_static/test_build_strategy.py index 39f4504375467..14e29b9ef4508 100644 --- a/test/dygraph_to_static/test_build_strategy.py +++ b/test/dygraph_to_static/test_build_strategy.py @@ -15,11 +15,13 @@ import unittest import numpy as np +from dygraph_to_static_util import ast_only_test, dy2static_unittest from test_resnet import ResNetHelper import paddle +@dy2static_unittest class TestResnetWithPass(unittest.TestCase): def setUp(self): self.build_strategy = paddle.static.BuildStrategy() @@ -64,6 +66,7 @@ def verify_predict(self): ), ) + @ast_only_test def test_resnet(self): static_loss = self.train(to_static=True) dygraph_loss = self.train(to_static=False) @@ -77,6 +80,7 @@ def test_resnet(self): ) self.verify_predict() + @ast_only_test def test_in_static_mode_mkldnn(self): paddle.fluid.set_flags({'FLAGS_use_mkldnn': True}) try: diff --git a/test/dygraph_to_static/test_cache_program.py b/test/dygraph_to_static/test_cache_program.py index 69100e60efb8d..750de4efda40e 100644 --- a/test/dygraph_to_static/test_cache_program.py +++ b/test/dygraph_to_static/test_cache_program.py @@ -16,6 +16,7 @@ from collections import Counter import numpy as np +from dygraph_to_static_util import dy2static_unittest from test_fetch_feed import Linear, Pool2D import paddle @@ -24,6 +25,7 @@ from paddle.jit.dy2static import convert_to_static +@dy2static_unittest class TestCacheProgram(unittest.TestCase): def setUp(self): self.batch_num = 5 @@ -36,7 +38,7 @@ def test_cache(self): with fluid.dygraph.guard(fluid.CPUPlace()): static_net = self.dygraph_class() for batch_id in range(self.batch_num): - out = static_net(self.data) + out = static_net(paddle.to_tensor(self.data)) # Check outputs prev_out = cur_out cur_out = out diff --git a/test/dygraph_to_static/test_cinn_prim.py b/test/dygraph_to_static/test_cinn_prim.py index 388cb67c66f43..f72ddbbffc2cb 100644 --- a/test/dygraph_to_static/test_cinn_prim.py +++ b/test/dygraph_to_static/test_cinn_prim.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from dygraph_to_static_util import ast_only_test, dy2static_unittest import paddle import paddle.nn.functional as F @@ -38,6 +39,7 @@ def forward(self, x): return out +@dy2static_unittest class TestPrimForward(unittest.TestCase): """ This case only tests prim_forward + to_static + cinn. Thus we need to @@ -88,6 +90,7 @@ def check_prim(self, net, use_prim): # Ensure that softmax is splitted into small ops self.assertTrue('softmax' not in fwd_ops) + @ast_only_test def test_cinn_prim_forward(self): dy_res = self.train(use_prim=False) cinn_res = self.train(use_prim=True) @@ -98,6 +101,7 @@ def test_cinn_prim_forward(self): ) +@dy2static_unittest class TestPrimForwardAndBackward(unittest.TestCase): """ Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph @@ -153,6 +157,7 @@ def check_prim(self, net, use_prim): if op != "matmul_v2_grad": self.assertTrue("_grad" not in op) + @ast_only_test def test_cinn_prim(self): dy_res = self.train(use_prim=False) cinn_res = self.train(use_prim=True) diff --git a/test/dygraph_to_static/test_cinn_prim_gelu.py b/test/dygraph_to_static/test_cinn_prim_gelu.py index 0f764c0745dac..88fa501f7696b 100644 --- a/test/dygraph_to_static/test_cinn_prim_gelu.py +++ b/test/dygraph_to_static/test_cinn_prim_gelu.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from dygraph_to_static_util import ast_only_test, dy2static_unittest import paddle import paddle.nn.functional as F @@ -52,6 +53,7 @@ def forward(self, x): return out +@dy2static_unittest class TestPrimForwardAndBackward(unittest.TestCase): """ Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph @@ -104,6 +106,7 @@ def check_prim(self, net, use_prim): # Ensure that gelu is splitted into small ops self.assertTrue('gelu' not in fwd_ops) + @ast_only_test def test_cinn_prim(self): for shape in self.shapes: for dtype in self.dtypes: diff --git a/test/dygraph_to_static/test_cinn_prim_mean.py b/test/dygraph_to_static/test_cinn_prim_mean.py index c920ce9b6dccf..65451ffad5911 100644 --- a/test/dygraph_to_static/test_cinn_prim_mean.py +++ b/test/dygraph_to_static/test_cinn_prim_mean.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from dygraph_to_static_util import ast_only_test, dy2static_unittest import paddle from paddle import tensor @@ -54,6 +55,7 @@ def forward(self, x): return out +@dy2static_unittest class TestPrimForward(unittest.TestCase): """ This case only tests prim_forward + to_static + cinn. Thus we need to @@ -110,6 +112,7 @@ def check_prim(self, net, use_prim): # Ensure that reduce_mean is splitted into small ops self.assertTrue('reduce_mean' not in fwd_ops) + @ast_only_test def test_cinn_prim_forward(self): for shape in self.shapes: for dtype in self.dtypes: @@ -131,6 +134,7 @@ def test_cinn_prim_forward(self): ) +@dy2static_unittest class TestPrimForwardAndBackward(unittest.TestCase): """ Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph @@ -183,6 +187,7 @@ def check_prim(self, net, use_prim): # Ensure that reduce_mean is splitted into small ops self.assertTrue('reduce_mean' not in fwd_ops) + @ast_only_test def test_cinn_prim(self): for shape in self.shapes: for dtype in self.dtypes: diff --git a/test/dygraph_to_static/test_closure_analysis.py b/test/dygraph_to_static/test_closure_analysis.py index bed90ccbe47fe..3fe282887e107 100644 --- a/test/dygraph_to_static/test_closure_analysis.py +++ b/test/dygraph_to_static/test_closure_analysis.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import os import unittest from numpy import append @@ -325,4 +326,5 @@ def vlist_of_dict(x): if __name__ == '__main__': + os.environ['ENABLE_FALL_BACK'] = "False" unittest.main() diff --git a/test/dygraph_to_static/test_container.py b/test/dygraph_to_static/test_container.py index 34da0ebc2c71f..8170c55f59d6e 100644 --- a/test/dygraph_to_static/test_container.py +++ b/test/dygraph_to_static/test_container.py @@ -17,6 +17,7 @@ import unittest import numpy as np +from dygraph_to_static_util import ast_only_test, dy2static_unittest import paddle @@ -69,6 +70,7 @@ def forward(self, x): return self.layers(x) +@dy2static_unittest class TestSequential(unittest.TestCase): def setUp(self): paddle.set_device('cpu') @@ -104,6 +106,7 @@ def _run(self, to_static): return out + @ast_only_test def test_train(self): paddle.jit.set_code_level(100) dy_out = self._run(to_static=False) diff --git a/test/dygraph_to_static/test_convert_call.py b/test/dygraph_to_static/test_convert_call.py index 11f947d183243..7b54ea5956134 100644 --- a/test/dygraph_to_static/test_convert_call.py +++ b/test/dygraph_to_static/test_convert_call.py @@ -16,6 +16,7 @@ import unittest import numpy as np +from dygraph_to_static_util import ast_only_test, dy2static_unittest import paddle import paddle.jit.dy2static as _jst @@ -252,6 +253,7 @@ def test_code(self): ) +@dy2static_unittest class TestNotToConvert2(TestRecursiveCall2): def set_func(self): self.net = NotToStaticHelper() @@ -264,7 +266,9 @@ def test_conversion_options(self): self.assertIsNotNone(options) self.assertTrue(options.not_convert) + @ast_only_test def test_code(self): + self.dygraph_func = paddle.jit.to_static(self.net.sum) # check 'if statement' is not converted self.assertIn("if x.shape[0] > 1", self.dygraph_func.code) @@ -277,19 +281,23 @@ def forward(self, x): return x +@dy2static_unittest class TestConvertPaddleAPI(unittest.TestCase): + @ast_only_test def test_functional_api(self): func = paddle.nn.functional.relu func = paddle.jit.to_static(func) self.assertNotIn("_jst.IfElse", func.code) self.assertIn("if in_dynamic_mode()", func.code) + @ast_only_test def test_class_api(self): bn = paddle.nn.SyncBatchNorm(2) paddle.jit.to_static(bn) self.assertNotIn("_jst.IfElse", bn.forward.code) self.assertIn("if in_dynamic_mode()", bn.forward.code) + @ast_only_test def test_class_patch_api(self): paddle.nn.SyncBatchNorm.forward = forward bn = paddle.nn.SyncBatchNorm(2) diff --git a/test/dygraph_to_static/test_cpu_cuda_to_tensor.py b/test/dygraph_to_static/test_cpu_cuda_to_tensor.py index f8d15971a7bc0..dfd2aae4ed62c 100644 --- a/test/dygraph_to_static/test_cpu_cuda_to_tensor.py +++ b/test/dygraph_to_static/test_cpu_cuda_to_tensor.py @@ -15,6 +15,11 @@ import unittest import numpy as np +from dygraph_to_static_util import ( + ast_only_test, + dy2static_unittest, + sot_only_test, +) import paddle @@ -28,8 +33,8 @@ def func(x): return x x = paddle.to_tensor([3]) - print(paddle.jit.to_static(func).code) - print(paddle.jit.to_static(func)(x)) + # print(paddle.jit.to_static(func).code) + # print(paddle.jit.to_static(func)(x)) class TestToTensor(unittest.TestCase): @@ -41,7 +46,7 @@ def func(x): return x x = paddle.to_tensor([3]) - print(paddle.jit.to_static(func).code) + # print(paddle.jit.to_static(func).code) np.testing.assert_allclose( paddle.jit.to_static(func)(x).numpy(), np.array([1, 2, 3, 4]), @@ -49,7 +54,9 @@ def func(x): ) +@dy2static_unittest class TestToTensor1(unittest.TestCase): + @ast_only_test def test_to_tensor_with_variable_list(self): def func(x): ones = paddle.to_tensor([1]) @@ -61,28 +68,59 @@ def func(x): return x x = paddle.to_tensor([3]) - print(paddle.jit.to_static(func).code) np.testing.assert_allclose( paddle.jit.to_static(func)(x).numpy(), np.array([1, 2, 3, 4]), rtol=1e-05, ) + @sot_only_test + def test_to_tensor_with_variable_list_sot(self): + def func(x): + ones = paddle.to_tensor([1]) + twos = paddle.to_tensor([2]) + """ we ignore the [3] and [4], they will be assign to a variable, and is regard as scalar. + TODO: deal with this case after 0-dim tensor is developed. + """ + x = paddle.to_tensor([ones, twos, [3], [4]]) + return x + x = paddle.to_tensor([3]) + np.testing.assert_allclose( + paddle.jit.to_static(func)(x), + np.array([[1], [2], [3], [4]]), + rtol=1e-05, + ) + + +@dy2static_unittest class TestToTensor2(unittest.TestCase): + @ast_only_test def test_to_tensor_with_variable_list(self): def func(x): x = paddle.to_tensor([[1], [2], [3], [4]]) return x x = paddle.to_tensor([3]) - print(paddle.jit.to_static(func).code) np.testing.assert_allclose( paddle.jit.to_static(func)(x).numpy(), np.array([[1], [2], [3], [4]]), rtol=1e-05, ) + @sot_only_test + def test_to_tensor_with_variable_list_sot(self): + def func(x): + x = paddle.to_tensor([[1], [2], [3], [4]]) + return x + + x = paddle.to_tensor([3]) + np.testing.assert_allclose( + paddle.jit.to_static(func)(x), + np.array([[1], [2], [3], [4]]), + rtol=1e-05, + ) + if __name__ == '__main__': unittest.main() diff --git a/test/dygraph_to_static/test_declarative.py b/test/dygraph_to_static/test_declarative.py index ddd8680dc3c10..936a71236b44f 100644 --- a/test/dygraph_to_static/test_declarative.py +++ b/test/dygraph_to_static/test_declarative.py @@ -30,6 +30,8 @@ from paddle.nn import Layer from paddle.static import InputSpec +os.environ['ENABLE_FALL_BACK'] = "False" # NOTE: ast only + class SimpleNet(Layer): def __init__(self): diff --git a/test/dygraph_to_static/test_decorator_transform.py b/test/dygraph_to_static/test_decorator_transform.py index 6add36fd9e09e..8568fa1e181b4 100644 --- a/test/dygraph_to_static/test_decorator_transform.py +++ b/test/dygraph_to_static/test_decorator_transform.py @@ -19,6 +19,7 @@ import decos import numpy as np +from dygraph_to_static_util import ast_only_test, dy2static_unittest import paddle @@ -147,7 +148,6 @@ def fun8(x, y=0): return a -@paddle.jit.to_static def forward(): funcs = [fun1, fun2, fun3, fun4, fun5, fun6, fun7, fun8] out = [] @@ -166,7 +166,6 @@ def fun9(): print('in fun9 want contextmanager warning') -@paddle.jit.to_static def warn1(): fun9() @@ -182,9 +181,10 @@ def deco_with_paddle_api(): return fun10() +@dy2static_unittest class TestDecoratorTransform(unittest.TestCase): def test_deco_transform(self): - outs = forward() + outs = paddle.jit.to_static(forward)() np.testing.assert_allclose(outs[0], np.array(3), rtol=1e-05) np.testing.assert_allclose(outs[1], np.array(5), rtol=1e-05) np.testing.assert_allclose(outs[2], np.array(6), rtol=1e-05) @@ -194,11 +194,12 @@ def test_deco_transform(self): np.testing.assert_allclose(outs[6], np.array(9), rtol=1e-05) np.testing.assert_allclose(outs[7], np.array(10), rtol=1e-05) + @ast_only_test def test_contextmanager_warning(self): paddle.disable_static() with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - warn1() + paddle.jit.to_static(warn1)() flag = False for warn in w: if ( diff --git a/test/dygraph_to_static/test_fallback.py b/test/dygraph_to_static/test_fallback.py index e4dc0114054ad..03602a586182c 100644 --- a/test/dygraph_to_static/test_fallback.py +++ b/test/dygraph_to_static/test_fallback.py @@ -16,6 +16,7 @@ import unittest import numpy as np +from dygraph_to_static_util import ast_only_test import paddle @@ -84,6 +85,7 @@ def test_case_net_fallback(self): u_net(self.x).numpy(), ) + @ast_only_test def test_case_net_error(self): s_net = SuppportNet() u_net = UnsuppportNet() @@ -110,6 +112,7 @@ def test_case_training(self): np.testing.assert_allclose(u_net(self.x).numpy(), [1, 1]) assert u_net.training is False, "Training must be false." + @ast_only_test def test_case_save_error(self): """ test the save will raise error. diff --git a/test/dygraph_to_static/test_gradname_parse.py b/test/dygraph_to_static/test_gradname_parse.py index 743f92f758f4f..ca63511a5f6a3 100644 --- a/test/dygraph_to_static/test_gradname_parse.py +++ b/test/dygraph_to_static/test_gradname_parse.py @@ -16,9 +16,9 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest import paddle -from paddle import ParamAttr from paddle.nn import BatchNorm, Linear @@ -28,13 +28,9 @@ def __init__(self): self.linear0 = Linear(100, 50) self.linear1 = Linear(50, 10) - param_attr0 = ParamAttr(name="aaaprefix_bn_scale") - bias_attr0 = ParamAttr(name="aaaprefix_bn_offset") - self.bn0 = BatchNorm(50, param_attr=param_attr0, bias_attr=bias_attr0) + self.bn0 = BatchNorm(50) - param_attr1 = ParamAttr(name="bn_scale") - bias_attr1 = ParamAttr(name="bn_offset") - self.bn1 = BatchNorm(10, param_attr=param_attr1, bias_attr=bias_attr1) + self.bn1 = BatchNorm(10) def forward(self, x): x1 = self.linear0(x) @@ -45,6 +41,7 @@ def forward(self, x): return dx[0] +@dy2static_unittest class TestGradNameParse(unittest.TestCase): def test_grad_name_parse(self): net = SimpleNet() @@ -72,6 +69,7 @@ def tanh_high_order_grad(x): return paddle.grad(y, x, create_graph=True)[0] +@dy2static_unittest class TestTanhHighOrderGrad(unittest.TestCase): def setUp(self): self.func = tanh_high_order_grad @@ -116,10 +114,11 @@ def test_run(self): def matmul_high_order_grad(x, y): z = paddle.matmul(x, y) - g = paddle.grad(z, [x, y], create_graph=False) + g = paddle.grad(z, [x, y], create_graph=True) return g[0] +@dy2static_unittest class TestMatMulHighOrderGrad1(TestTanhHighOrderGrad): def setUp(self): self.func = matmul_high_order_grad @@ -139,6 +138,7 @@ def setUp(self): self.dy2st_grad_input = (x2,) +@dy2static_unittest class TestMatMulHighOrderGrad2(TestTanhHighOrderGrad): def setUp(self): self.func = matmul_high_order_grad diff --git a/test/dygraph_to_static/test_ifelse.py b/test/dygraph_to_static/test_ifelse.py index 722240050d116..9ca60bf383a0a 100644 --- a/test/dygraph_to_static/test_ifelse.py +++ b/test/dygraph_to_static/test_ifelse.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from dygraph_to_static_util import ast_only_test, dy2static_unittest from ifelse_simple_func import ( NetWithControlFlowIf, add_fn, @@ -54,12 +55,14 @@ place = fluid.CPUPlace() +@dy2static_unittest class TestDy2staticException(unittest.TestCase): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.dyfunc = None self.error = "Your if/else have different number of return value." + @ast_only_test def test_error(self): if self.dyfunc: with self.assertRaisesRegex(Dygraph2StaticException, self.error): @@ -413,10 +416,11 @@ def case_func(training): self.assertEqual(paddle.jit.to_static(case_func)(True), -2) +@dy2static_unittest class TestDy2StIfElseRetInt1(unittest.TestCase): def setUp(self): self.x = np.random.random([5]).astype('float32') - self.dyfunc = dyfunc_ifelse_ret_int1 + self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int1) self.out = self.get_dy2stat_out() def get_dy2stat_out(self): @@ -426,7 +430,9 @@ def get_dy2stat_out(self): paddle.jit.enable_to_static(False) return out + @ast_only_test def test_ast_to_func(self): + self.setUp() self.assertIsInstance(self.out[0], (paddle.Tensor, core.eager.Tensor)) self.assertIsInstance(self.out[1], int) @@ -438,21 +444,26 @@ def setUp(self): self.dyfunc = dyfunc_ifelse_ret_int2 +@dy2static_unittest class TestDy2StIfElseRetInt3(TestDy2StIfElseRetInt1): def setUp(self): self.x = np.random.random([5]).astype('float32') - self.dyfunc = dyfunc_ifelse_ret_int3 + self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int3) self.out = self.get_dy2stat_out() + @ast_only_test def test_ast_to_func(self): + self.setUp() self.assertIsInstance(self.out, (paddle.Tensor, core.eager.Tensor)) +@dy2static_unittest class TestDy2StIfElseRetInt4(TestDy2StIfElseRetInt1): def setUp(self): self.x = np.random.random([5]).astype('float32') - self.dyfunc = dyfunc_ifelse_ret_int4 + self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int4) + @ast_only_test def test_ast_to_func(self): paddle.jit.enable_to_static(True) with self.assertRaises(Dygraph2StaticException): diff --git a/test/dygraph_to_static/test_mnist.py b/test/dygraph_to_static/test_mnist.py index f324919e3cc14..2f33516e89dbf 100644 --- a/test/dygraph_to_static/test_mnist.py +++ b/test/dygraph_to_static/test_mnist.py @@ -18,6 +18,7 @@ from time import time import numpy as np +from dygraph_to_static_util import ast_only_test from predictor_utils import PredictorTools import paddle @@ -158,6 +159,7 @@ def train_static(self): def train_dygraph(self): return self.train(to_static=False) + @ast_only_test def test_mnist_to_static(self): dygraph_loss = self.train_dygraph() static_loss = self.train_static() diff --git a/test/dygraph_to_static/test_mobile_net.py b/test/dygraph_to_static/test_mobile_net.py index 7f0efd67620be..f5b99c23cc463 100644 --- a/test/dygraph_to_static/test_mobile_net.py +++ b/test/dygraph_to_static/test_mobile_net.py @@ -19,6 +19,7 @@ import unittest import numpy as np +from dygraph_to_static_util import ast_only_test from predictor_utils import PredictorTools import paddle @@ -713,6 +714,7 @@ def assert_same_predict(self, model_name): ), ) + @ast_only_test def test_mobile_net(self): # MobileNet-V1 self.assert_same_loss("MobileNetV1") diff --git a/test/dygraph_to_static/test_op_attr.py b/test/dygraph_to_static/test_op_attr.py index d474d80b63e60..25194397e4651 100644 --- a/test/dygraph_to_static/test_op_attr.py +++ b/test/dygraph_to_static/test_op_attr.py @@ -14,6 +14,8 @@ import unittest +from dygraph_to_static_util import ast_only_test + import paddle from paddle.static import InputSpec @@ -75,6 +77,7 @@ def expected_results(self): 'elementwise_sub': self.sub_attrs, } + @ast_only_test def test_set_op_attrs(self): net = NetWithOpAttr(self.in_num, self.out_num) # set attrs @@ -116,6 +119,7 @@ def check_op_attrs(self, main_program): else: self.assertEqual(op_val, expect_val) + @ast_only_test def test_set_op_attrs_with_sub_block(self): net = NetWithOpAttr(self.in_num, self.out_num) # set attrs diff --git a/test/dygraph_to_static/test_tsm.py b/test/dygraph_to_static/test_tsm.py index d0892a50fdd35..9b04d39c493c2 100644 --- a/test/dygraph_to_static/test_tsm.py +++ b/test/dygraph_to_static/test_tsm.py @@ -45,7 +45,9 @@ def parse_args(): default=fluid.is_compiled_with_cuda(), help='default use gpu.', ) - args = parser.parse_args(['--config', 'tsm.yaml']) + args = parser.parse_args( + ['--config', __file__.rpartition('/')[0] + '/tsm.yaml'] + ) return args diff --git a/test/dygraph_to_static/test_write_python_container.py b/test/dygraph_to_static/test_write_python_container.py index a2aa94886e5d4..a175b881d86c7 100644 --- a/test/dygraph_to_static/test_write_python_container.py +++ b/test/dygraph_to_static/test_write_python_container.py @@ -14,6 +14,12 @@ import unittest +from dygraph_to_static_util import ( + ast_only_test, + dy2static_unittest, + sot_only_test, +) + import paddle @@ -93,6 +99,7 @@ def func_ifelse_write_nest_list_dict(x): return res +@dy2static_unittest class TestWriteContainer(unittest.TestCase): def setUp(self): self.set_func() @@ -110,6 +117,15 @@ def get_raw_value(self, container, getitem_path): out = out[path] return out + @sot_only_test + def test_write_container_sot(self): + func_static = paddle.jit.to_static(self.func) + input = paddle.to_tensor([1, 2, 3]) + out_static = self.get_raw_value(func_static(input), self.getitem_path) + out_dygraph = self.get_raw_value(self.func(input), self.getitem_path) + self.assertEqual(out_static, out_dygraph) + + @ast_only_test def test_write_container(self): func_static = paddle.jit.to_static(self.func) input = paddle.to_tensor([1, 2, 3]) From b61b1aa62bc8cf0264f35e380379c0a38ad470c2 Mon Sep 17 00:00:00 2001 From: NotHaozi Date: Fri, 16 Jun 2023 06:46:14 +0000 Subject: [PATCH 2/2] fix unittests --- test/dygraph_to_static/test_resnet.py | 5 +++++ test/dygraph_to_static/test_resnet_v2.py | 4 ++++ test/dygraph_to_static/test_rollback.py | 2 ++ test/dygraph_to_static/test_save_inference_model.py | 3 +++ test/dygraph_to_static/test_save_load.py | 3 +++ 5 files changed, 17 insertions(+) diff --git a/test/dygraph_to_static/test_resnet.py b/test/dygraph_to_static/test_resnet.py index 9fcc1a803f822..2e245b37ef004 100644 --- a/test/dygraph_to_static/test_resnet.py +++ b/test/dygraph_to_static/test_resnet.py @@ -19,6 +19,7 @@ import unittest import numpy as np +from dygraph_to_static_util import ast_only_test from predictor_utils import PredictorTools import paddle @@ -413,6 +414,7 @@ def verify_predict(self): ), ) + @ast_only_test def test_resnet(self): static_loss = self.train(to_static=True) dygraph_loss = self.train(to_static=False) @@ -426,6 +428,7 @@ def test_resnet(self): ) self.verify_predict() + @ast_only_test def test_resnet_composite_backward(self): core._set_prim_backward_enabled(True) static_loss = self.train(to_static=True) @@ -440,6 +443,7 @@ def test_resnet_composite_backward(self): ), ) + @ast_only_test def test_resnet_composite_forward_backward(self): core._set_prim_all_enabled(True) static_loss = self.train(to_static=True) @@ -454,6 +458,7 @@ def test_resnet_composite_forward_backward(self): ), ) + @ast_only_test def test_in_static_mode_mkldnn(self): fluid.set_flags({'FLAGS_use_mkldnn': True}) try: diff --git a/test/dygraph_to_static/test_resnet_v2.py b/test/dygraph_to_static/test_resnet_v2.py index 2efbe46cedfec..943dc6d21ca95 100644 --- a/test/dygraph_to_static/test_resnet_v2.py +++ b/test/dygraph_to_static/test_resnet_v2.py @@ -19,6 +19,7 @@ import unittest import numpy as np +from dygraph_to_static_util import ast_only_test from predictor_utils import PredictorTools import paddle @@ -412,6 +413,7 @@ def verify_predict(self): ), ) + @ast_only_test def test_resnet(self): static_loss = self.train(to_static=True) dygraph_loss = self.train(to_static=False) @@ -425,6 +427,7 @@ def test_resnet(self): ) self.verify_predict() + @ast_only_test def test_resnet_composite(self): core._set_prim_backward_enabled(True) core._add_skip_comp_ops("batch_norm") @@ -440,6 +443,7 @@ def test_resnet_composite(self): ), ) + @ast_only_test def test_in_static_mode_mkldnn(self): paddle.fluid.set_flags({'FLAGS_use_mkldnn': True}) try: diff --git a/test/dygraph_to_static/test_rollback.py b/test/dygraph_to_static/test_rollback.py index c418a850d5aaf..443fffb9d134c 100644 --- a/test/dygraph_to_static/test_rollback.py +++ b/test/dygraph_to_static/test_rollback.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from dygraph_to_static_util import ast_only_test import paddle from paddle.jit.dy2static.program_translator import StaticFunction @@ -88,6 +89,7 @@ class TestRollBackNet(unittest.TestCase): def setUp(self): paddle.set_device("cpu") + @ast_only_test def test_net(self): net = paddle.jit.to_static(Net()) x = paddle.randn([3, 4]) diff --git a/test/dygraph_to_static/test_save_inference_model.py b/test/dygraph_to_static/test_save_inference_model.py index 1d1826c1fd8f6..bacaa59e89b0a 100644 --- a/test/dygraph_to_static/test_save_inference_model.py +++ b/test/dygraph_to_static/test_save_inference_model.py @@ -17,6 +17,7 @@ import unittest import numpy as np +from dygraph_to_static_util import ast_only_test import paddle from paddle import fluid @@ -53,6 +54,7 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() + @ast_only_test def test_save_inference_model(self): fc_size = 20 x_data = np.random.random((fc_size, fc_size)).astype('float32') @@ -145,6 +147,7 @@ def load_and_run_inference( class TestPartialProgramRaiseError(unittest.TestCase): + @ast_only_test def test_param_type(self): paddle.jit.enable_to_static(True) x_data = np.random.random((20, 20)).astype('float32') diff --git a/test/dygraph_to_static/test_save_load.py b/test/dygraph_to_static/test_save_load.py index 1f07fe0124502..f02867808d6f5 100644 --- a/test/dygraph_to_static/test_save_load.py +++ b/test/dygraph_to_static/test_save_load.py @@ -17,6 +17,7 @@ import unittest import numpy as np +from dygraph_to_static_util import ast_only_test from test_fetch_feed import Linear import paddle @@ -115,6 +116,7 @@ def test_save_load_same_result(self): dygraph_loss.numpy(), static_loss.numpy(), rtol=1e-05 ) + @ast_only_test def test_save_load_prim(self): with fluid.dygraph.guard(place): self.x = paddle.randn([4, 2, 6, 6], dtype="float32") @@ -155,6 +157,7 @@ def test_save_load_prim(self): self.assertIn("pool2d", load_op_type_list) np.testing.assert_allclose(res.numpy(), new_res.numpy(), rtol=1e-05) + @ast_only_test def test_save_load_prim_with_hook(self): with fluid.dygraph.guard(place): self.x = paddle.randn([4, 2, 6, 6], dtype="float32")