Skip to content

Commit

Permalink
[SOT] Compile with graph size check (#58538)
Browse files Browse the repository at this point in the history
* update

* update

* update

* update

* update

* update

* add test

* update

* update
  • Loading branch information
feifei-111 authored Nov 2, 2023
1 parent e54e7aa commit a827e97
Show file tree
Hide file tree
Showing 8 changed files with 279 additions and 115 deletions.
91 changes: 88 additions & 3 deletions python/paddle/jit/sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
map_if,
tmp_name_guard,
)
from ..instruction_utils import get_instructions
from .guard import Guard, StringifyExpression, make_guard
from .mutable_data import MutationDel, MutationNew, MutationSet
from .pycode_generator import PyCodeGen
Expand Down Expand Up @@ -241,20 +242,88 @@ def guard_fn(self) -> Guard:

return make_guard(guards)

def start_compile_with_name_store(self, ret_vars, to_store_vars):
def _restore_origin_opcode(self, stack_vars, store_var_info, instr_idx):
class VariableLoader:
def __init__(self, store_var_info, pycode_gen):
self._store_var_info = store_var_info
self._pycode_gen: PyCodeGen = pycode_gen

def load(self, var, allow_push_null=True):
if isinstance(var, NullVariable):
# PUSH_NULL is an opcode
if allow_push_null:
var.reconstruct(self._pycode_gen)
else:
# Avoid passing NULL as a parameter to the resume function
self._pycode_gen.gen_load_null_variable()
return
# only restored vars in stack, so used gen_load to process global var
self._pycode_gen.gen_load(self._store_var_info[var])

origin_instr = get_instructions(self.pycode_gen._origin_code)

for instr in origin_instr[0:instr_idx]:
if (
instr.opname == 'LOAD_FAST'
and instr.argval in self.pycode_gen._frame.f_locals.keys()
and isinstance(
self.pycode_gen._frame.f_locals[instr.argval], NullVariable
)
):
self.pycode_gen._frame.f_locals[instr.argval].reconstruct(
self.pycode_gen
)
elif (
instr.opname == 'LOAD_GLOBAL'
and instr.argval in self.pycode_gen._frame.f_globals.keys()
and isinstance(
self.pycode_gen._frame.f_globals[instr.argval], NullVariable
)
):
self.pycode_gen._frame.f_globals[instr.argval].reconstruct(
self.pycode_gen
)
else:
self.pycode_gen.extend_instrs([instr])

nop = self.pycode_gen._add_instr("NOP")

for instr in origin_instr:
if instr.jump_to == origin_instr[instr_idx]:
instr.jump_to = nop

self.pycode_gen.hooks.append(
lambda: self.pycode_gen.extend_instrs(
iter(origin_instr[instr_idx + 1 :])
)
)

self.pycode_gen.gen_enable_eval_frame()

name_gen = NameGenerator("__start_compile_saved_orig_")

for var in stack_vars[::-1]:
store_var_info[var] = name_gen.next()
self.pycode_gen.gen_store_fast(store_var_info[var])

return VariableLoader(store_var_info, self.pycode_gen)

def _build_compile_fn_with_name_store(self, ret_vars, to_store_vars):
class VariableLoader:
def __init__(self, index_for_load, pycode_gen):
self._index_for_load = index_for_load
self._pycode_gen: PyCodeGen = pycode_gen

def load(self, var, allow_push_null=True):
if isinstance(var, NullVariable):
# PUSH_NULL is an opcode
if allow_push_null:
var.reconstruct(self._pycode_gen)
else:
# Avoid passing NULL as a parameter to the resume function
self._pycode_gen.gen_load_null_variable()
return
# all vars to be load are saved by this function, so load_fast is correct
self._pycode_gen.gen_load_fast(self._index_for_load[var.id])

# var_id -> local_name mapping
Expand All @@ -264,7 +333,8 @@ def load(self, var, allow_push_null=True):
)
self.start_compile(*(ret_vars + to_store_vars))
name_gen = NameGenerator("__start_compile_saved_")
for var in to_store_vars:

for var in to_store_vars[::-1]:
index_for_load[var.id] = name_gen.next()

def _log_fn():
Expand All @@ -275,8 +345,8 @@ def _log_fn():

log_do(4, _log_fn)

for var in to_store_vars[::-1]:
self.pycode_gen.gen_store_fast(index_for_load[var.id])

return VariableLoader(index_for_load, self.pycode_gen)

@event_register("start_compile", event_level=2)
Expand Down Expand Up @@ -552,13 +622,28 @@ def _find_tensor_outputs(
Args:
outputs: output variables
"""

def collect_related_dummy_tensor(var):
if isinstance(var.tracker, DummyTracker):
if isinstance(var, TensorVariable):
return [var]
else:
retval = []
for inp in var.tracker.inputs:
retval.extend(collect_related_dummy_tensor(inp))
return retval
return []

output_tensors: OrderedSet[TensorVariable] = OrderedSet()
# Find Tensor Variables from outputs.
for output in outputs:
if isinstance(output.tracker, DummyTracker):
if isinstance(output, TensorVariable):
output_tensors.add(output)
else:
for inp in output.tracker.inputs:
for _var in collect_related_dummy_tensor(inp):
output_tensors.add(_var)
# Guard output that can not be traced.
self.add_global_guarded_variable(output)
# Find Tensor Variables from side effects Variables.
Expand Down
Loading

0 comments on commit a827e97

Please sign in to comment.