Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Eliasj42 committed Feb 20, 2024
1 parent c1932d1 commit d21fbcb
Showing 1 changed file with 57 additions and 17 deletions.
74 changes: 57 additions & 17 deletions core/shark_turbine/aot/builtins/jittable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,17 @@

"""Tracing builtins."""

from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
)

import torch
from torch._decomp import get_decompositions
Expand Down Expand Up @@ -97,7 +107,9 @@ def resolver(py_value: Any, gni: GraphNodeImporter) -> Optional[Value]:
# legal). Note that the merger will ignore these since they already
# exist in the target module.
if materialized_global.symbol_name not in cloned_global_symbols:
materialized_global.global_op.operation.clone(ip=gni.fx_importer._m_ip)
materialized_global.global_op.operation.clone(
ip=gni.fx_importer._m_ip
)
cloned_global_symbols.add(materialized_global.symbol_name)

# Emit a global load and conversion.
Expand Down Expand Up @@ -156,7 +168,9 @@ def __init__(
self.constraints = constraints
self.decomposition_table = decomposition_table
self.wrapped_f = wrapped_f
self.function_name = function_name if function_name else wrapped_f.__name__
self.function_name = (
function_name if function_name else wrapped_f.__name__
)
self._passes = set(passes)
for p in passes:
if p not in ALL_PASSES:
Expand Down Expand Up @@ -186,7 +200,9 @@ def resolve_call(
flat_pytorch_args = []
flat_ir_args = []
for py_arg in flat_py_args:
ir_arg, pytorch_arg = self._split_py_arg(py_arg, constraints=constraints)
ir_arg, pytorch_arg = self._split_py_arg(
py_arg, constraints=constraints
)
flat_ir_args.append(ir_arg)
flat_pytorch_args.append(pytorch_arg)

Expand All @@ -204,14 +220,16 @@ def flat_wrapped_f(*args):
# Run pre-processing passes.
transformed_f = flat_wrapped_f
if "functorch_functionalize" in self._passes:
transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args)
transformed_f = functorch_functionalize(
transformed_f, *flat_pytorch_args
)

for node in transformed_f.graph.nodes:
if node.op == "call_function":
if node.target == torch._ops.ops.aten.lift_fresh_copy.default:
node.target = torch._ops.ops.aten.clone.default
transformed_f.recompile()

# Ask dynamo to give us an aten graph.
# TODO: Cache this for repeated calls.
logger.debug("Performing dynamo.export(constraints=%r)", constraints)
Expand Down Expand Up @@ -240,10 +258,14 @@ def flat_wrapped_f(*args):
fx_importer = FxImporter(
context=proc_trace.context,
config_check=False,
literal_resolver_callback=_make_literal_resolver(proc_trace.module_builder),
literal_resolver_callback=_make_literal_resolver(
proc_trace.module_builder
),
py_attr_tracker=proc_trace.module_builder.fx_py_attr_tracker,
)
fx_importer.import_stateless_graph(gm.graph, func_name=self.function_name)
fx_importer.import_stateless_graph(
gm.graph, func_name=self.function_name
)

# TODO: Real debugging options
# print(fx_importer.module, file=sys.stderr)
Expand Down Expand Up @@ -300,11 +322,17 @@ def flat_wrapped_f(*args):

assert len(flat_ir_results) == len(result_tensor_infos)
flat_py_results = []
for ir_result, result_tensor_info in zip(flat_ir_results, result_tensor_infos):
for ir_result, result_tensor_info in zip(
flat_ir_results, result_tensor_infos
):
(dtype,) = result_tensor_info
native_ir_result = type_converter.materialize_torch_to_native(ir_result)
native_ir_result = type_converter.materialize_torch_to_native(
ir_result
)
if dtype is not None:
flat_py_results.append(IrImmediateTensor(native_ir_result, dtype))
flat_py_results.append(
IrImmediateTensor(native_ir_result, dtype)
)
else:
raise TypeError(
f"Unknown PyTorch->IREE value mapping for jittable result: {result_tensor_info}->{native_ir_result}"
Expand All @@ -313,7 +341,9 @@ def flat_wrapped_f(*args):
tree_py_results = tree_unflatten(flat_py_results, out_spec)
return tree_py_results

def _split_py_arg(self, arg, constraints: List[Constraint]) -> Tuple[Value, Any]:
def _split_py_arg(
self, arg, constraints: List[Constraint]
) -> Tuple[Value, Any]:
if isinstance(arg, IrTensor):
meta_tensor, meta_constraints = arg._to_meta_tensor()
constraints.extend(meta_constraints)
Expand Down Expand Up @@ -358,7 +388,9 @@ def merge(self) -> Optional[Operation]:
imported_func_op: Optional[Operation] = None

# Import functions.
func_ops = _get_top_level_ops(self.from_module_op, func_d.FuncOp.OPERATION_NAME)
func_ops = _get_top_level_ops(
self.from_module_op, func_d.FuncOp.OPERATION_NAME
)
for func_op in func_ops:
# Pre-rename, check if it is the one we are looking for.
func_name = _get_symbol_name(func_op)
Expand All @@ -374,7 +406,9 @@ def merge(self) -> Optional[Operation]:
for from_symbol, to_symbol in self.rename_map.items():
from_name = StringAttr(from_symbol).value
to_name = StringAttr(to_symbol).value
SymbolTable.replace_all_symbol_uses(from_name, to_name, sym_operation)
SymbolTable.replace_all_symbol_uses(
from_name, to_name, sym_operation
)

return imported_func_op

Expand All @@ -384,7 +418,9 @@ def import_symbol_op(self, symbol_op):
orig_symbol = SymbolTable.get_symbol_name(symbol_op)
orig_symbol_name = StringAttr(orig_symbol).value
# Make sure it is unique.
new_symbol_name = _uniqueify_name(orig_symbol_name, target_symbol_table)
new_symbol_name = _uniqueify_name(
orig_symbol_name, target_symbol_table
)
if new_symbol_name != orig_symbol_name:
SymbolTable.set_symbol_name(symbol_op, new_symbol_name)
self._rename(orig_symbol, new_symbol_name)
Expand All @@ -393,7 +429,9 @@ def import_symbol_op(self, symbol_op):
self.nested_symbol_ops.append(symbol_op)
target_symbol_table.insert(symbol_op)

def _rename(self, from_symbol: StringAttrOrStr, to_symbol: StringAttrOrStr):
def _rename(
self, from_symbol: StringAttrOrStr, to_symbol: StringAttrOrStr
):
from_symbol = self._make_string_attr(from_symbol)
to_symbol = self._make_string_attr(to_symbol)
if from_symbol != to_symbol:
Expand All @@ -407,7 +445,9 @@ def _make_string_attr(self, string_attr_or_str: StringAttrOrStr):
return StringAttr(string_attr_or_str)


def _get_top_level_ops(module_op: Operation, *op_names: str) -> Sequence[Operation]:
def _get_top_level_ops(
module_op: Operation, *op_names: str
) -> Sequence[Operation]:
results = []
for op_view in module_op.regions[0].blocks[0]:
op = op_view.operation
Expand Down

0 comments on commit d21fbcb

Please sign in to comment.