Skip to content

Commit

Permalink
[SOT] fix load null in resume function (#59297)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifei-111 authored Nov 28, 2023
1 parent 7896c17 commit a7d2fc7
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from ..custom_code import CustomCode
from .guard import Guard
from .opcode_executor import OpcodeExecutor, OpcodeExecutorBase
from .pycode_generator import PyCodeGen

GuardedFunction = Tuple[CustomCode, Guard]
GuardedFunctions = List[GuardedFunction]
Expand Down Expand Up @@ -213,16 +212,12 @@ def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction:
f"Unsupport Frame is {frame.f_code}, error message is: \n"
+ "".join(traceback.format_exception(type(e), e, e.__traceback__)),
)

# NOTE: If resume fn need fallback, we should replace NullVariable using NULL otherwise will fail to run
py_codegen = PyCodeGen(frame)
new_code = py_codegen.replace_null_variable()
# simulation not complete, not sure whether this code has sir, set disable_eval_frame = False
guard_fn = (
dummy_guard if e.disable_eval_frame is False else simulator.guard_fn
)
return (
CustomCode(new_code, e.disable_eval_frame),
CustomCode(None, e.disable_eval_frame),
guard_fn,
)
except Exception as e:
Expand Down
51 changes: 8 additions & 43 deletions python/paddle/jit/sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,17 +248,11 @@ def __init__(self, store_var_info, pycode_gen):
self._store_var_info = store_var_info
self._pycode_gen: PyCodeGen = pycode_gen

def load(self, var, allow_push_null=True):
def load(self, var):
if isinstance(var, NullVariable):
# PUSH_NULL is an opcode
if allow_push_null:
var.reconstruct(self._pycode_gen)
else:
# Avoid passing NULL as a parameter to the resume function
self._pycode_gen.gen_load_null_variable()
var.reconstruct(self._pycode_gen)
return
# only restored vars in stack, so used gen_load to process global var
self._pycode_gen.gen_load(self._store_var_info[var])
self._pycode_gen.gen_load(self._store_var_info[var.id])

origin_instrs = get_instructions(self.pycode_gen._origin_code)

Expand All @@ -270,30 +264,7 @@ def load(self, var, allow_push_null=True):
if restore_instr_names[-2:] == ["KW_NAMES", "PRECALL"]:
restore_instrs = restore_instrs[:-2]

for instr in restore_instrs:
if (
instr.opname == 'LOAD_FAST'
and instr.argval in self.pycode_gen._frame.f_locals.keys()
and isinstance(
self.pycode_gen._frame.f_locals[instr.argval], NullVariable
)
):
self.pycode_gen._frame.f_locals[instr.argval].reconstruct(
self.pycode_gen
)
elif (
instr.opname == 'LOAD_GLOBAL'
and instr.argval in self.pycode_gen._frame.f_globals.keys()
and isinstance(
self.pycode_gen._frame.f_globals[instr.argval], NullVariable
)
):
self.pycode_gen._frame.f_globals[instr.argval].reconstruct(
self.pycode_gen
)
else:
self.pycode_gen.extend_instrs([instr])

self.pycode_gen.extend_instrs(restore_instrs)
nop = self.pycode_gen._add_instr("NOP")

for instr in origin_instrs:
Expand All @@ -311,8 +282,8 @@ def load(self, var, allow_push_null=True):
name_gen = NameGenerator("__start_compile_saved_orig_")

for var in stack_vars[::-1]:
store_var_info[var] = name_gen.next()
self.pycode_gen.gen_store_fast(store_var_info[var])
store_var_info[var.id] = name_gen.next()
self.pycode_gen.gen_store_fast(store_var_info[var.id])

return VariableLoader(store_var_info, self.pycode_gen)

Expand All @@ -324,15 +295,9 @@ def __init__(self, index_for_load, pycode_gen):

def load(self, var, allow_push_null=True):
if isinstance(var, NullVariable):
# PUSH_NULL is an opcode
if allow_push_null:
var.reconstruct(self._pycode_gen)
else:
# Avoid passing NULL as a parameter to the resume function
self._pycode_gen.gen_load_null_variable()
var.reconstruct(self._pycode_gen)
return
# all vars to be load are saved by this function, so load_fast is correct
self._pycode_gen.gen_load_fast(self._index_for_load[var.id])
self._pycode_gen.gen_load(self._index_for_load[var.id])

