Skip to content

Commit

Permalink
reverted format changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Eliasj42 committed Feb 21, 2024
1 parent d21fbcb commit 685fd80
Showing 1 changed file with 15 additions and 45 deletions.
60 changes: 15 additions & 45 deletions core/shark_turbine/aot/builtins/jittable.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,7 @@ def resolver(py_value: Any, gni: GraphNodeImporter) -> Optional[Value]:
# legal). Note that the merger will ignore these since they already
# exist in the target module.
if materialized_global.symbol_name not in cloned_global_symbols:
materialized_global.global_op.operation.clone(
ip=gni.fx_importer._m_ip
)
materialized_global.global_op.operation.clone(ip=gni.fx_importer._m_ip)
cloned_global_symbols.add(materialized_global.symbol_name)

# Emit a global load and conversion.
Expand Down Expand Up @@ -168,9 +166,7 @@ def __init__(
self.constraints = constraints
self.decomposition_table = decomposition_table
self.wrapped_f = wrapped_f
self.function_name = (
function_name if function_name else wrapped_f.__name__
)
self.function_name = function_name if function_name else wrapped_f.__name__
self._passes = set(passes)
for p in passes:
if p not in ALL_PASSES:
Expand Down Expand Up @@ -200,9 +196,7 @@ def resolve_call(
flat_pytorch_args = []
flat_ir_args = []
for py_arg in flat_py_args:
ir_arg, pytorch_arg = self._split_py_arg(
py_arg, constraints=constraints
)
ir_arg, pytorch_arg = self._split_py_arg(py_arg, constraints=constraints)
flat_ir_args.append(ir_arg)
flat_pytorch_args.append(pytorch_arg)

Expand All @@ -220,9 +214,7 @@ def flat_wrapped_f(*args):
# Run pre-processing passes.
transformed_f = flat_wrapped_f
if "functorch_functionalize" in self._passes:
transformed_f = functorch_functionalize(
transformed_f, *flat_pytorch_args
)
transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args)

for node in transformed_f.graph.nodes:
if node.op == "call_function":
Expand Down Expand Up @@ -258,14 +250,10 @@ def flat_wrapped_f(*args):
fx_importer = FxImporter(
context=proc_trace.context,
config_check=False,
literal_resolver_callback=_make_literal_resolver(
proc_trace.module_builder
),
literal_resolver_callback=_make_literal_resolver(proc_trace.module_builder),
py_attr_tracker=proc_trace.module_builder.fx_py_attr_tracker,
)
fx_importer.import_stateless_graph(
gm.graph, func_name=self.function_name
)
fx_importer.import_stateless_graph(gm.graph, func_name=self.function_name)

# TODO: Real debugging options
# print(fx_importer.module, file=sys.stderr)
Expand Down Expand Up @@ -322,17 +310,11 @@ def flat_wrapped_f(*args):

assert len(flat_ir_results) == len(result_tensor_infos)
flat_py_results = []
for ir_result, result_tensor_info in zip(
flat_ir_results, result_tensor_infos
):
for ir_result, result_tensor_info in zip(flat_ir_results, result_tensor_infos):
(dtype,) = result_tensor_info
native_ir_result = type_converter.materialize_torch_to_native(
ir_result
)
native_ir_result = type_converter.materialize_torch_to_native(ir_result)
if dtype is not None:
flat_py_results.append(
IrImmediateTensor(native_ir_result, dtype)
)
flat_py_results.append(IrImmediateTensor(native_ir_result, dtype))
else:
raise TypeError(
f"Unknown PyTorch->IREE value mapping for jittable result: {result_tensor_info}->{native_ir_result}"
Expand All @@ -341,9 +323,7 @@ def flat_wrapped_f(*args):
tree_py_results = tree_unflatten(flat_py_results, out_spec)
return tree_py_results

def _split_py_arg(
self, arg, constraints: List[Constraint]
) -> Tuple[Value, Any]:
def _split_py_arg(self, arg, constraints: List[Constraint]) -> Tuple[Value, Any]:
if isinstance(arg, IrTensor):
meta_tensor, meta_constraints = arg._to_meta_tensor()
constraints.extend(meta_constraints)
Expand Down Expand Up @@ -388,9 +368,7 @@ def merge(self) -> Optional[Operation]:
imported_func_op: Optional[Operation] = None

# Import functions.
func_ops = _get_top_level_ops(
self.from_module_op, func_d.FuncOp.OPERATION_NAME
)
func_ops = _get_top_level_ops(self.from_module_op, func_d.FuncOp.OPERATION_NAME)
for func_op in func_ops:
# Pre-rename, check if it is the one we are looking for.
func_name = _get_symbol_name(func_op)
Expand All @@ -406,9 +384,7 @@ def merge(self) -> Optional[Operation]:
for from_symbol, to_symbol in self.rename_map.items():
from_name = StringAttr(from_symbol).value
to_name = StringAttr(to_symbol).value
SymbolTable.replace_all_symbol_uses(
from_name, to_name, sym_operation
)
SymbolTable.replace_all_symbol_uses(from_name, to_name, sym_operation)

return imported_func_op

Expand All @@ -418,9 +394,7 @@ def import_symbol_op(self, symbol_op):
orig_symbol = SymbolTable.get_symbol_name(symbol_op)
orig_symbol_name = StringAttr(orig_symbol).value
# Make sure it is unique.
new_symbol_name = _uniqueify_name(
orig_symbol_name, target_symbol_table
)
new_symbol_name = _uniqueify_name(orig_symbol_name, target_symbol_table)
if new_symbol_name != orig_symbol_name:
SymbolTable.set_symbol_name(symbol_op, new_symbol_name)
self._rename(orig_symbol, new_symbol_name)
Expand All @@ -429,9 +403,7 @@ def import_symbol_op(self, symbol_op):
self.nested_symbol_ops.append(symbol_op)
target_symbol_table.insert(symbol_op)

def _rename(
self, from_symbol: StringAttrOrStr, to_symbol: StringAttrOrStr
):
def _rename(self, from_symbol: StringAttrOrStr, to_symbol: StringAttrOrStr):
from_symbol = self._make_string_attr(from_symbol)
to_symbol = self._make_string_attr(to_symbol)
if from_symbol != to_symbol:
Expand All @@ -445,9 +417,7 @@ def _make_string_attr(self, string_attr_or_str: StringAttrOrStr):
return StringAttr(string_attr_or_str)


def _get_top_level_ops(
module_op: Operation, *op_names: str
) -> Sequence[Operation]:
def _get_top_level_ops(module_op: Operation, *op_names: str) -> Sequence[Operation]:
results = []
for op_view in module_op.regions[0].blocks[0]:
op = op_view.operation
Expand Down

0 comments on commit 685fd80

Please sign in to comment.