diff --git a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py index 1dbe77cd48052..185ea35867a1d 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py +++ b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py @@ -40,6 +40,7 @@ map_if, tmp_name_guard, ) +from ..instruction_utils import get_instructions from .guard import Guard, StringifyExpression, make_guard from .mutable_data import MutationDel, MutationNew, MutationSet from .pycode_generator import PyCodeGen @@ -241,7 +242,73 @@ def guard_fn(self) -> Guard: return make_guard(guards) - def start_compile_with_name_store(self, ret_vars, to_store_vars): + def _restore_origin_opcode(self, stack_vars, store_var_info, instr_idx): + class VariableLoader: + def __init__(self, store_var_info, pycode_gen): + self._store_var_info = store_var_info + self._pycode_gen: PyCodeGen = pycode_gen + + def load(self, var, allow_push_null=True): + if isinstance(var, NullVariable): + # PUSH_NULL is an opcode + if allow_push_null: + var.reconstruct(self._pycode_gen) + else: + # Avoid passing NULL as a parameter to the resume function + self._pycode_gen.gen_load_null_variable() + return + # only restored vars in stack, so used gen_load to process global var + self._pycode_gen.gen_load(self._store_var_info[var]) + + origin_instr = get_instructions(self.pycode_gen._origin_code) + + for instr in origin_instr[0:instr_idx]: + if ( + instr.opname == 'LOAD_FAST' + and instr.argval in self.pycode_gen._frame.f_locals.keys() + and isinstance( + self.pycode_gen._frame.f_locals[instr.argval], NullVariable + ) + ): + self.pycode_gen._frame.f_locals[instr.argval].reconstruct( + self.pycode_gen + ) + elif ( + instr.opname == 'LOAD_GLOBAL' + and instr.argval in self.pycode_gen._frame.f_globals.keys() + and isinstance( + self.pycode_gen._frame.f_globals[instr.argval], NullVariable + ) + ): + self.pycode_gen._frame.f_globals[instr.argval].reconstruct( + self.pycode_gen + ) + else: + self.pycode_gen.extend_instrs([instr]) + + nop = self.pycode_gen._add_instr("NOP") + + for instr in origin_instr: + if instr.jump_to == origin_instr[instr_idx]: + instr.jump_to = nop + + self.pycode_gen.hooks.append( + lambda: self.pycode_gen.extend_instrs( + iter(origin_instr[instr_idx + 1 :]) + ) + ) + + self.pycode_gen.gen_enable_eval_frame() + + name_gen = NameGenerator("__start_compile_saved_orig_") + + for var in stack_vars[::-1]: + store_var_info[var] = name_gen.next() + self.pycode_gen.gen_store_fast(store_var_info[var]) + + return VariableLoader(store_var_info, self.pycode_gen) + + def _build_compile_fn_with_name_store(self, ret_vars, to_store_vars): class VariableLoader: def __init__(self, index_for_load, pycode_gen): self._index_for_load = index_for_load @@ -249,12 +316,14 @@ def __init__(self, index_for_load, pycode_gen): def load(self, var, allow_push_null=True): if isinstance(var, NullVariable): + # PUSH_NULL is an opcode if allow_push_null: var.reconstruct(self._pycode_gen) else: # Avoid passing NULL as a parameter to the resume function self._pycode_gen.gen_load_null_variable() return + # all vars to be load are saved by this function, so load_fast is correct self._pycode_gen.gen_load_fast(self._index_for_load[var.id]) # var_id -> local_name mapping @@ -264,7 +333,8 @@ def load(self, var, allow_push_null=True): ) self.start_compile(*(ret_vars + to_store_vars)) name_gen = NameGenerator("__start_compile_saved_") - for var in to_store_vars: + + for var in to_store_vars[::-1]: index_for_load[var.id] = name_gen.next() def _log_fn(): @@ -275,8 +345,8 @@ def _log_fn(): log_do(4, _log_fn) - for var in to_store_vars[::-1]: self.pycode_gen.gen_store_fast(index_for_load[var.id]) + return VariableLoader(index_for_load, self.pycode_gen) @event_register("start_compile", event_level=2) @@ -552,6 +622,18 @@ def _find_tensor_outputs( Args: outputs: output variables """ + + def collect_related_dummy_tensor(var): + if isinstance(var.tracker, DummyTracker): + if isinstance(var, TensorVariable): + return [var] + else: + retval = [] + for inp in var.tracker.inputs: + retval.extend(collect_related_dummy_tensor(inp)) + return retval + return [] + output_tensors: OrderedSet[TensorVariable] = OrderedSet() # Find Tensor Variables from outputs. for output in outputs: @@ -559,6 +641,9 @@ def _find_tensor_outputs( if isinstance(output, TensorVariable): output_tensors.add(output) else: + for inp in output.tracker.inputs: + for _var in collect_related_dummy_tensor(inp): + output_tensors.add(_var) # Guard output that can not be traced. self.add_global_guarded_variable(output) # Find Tensor Variables from side effects Variables. diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index ea76642a671db..052b89c1cc1e1 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -227,7 +227,6 @@ def jump_break_graph_decorator(normal_jump: Callable): def inner(self: OpcodeExecutor, instr: Instruction): result = self.stack.top if isinstance(result, TensorVariable): - self.stack.pop() # fallback when in OpcodeExecutor # raise error in OpcodeInlineExecutor log(3, "[BreakGraph] jump break graph, because if tensor\n") @@ -327,7 +326,11 @@ class OpcodeExecutorBase: """ + class EmptyCode: + pass + call_stack: list[OpcodeExecutorBase] = [] + empty_code = EmptyCode() @staticmethod def validate_value(value): @@ -352,7 +355,7 @@ def __init__(self, code: types.CodeType, graph: FunctionGraph): self._current_line: int = -1 self._instructions = get_instructions(self._code) self._graph = graph - self.new_code: types.CodeType | None = None + self.new_code: types.CodeType | None = self.empty_code self.guard_fn = None self._name = "Executor" self._call_shape: tuple[ @@ -1506,7 +1509,39 @@ def _prepare_virtual_env(self): ) ) - def _create_resume_fn(self, index, stack_size=0): + def gen_compute_in_break_with_name_store(self, restore_names, instr_idx): + """ + branch 1: if the graph size is too small, just run in dygraph + branch 2: if the graph is big enough, create compiled_fn + + This api will generator opcodes in different situation, the generated codes + will do the same thing as origin code. + + restore_names: + the names used in resume functions, branch 2 will restore these values, + branch 1 also need these names for generating opcode, but they are not + needed to be restored + instr_idx: + the index for branch 1 to find the boundary and copy origin opcode + """ + if self._graph.sir_ctx.TOS.graph_size() < ENV_MIN_GRAPH_SIZE.get(): + store_var_info = {} + for name in restore_names: + _var = self.get_var(name) + if _var not in self.stack: + store_var_info[_var] = name + return self._graph._restore_origin_opcode( + list(self.stack), store_var_info, instr_idx + ) + else: + store_vars = list(self.stack) + for name in restore_names: + _var = self.get_var(name) + if _var not in self.stack: + store_vars.append(_var) + return self._graph._build_compile_fn_with_name_store([], store_vars) + + def _create_resume_fn(self, index, stack_size): """ Create a resume function and its inputs at the specified index. @@ -1523,7 +1558,7 @@ def _create_resume_fn(self, index, stack_size=0): return fn, inputs @fallback_when_occur_error - def _break_graph_in_jump(self, result: VariableBase, instr: Instruction): + def _break_graph_in_jump(self, result: TensorVariable, instr: Instruction): """ Break the graph at a JUMP instruction. @@ -1533,7 +1568,10 @@ def _break_graph_in_jump(self, result: VariableBase, instr: Instruction): """ self._graph.add_global_guarded_variable(result) - stack_size = len(self.stack) + # minus the bool value + stack_size = len(self.stack) - 1 + + # gen call static fn opcode if_fn, if_inputs = self._create_resume_fn( self.indexof(instr) + 1, stack_size ) @@ -1541,29 +1579,15 @@ def _break_graph_in_jump(self, result: VariableBase, instr: Instruction): self.indexof(instr.jump_to), stack_size ) - # gen call static fn opcode - inputs_name = if_inputs | else_inputs - inputs_var = [ - self.get_var(name) - for name in inputs_name - if self.get_var(name) is not result - ] - ret_vars = [ - result, - ] + inputs_var - # Collect all the to store variables. - store_vars = [] - for stack_arg in self.stack: - store_vars.append(stack_arg) - for name in inputs_name: - store_vars.append(self.get_var(name)) + inputs_names = if_inputs | else_inputs - var_loader = self._graph.start_compile_with_name_store( - ret_vars, store_vars + var_loader = self.gen_compute_in_break_with_name_store( + inputs_names, self.indexof(instr) ) - # only pop the input of if/else resume fn, and keep the bool tensor result on the stack - for _ in inputs_var: - self._graph.pycode_gen.gen_pop_top() + + var_loader.load(result) + # the result is used by if opcode, and should not be input of resume_fn + self.stack.pop() # gen call if/else resume fn opcode if if_fn is not None: @@ -1633,31 +1657,13 @@ def _break_graph_in_call( self.stack = origin_stack # gen call static fn opcode - ret_vars = [ - arg - for arg in self.stack - if isinstance(arg, (TensorVariable, ContainerVariable)) - ] + resume_input_name = analysis_inputs(self._instructions, index + 1) - ret_vars = ret_vars + [ - self.get_var(name) - for name in resume_input_name - if self.get_var(name) not in ret_vars - ] - # Collect all the to store variables. - store_vars = [] - for stack_arg in self.stack: - store_vars.append(stack_arg) - for name in resume_input_name: - store_vars.append(self.get_var(name)) - var_loader = self._graph.start_compile_with_name_store( - ret_vars, store_vars + var_loader = self.gen_compute_in_break_with_name_store( + resume_input_name, self.indexof(instr) ) - for _ in ret_vars: - self._graph.pycode_gen.gen_pop_top() - # gen graph break call fn opcode stack_effect = calc_stack_effect(instr) pop_n = push_n - stack_effect @@ -1670,11 +1676,12 @@ def _break_graph_in_call( # gen call resume fn opcode # NOTE(SigureMo): In Python 3.11,we need generate KW_NAMES if the call shape is not None. self._graph.pycode_gen.gen_kw_names(self._call_shape) - self._graph.pycode_gen.add_pure_instructions([instr]) + self._graph.pycode_gen.extend_instrs([instr]) self.stack.pop_n(pop_n) stack_size = len(self.stack) + push_n resume_fn, _ = self._create_resume_fn(index + 1, stack_size) + if resume_fn: self._graph.pycode_gen.gen_load_object( resume_fn, resume_fn.__code__.co_name @@ -1697,30 +1704,12 @@ def _break_graph_in_call( def transform(self): self.run() - if self.new_code is None: + if self.new_code is self.empty_code: raise InnerError("OpExecutor return a empty new_code.") - # stopped by RETURN_VALUE and has sir len is enough => disable_eval_frame - simulate_complete = bool(self.stop_state == "Return") - if simulate_complete: - if self._graph.sir_ctx.TOS.graph_size() < ENV_MIN_GRAPH_SIZE.get(): - raise FallbackError( - "Fallback after simulate for reasons.", - disable_eval_frame=True, - ) - else: - # if simulate stop with graph successfully, the all codes will be - # surrounded by the eval_frame triggers which exist in self.new_code - # we need not set disable_eval_frame=False here (for it already is) - return ( - CustomCode(self.new_code, True), - self.guard_fn, - ) - else: - # if return because breakgraph, need open eval_frame - return ( - CustomCode(self.new_code, False), - self.guard_fn, - ) + return ( + CustomCode(self.new_code, self.new_code is None), + self.guard_fn, + ) def _gen_loop_body_between( self, inputs: list, for_iter_idx: int, start: int, end: int @@ -1837,9 +1826,9 @@ def _break_graph_in_for_loop( log(3, "[Resumed Function]: break graph in loop create loop body as\n") log_do(3, lambda: dis.dis(loop_body_fn)) - # 0.3 create after loop part function + # 0.3 create after loop part function, minus 1 for iterator after_loop_fn, fn_inputs = self._create_resume_fn( - loop_body_end_idx, len(self.stack) + loop_body_end_idx, len(self.stack) - 1 ) total_inputs = OrderedSet(list(fn_inputs) + list(loop_body_inputs[:-1])) @@ -1850,23 +1839,17 @@ def _break_graph_in_for_loop( for name in total_inputs if name in chain(self._locals, self._cells) ] - ret_vars = [self.get_var(name) for name in ret_names] - store_vars = [ret_vars[idx] for idx in range(len(ret_names))] - store_vars.extend(iter(self.stack)) - store_vars.append(iterator.get_hold()) - var_loader = self._graph.start_compile_with_name_store( - ret_vars, store_vars - ) - for _ in ret_vars: - self._graph.pycode_gen.gen_pop_top() + var_loader = self.gen_compute_in_break_with_name_store( + ret_names, self.indexof(for_iter) + ) - # 2. restore vars - for idx in range(len(ret_names)): - var_loader.load(ret_vars[idx]) - self._graph.pycode_gen.gen_store(ret_names[idx], self._code) + # 2. restore vars with origin name + for name in ret_names: + var_loader.load(self.get_var(name)) + self._graph.pycode_gen.gen_store(name, self._code) - # 3. setup vars which is created in loop + # 3. setup vars which is created in loop as Undefind undefined_names = set() for name in loop_body_inputs[:-1]: if not self.has_var(name, all_used_vars[name]): @@ -1874,12 +1857,9 @@ def _break_graph_in_for_loop( self._graph.pycode_gen.gen_load_const(SotUndefinedVar()) self._graph.pycode_gen.gen_store(name, self._code) - # close eval_frame - # TODO: need support effective strategies - # self._graph.pycode_gen.gen_disable_eval_frame() - # 4.1 load iterator - iterator.reconstruct(self._graph.pycode_gen) + var_loader.load(iterator) + self.stack.pop() # 4.2 gen FOR_ITER and unpack data self._graph.pycode_gen.extend_instrs( @@ -1923,10 +1903,6 @@ def _break_graph_in_for_loop( for_iter.jump_to = nop jump_if_break.jump_to = nop - # open eval_frame - # TODO: need support effective strategies - # self._graph.pycode_gen.gen_enable_eval_frame() - # 8. call after_loop_fn self._graph.pycode_gen.gen_load_object( after_loop_fn, after_loop_fn.__code__.co_name @@ -2045,17 +2021,20 @@ def FOR_ITER(self, instr): try: if not isinstance(iterator, SequenceIterVariable): - raise BreakGraphError() + raise BreakGraphError( + f"Can not simulate iterator of {type(iterator)}." + ) backup_iter_idx = iterator.idx self._inline_call_for_loop(iterator, instr) self._lasti = self.indexof(instr.jump_to) except BreakGraphError as e: - log(3, f"{e}") + log(3, f"[FOR_ITER] sim for loop failed for: {e}\n") if backup_iter_idx: iterator.idx = backup_iter_idx self._graph.remove_global_guarded_variable(iterator) + self.stack.push(iterator) self._break_graph_in_for_loop(iterator, instr) return Stop(state="BreakGraph") @@ -2064,8 +2043,12 @@ def RETURN_VALUE(self, instr: Instruction): len(self.stack) == 1 ), f"Stack must have one element, but get {len(self.stack)} elements." ret_val = self.stack.pop() - self._graph.start_compile(ret_val) - self._graph.pycode_gen.gen_return() - self.new_code = self._graph.pycode_gen.gen_pycode() + if self._graph.sir_ctx.TOS.graph_size() < ENV_MIN_GRAPH_SIZE.get(): + py_codegen = PyCodeGen(self._frame) + self.new_code = py_codegen.replace_null_variable() + else: + self._graph.start_compile(ret_val) + self._graph.pycode_gen.gen_return() + self.new_code = self._graph.pycode_gen.gen_pycode() self.guard_fn = self._graph.guard_fn return Stop(state="Return") diff --git a/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py index 3e2032dcc3a80..29764afdca4eb 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py +++ b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py @@ -433,6 +433,7 @@ def __init__( self._f_globals = frame.f_globals self._instructions = [] self.disable_eval_frame = disable_eval_frame + self.hooks = [] if self.disable_eval_frame: self.gen_disable_eval_frame() @@ -493,16 +494,21 @@ def gen_pycode(self) -> types.CodeType: Returns: CodeType: The generated code object. """ + for hook in self.hooks: + hook() + self.hooks.clear() + self.insert_prefix_instructions() modify_instrs(self._instructions) modify_vars(self._instructions, self._code_options) new_code = gen_new_opcode( self._instructions, self._code_options, PYCODE_ATTRIBUTES ) + return new_code def gen_resume_fn_at( - self, index: int, stack_size: int = 0 + self, index: int, stack_size: int ) -> tuple[None | types.FunctionType, OrderedSet[str]]: """ Generates a resume function at the specified index in the instruction list. @@ -515,6 +521,7 @@ def gen_resume_fn_at( tuple: The resume function object and the inputs to the function. """ + self._instructions = get_instructions(self._origin_code) # TODO(dev): could give an example code here? if self._instructions[index].opname == 'RETURN_VALUE': @@ -522,6 +529,7 @@ def gen_resume_fn_at( inputs = analysis_inputs(self._instructions, index) fn_name = ResumeFnNameFactory().next() stack_arg_str = fn_name + '_stack_{}' + self._instructions = ( [ gen_instr('LOAD_FAST', argval=stack_arg_str.format(i)) @@ -538,13 +546,12 @@ def gen_resume_fn_at( + list(inputs) + [ var_name - for var_name in self._origin_code.co_varnames + for var_name in self._code_options['co_varnames'] if var_name not in inputs ] ) self.update_code_name(fn_name, is_resumed_fn=True) - new_code = self.gen_pycode() if len(new_code.co_freevars) + len(new_code.co_cellvars) > 0: raise FallbackError("Break graph in closure is not support.") @@ -1019,12 +1026,6 @@ def gen_return(self): def gen_get_iter(self): self._add_instr("GET_ITER") - def add_pure_instructions(self, instructions): - """ - add instructions and do nothing. - """ - self._instructions.extend(instructions) - def _add_instr(self, *args, **kwargs): instr = gen_instr(*args, **kwargs) self._instructions.append(instr) @@ -1062,8 +1063,17 @@ def replace_null_variable(self): ): has_null_variable = True self._frame.f_locals[instr.argval].reconstruct(self) + elif ( + instr.opname == 'LOAD_GLOBAL' + and instr.argval in self._frame.f_globals.keys() + and isinstance( + self._frame.f_globals[instr.argval], NullVariable + ) + ): + has_null_variable = True + self._frame.f_globals[instr.argval].reconstruct(self) else: - self.add_pure_instructions([instr]) + self.extend_instrs([instr]) if has_null_variable: new_code = self.gen_pycode() diff --git a/python/paddle/jit/sot/opcode_translator/executor/variable_stack.py b/python/paddle/jit/sot/opcode_translator/executor/variable_stack.py index e7389de5b8805..e7dec76fbea78 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variable_stack.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variable_stack.py @@ -206,6 +206,9 @@ def top(self, value): assert len(self) > 0, "stack is empty" self.peek[1] = value + def __contains__(self, value): + return value in self._data + def __iter__(self): return iter(self._data) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py b/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py index f39343acec358..ecc2e3216f7e4 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py @@ -266,7 +266,7 @@ def __init__( def call_function(self, /, *args, **kwargs): if is_break_graph_tensor_methods(self.method_name): - raise BreakGraphError() + raise BreakGraphError("call break_graph_tensor_method.") return self.graph.call_tensor_method(self.method_name, *args, **kwargs) def bind(self, instance: VariableBase, name: str): diff --git a/python/paddle/jit/sot/utils/__init__.py b/python/paddle/jit/sot/utils/__init__.py index f89adaeb089de..02fc91e62873b 100644 --- a/python/paddle/jit/sot/utils/__init__.py +++ b/python/paddle/jit/sot/utils/__init__.py @@ -21,6 +21,7 @@ ENV_STRICT_MODE, cost_model_guard, strict_mode_guard, + min_graph_size_guard, ) from .exceptions import ( # noqa: F401 BreakGraphError, diff --git a/python/paddle/jit/sot/utils/envs.py b/python/paddle/jit/sot/utils/envs.py index 303e3af2a20f3..a7d8ceafb7f0c 100644 --- a/python/paddle/jit/sot/utils/envs.py +++ b/python/paddle/jit/sot/utils/envs.py @@ -41,3 +41,9 @@ def cost_model_guard(value: bool): def strict_mode_guard(value: bool): with EnvironmentVariableGuard(ENV_STRICT_MODE, value): yield + + +@contextmanager +def min_graph_size_guard(value: int): + with EnvironmentVariableGuard(ENV_MIN_GRAPH_SIZE, value): + yield diff --git a/test/sot/test_min_graph_size.py b/test/sot/test_min_graph_size.py new file mode 100644 index 0000000000000..04a90f326d855 --- /dev/null +++ b/test/sot/test_min_graph_size.py @@ -0,0 +1,76 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# GET_ITER (new) +# FOR_ITER (new) + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit import sot +from paddle.jit.sot.utils import min_graph_size_guard + + +def case_for(x, vars): + x = x + 1 + sot.psdb.breakgraph() + for y in vars: + x += y + return x + + +def case_if(x): + x = x + 1 + if x > 5: + x += 3 + else: + x += 4 + return x + + +def case_call(x): + y = paddle.to_tensor(x.numpy()) + x += y + return x + + +def case_all(x, vars): + x = x + 1 + for y in vars: + z = paddle.to_tensor(x.numpy()) + x += z + x += y + if x > 5: + x += y + else: + x += 3 + return x + + +class TestMinGraphSize(TestCaseBase): + @min_graph_size_guard(10) + def test_cases(self): + x = paddle.to_tensor(1) + self.assert_results(case_for, x, [1, 2, 3]) + self.assert_results(case_if, x) + self.assert_results(case_call, x) + self.assert_results(case_all, x, [4, 5, 6]) + + +if __name__ == "__main__": + unittest.main()