# var_id -> local_name mapping
index_for_load = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1534,7 +1534,7 @@ def gen_compute_in_break_with_name_store(self, restore_names, instr_idx):
for name in restore_names:
_var = self.get_var(name)
if _var not in self.stack:
store_var_info[_var] = name
store_var_info[_var.id] = name
return self._graph._restore_origin_opcode(
list(self.stack), store_var_info, instr_idx
)
Expand Down Expand Up @@ -1601,9 +1601,7 @@ def _break_graph_in_jump(self, result: TensorVariable, instr: Instruction):
)
insert_index = len(self._graph.pycode_gen._instructions) - 1
for i, stack_arg in enumerate(self.stack):
var_loader.load(
stack_arg, allow_push_null=i >= len(self.stack) - 1
)
var_loader.load(stack_arg)
for name in if_inputs:
var_loader.load(self.get_var(name))
self._graph.pycode_gen.gen_call_function(
Expand All @@ -1620,9 +1618,7 @@ def _break_graph_in_jump(self, result: TensorVariable, instr: Instruction):
)
jump_to = self._graph.pycode_gen._instructions[-1]
for i, stack_arg in enumerate(self.stack):
var_loader.load(
stack_arg, allow_push_null=i >= len(self.stack) - 1
)
var_loader.load(stack_arg)
for name in else_inputs:
var_loader.load(self.get_var(name))
self._graph.pycode_gen.gen_call_function(
Expand Down Expand Up @@ -1674,9 +1670,7 @@ def _break_graph_in_call(
pop_n = push_n - stack_effect

for i, stack_arg in enumerate(self.stack):
var_loader.load(
stack_arg, allow_push_null=i >= len(self.stack) - pop_n
)
var_loader.load(stack_arg)

# gen call resume fn opcode
# NOTE(SigureMo): In Python 3.11,we need generate KW_NAMES if the call shape is not None.
Expand Down Expand Up @@ -2049,8 +2043,7 @@ def RETURN_VALUE(self, instr: Instruction):
), f"Stack must have one element, but get {len(self.stack)} elements."
ret_val = self.stack.pop()
if self._graph.sir_ctx.TOS.graph_size() < ENV_MIN_GRAPH_SIZE.get():
py_codegen = PyCodeGen(self._frame)
self.new_code = py_codegen.replace_null_variable()
self.new_code = None
else:
self._graph.start_compile(ret_val)
self._graph.pycode_gen.gen_return()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -827,20 +827,6 @@ def gen_import_name(self, name: str):
idx = self._code_options["co_names"].index(name)
self._add_instr("IMPORT_NAME", arg=idx, argval=name)

def gen_push_null(self):
if sys.version_info >= (3, 11):
self._add_instr("PUSH_NULL")
else:
# There is no PUSH_NULL bytecode before python3.11, so we push
# a NULL element to the stack through the following bytecode
self.gen_load_const(0)
self.gen_load_const(None)
self.gen_import_name('sys')
self.gen_store_fast('sys')
self.gen_load_fast('sys')
self.gen_load_method('getsizeof')
self.gen_pop_top()

def gen_store_fast(self, name):
if name not in self._code_options["co_varnames"]:
self._code_options["co_varnames"].append(name)
Expand Down Expand Up @@ -1064,40 +1050,3 @@ def extend_instrs(self, instrs):

def pop_instr(self):
self._instructions.pop()

def replace_null_variable(self):
"""
Replace all NullVariables in the bytecode.
Returns:
Optional[Tuple[Any, Callable]]: The new code object and its guard function, or None if no dummy variables are found.
"""
from .variables.basic import NullVariable

instructions = get_instructions(self._origin_code)
has_null_variable = False
for instr in instructions:
if (
instr.opname == 'LOAD_FAST'
and instr.argval in self._frame.f_locals.keys()
and isinstance(self._frame.f_locals[instr.argval], NullVariable)
):
has_null_variable = True
self._frame.f_locals[instr.argval].reconstruct(self)
elif (
instr.opname == 'LOAD_GLOBAL'
and instr.argval in self._frame.f_globals.keys()
and isinstance(
self._frame.f_globals[instr.argval], NullVariable
)
):
has_null_variable = True
self._frame.f_globals[instr.argval].reconstruct(self)
else:
self.extend_instrs([instr])

if has_null_variable:
new_code = self.gen_pycode()
return new_code
else:
return None
Original file line number Diff line number Diff line change
Expand Up @@ -785,8 +785,13 @@ def __init__(self):
# TODO: graph should be not None
super().__init__(None, DanglingTracker())

def __call__(self, *args, **kwargs):
func = args[0]
assert callable(func)
return func(*args[1:], **kwargs)

def reconstruct(self, codegen: PyCodeGen):
codegen.gen_push_null()
codegen.gen_load_null_variable()


class CellVariable(VariableBase):
Expand Down

0 comments on commit a7d2fc7

Please sign in to comment.