From 685fd802612b2703c730fae2bb05ff64db35dbbf Mon Sep 17 00:00:00 2001 From: Elias Joseph Date: Tue, 20 Feb 2024 19:31:27 -0800 Subject: [PATCH] reverted format changes --- core/shark_turbine/aot/builtins/jittable.py | 60 ++++++--------------- 1 file changed, 15 insertions(+), 45 deletions(-) diff --git a/core/shark_turbine/aot/builtins/jittable.py b/core/shark_turbine/aot/builtins/jittable.py index f81bc57ba..58c9fa790 100644 --- a/core/shark_turbine/aot/builtins/jittable.py +++ b/core/shark_turbine/aot/builtins/jittable.py @@ -107,9 +107,7 @@ def resolver(py_value: Any, gni: GraphNodeImporter) -> Optional[Value]: # legal). Note that the merger will ignore these since they already # exist in the target module. if materialized_global.symbol_name not in cloned_global_symbols: - materialized_global.global_op.operation.clone( - ip=gni.fx_importer._m_ip - ) + materialized_global.global_op.operation.clone(ip=gni.fx_importer._m_ip) cloned_global_symbols.add(materialized_global.symbol_name) # Emit a global load and conversion. @@ -168,9 +166,7 @@ def __init__( self.constraints = constraints self.decomposition_table = decomposition_table self.wrapped_f = wrapped_f - self.function_name = ( - function_name if function_name else wrapped_f.__name__ - ) + self.function_name = function_name if function_name else wrapped_f.__name__ self._passes = set(passes) for p in passes: if p not in ALL_PASSES: @@ -200,9 +196,7 @@ def resolve_call( flat_pytorch_args = [] flat_ir_args = [] for py_arg in flat_py_args: - ir_arg, pytorch_arg = self._split_py_arg( - py_arg, constraints=constraints - ) + ir_arg, pytorch_arg = self._split_py_arg(py_arg, constraints=constraints) flat_ir_args.append(ir_arg) flat_pytorch_args.append(pytorch_arg) @@ -220,9 +214,7 @@ def flat_wrapped_f(*args): # Run pre-processing passes. transformed_f = flat_wrapped_f if "functorch_functionalize" in self._passes: - transformed_f = functorch_functionalize( - transformed_f, *flat_pytorch_args - ) + transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args) for node in transformed_f.graph.nodes: if node.op == "call_function": @@ -258,14 +250,10 @@ def flat_wrapped_f(*args): fx_importer = FxImporter( context=proc_trace.context, config_check=False, - literal_resolver_callback=_make_literal_resolver( - proc_trace.module_builder - ), + literal_resolver_callback=_make_literal_resolver(proc_trace.module_builder), py_attr_tracker=proc_trace.module_builder.fx_py_attr_tracker, ) - fx_importer.import_stateless_graph( - gm.graph, func_name=self.function_name - ) + fx_importer.import_stateless_graph(gm.graph, func_name=self.function_name) # TODO: Real debugging options # print(fx_importer.module, file=sys.stderr) @@ -322,17 +310,11 @@ def flat_wrapped_f(*args): assert len(flat_ir_results) == len(result_tensor_infos) flat_py_results = [] - for ir_result, result_tensor_info in zip( - flat_ir_results, result_tensor_infos - ): + for ir_result, result_tensor_info in zip(flat_ir_results, result_tensor_infos): (dtype,) = result_tensor_info - native_ir_result = type_converter.materialize_torch_to_native( - ir_result - ) + native_ir_result = type_converter.materialize_torch_to_native(ir_result) if dtype is not None: - flat_py_results.append( - IrImmediateTensor(native_ir_result, dtype) - ) + flat_py_results.append(IrImmediateTensor(native_ir_result, dtype)) else: raise TypeError( f"Unknown PyTorch->IREE value mapping for jittable result: {result_tensor_info}->{native_ir_result}" @@ -341,9 +323,7 @@ def flat_wrapped_f(*args): tree_py_results = tree_unflatten(flat_py_results, out_spec) return tree_py_results - def _split_py_arg( - self, arg, constraints: List[Constraint] - ) -> Tuple[Value, Any]: + def _split_py_arg(self, arg, constraints: List[Constraint]) -> Tuple[Value, Any]: if isinstance(arg, IrTensor): meta_tensor, meta_constraints = arg._to_meta_tensor() constraints.extend(meta_constraints) @@ -388,9 +368,7 @@ def merge(self) -> Optional[Operation]: imported_func_op: Optional[Operation] = None # Import functions. - func_ops = _get_top_level_ops( - self.from_module_op, func_d.FuncOp.OPERATION_NAME - ) + func_ops = _get_top_level_ops(self.from_module_op, func_d.FuncOp.OPERATION_NAME) for func_op in func_ops: # Pre-rename, check if it is the one we are looking for. func_name = _get_symbol_name(func_op) @@ -406,9 +384,7 @@ def merge(self) -> Optional[Operation]: for from_symbol, to_symbol in self.rename_map.items(): from_name = StringAttr(from_symbol).value to_name = StringAttr(to_symbol).value - SymbolTable.replace_all_symbol_uses( - from_name, to_name, sym_operation - ) + SymbolTable.replace_all_symbol_uses(from_name, to_name, sym_operation) return imported_func_op @@ -418,9 +394,7 @@ def import_symbol_op(self, symbol_op): orig_symbol = SymbolTable.get_symbol_name(symbol_op) orig_symbol_name = StringAttr(orig_symbol).value # Make sure it is unique. - new_symbol_name = _uniqueify_name( - orig_symbol_name, target_symbol_table - ) + new_symbol_name = _uniqueify_name(orig_symbol_name, target_symbol_table) if new_symbol_name != orig_symbol_name: SymbolTable.set_symbol_name(symbol_op, new_symbol_name) self._rename(orig_symbol, new_symbol_name) @@ -429,9 +403,7 @@ def import_symbol_op(self, symbol_op): self.nested_symbol_ops.append(symbol_op) target_symbol_table.insert(symbol_op) - def _rename( - self, from_symbol: StringAttrOrStr, to_symbol: StringAttrOrStr - ): + def _rename(self, from_symbol: StringAttrOrStr, to_symbol: StringAttrOrStr): from_symbol = self._make_string_attr(from_symbol) to_symbol = self._make_string_attr(to_symbol) if from_symbol != to_symbol: @@ -445,9 +417,7 @@ def _make_string_attr(self, string_attr_or_str: StringAttrOrStr): return StringAttr(string_attr_or_str) -def _get_top_level_ops( - module_op: Operation, *op_names: str -) -> Sequence[Operation]: +def _get_top_level_ops(module_op: Operation, *op_names: str) -> Sequence[Operation]: results = [] for op_view in module_op.regions[0].blocks[0]: op = op_view.operation