diff --git a/exir/backend/test/op_partitioner_demo.py b/exir/backend/test/op_partitioner_demo.py index e8ae1e85b2..dc20c03e68 100644 --- a/exir/backend/test/op_partitioner_demo.py +++ b/exir/backend/test/op_partitioner_demo.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import itertools from typing import Callable, Dict, final, List, Optional, Tuple import torch @@ -24,6 +25,7 @@ from executorch.exir.graph_module import get_control_flow_submodules from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param from torch.export import ExportedProgram +from torch.fx.passes.infra.partitioner import Partition from torch.fx.passes.operator_support import any_chain, OperatorSupportBase @@ -145,10 +147,12 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: @final class NonDecompTestPartitioner(Partitioner): """ - Partitions all add/mul nodes regardless of order + Non Decomp Test Partitioner, preserves aten ops from decomposition for delegate + consumption. Ensures that non_decomposed_edge_ops are all within their own delegate """ def __init__(self) -> None: + self.supported_non_decomposed_edge_ops = edge_ops_non_decomposed self.op_support = any_chain(OpsToNotDecomposeOperatorSupport()) self.delegation_spec = DelegationSpec( BackendWithCompilerDemo.__name__, @@ -171,14 +175,29 @@ def filter_ops(node: torch.fx.Node) -> bool: return (ops_not_to_decompose, filter_ops) + def _generate_single_node_partition( + self, gm: torch.fx.GraphModule + ) -> List[Partition]: + partitions = [] + partition_id = itertools.count() + nodes_seen = set() + for node in gm.graph.nodes: + if ( + node.op == "call_function" + and node.target in self.supported_non_decomposed_edge_ops + and node not in nodes_seen + ): + partitions.append(Partition(nodes=[node], id=next(partition_id))) + nodes_seen.add(node) + + return partitions + def _partition_graph_module( self, graph_module: torch.fx.GraphModule, ) -> Dict[str, DelegationSpec]: partition_tags: Dict[str, DelegationSpec] = {} - partition_list = generate_pattern_op_partitions( - graph_module, op_support=self.op_support - ) + partition_list = self._generate_single_node_partition(graph_module) for partition in partition_list: for node in partition.nodes: delegation_tag = f"tag{partition.id}" diff --git a/exir/program/_program.py b/exir/program/_program.py index 1e21bd1993..3f0f0155ec 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -859,6 +859,66 @@ def _sanity_check_graph_for_non_decomp_ops( logging.warning(warning_str) +def _gen_edge_manager_for_partitioners( + partitioner: Dict[str, List[Partitioner]], + aten_programs: Dict[str, ExportedProgram], + config: EdgeCompileConfig, + constant_methods: Optional[Dict[str, Any]], +) -> "EdgeProgramManager": + """ + Generates EdgeProgramManager for subsequent lowering to the + partitioners specified by partitioner. The EdgeProgramManager is generated from + aten_programs. + + Partitioners specify what nodes should not be decomposed from the original aten programs. + This is done through two passes of run_decompositions. + - First pass preserves all aten_targets specified by partitioners to preserve + them from nested decompositions + - Second pass uses check_op fn provided by partitioners to perform additional checks + on nodes with preserved aten targets. They are then replaces with transformed ops to + keep them through the second pass of decompositions + """ + ops_set_to_not_decompose_by_program = {} + edge_programs: Dict[str, ExportedProgram] = {} + for name, program in aten_programs.items(): + if partitioner is not None: + # preserve all ops listed by all partitioners first + all_ops_no_decomp = set() + for curr_partitioner in partitioner.get(name, []): + curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program) + all_ops_no_decomp |= set(curr_ops_no_decomp) + + program = program.run_decompositions( + _default_decomposition_table(), _preserve_ops=tuple(all_ops_no_decomp) + ) + # Among all the preserved aten ops, use the check_op_fn to do an additional + # check on which ops need to be preserved and which ops need to be decomposed + # Those which are truly preserved will be replaced with transformed ops + ops_set_to_not_decompose_by_program[name] = ( + _replace_aten_ops_with_transformed_ops(name, program, partitioner) + ) + program = program.run_decompositions(_default_decomposition_table()) + + _restore_transformed_ops_to_aten_ops(program) + + edge_programs[name] = program + + edge_programs[name] = _generate_edge_program( + name, + config, + program, + list(ops_set_to_not_decompose_by_program.get(name, [])), + ) + + edge_manager = EdgeProgramManager( + edge_programs, + constant_methods, + config, + list(set().union(*ops_set_to_not_decompose_by_program.values())), + ) + return edge_manager + + def _to_edge_transform_and_lower( programs: Union[ExportedProgram, Dict[str, ExportedProgram]], transform_passes: Optional[ @@ -909,8 +969,6 @@ def _to_edge_transform_and_lower( Returns: EdgeProgramManager """ - ops_set_to_not_decompose = set() - assert not isinstance(constant_methods, EdgeCompileConfig) config = compile_config or EdgeCompileConfig() if not isinstance(programs, dict): @@ -923,31 +981,8 @@ def _to_edge_transform_and_lower( else: partitioner = {} - ops_set_to_not_decompose_by_program = {} - edge_programs: Dict[str, ExportedProgram] = {} - for name, program in aten_programs.items(): - if partitioner is not None: - ops_set_to_not_decompose_by_program[name] = ( - _replace_aten_ops_with_transformed_ops(name, program, partitioner) - ) - program = program.run_decompositions(_default_decomposition_table()) - - _restore_transformed_ops_to_aten_ops(program) - - edge_programs[name] = program - - edge_programs[name] = _generate_edge_program( - name, - config, - program, - list(ops_set_to_not_decompose_by_program.get(name, [])), - ) - - edge_manager = EdgeProgramManager( - edge_programs, - constant_methods, - config, - list(set().union(*ops_set_to_not_decompose_by_program.values())), + edge_manager = _gen_edge_manager_for_partitioners( + partitioner, aten_programs, config, constant_methods ) if transform_passes is not None: diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index 6afa5e1965..e3d2ca9b8c 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -6,7 +6,7 @@ # pye-strict -import operator +import copy import unittest from typing import Any, Dict @@ -27,6 +27,7 @@ ExecutorchProgramManager, to_edge, ) +from executorch.exir.tracer import _default_decomposition_table from executorch.exir.verification.verifier import EXIREdgeDialectVerifier from executorch.extension.pybindings.portable_lib import ( @@ -102,6 +103,19 @@ def _get_random_inputs(cls): return (x,) +class TestLSTM(torch.nn.Module): + def __init__(self): + super().__init__() + self.lstm = torch.nn.LSTM(input_size=8, hidden_size=16, batch_first=True) + + def forward(self, x): + return self.lstm(x) + + @classmethod + def _get_random_inputs(cls): + return (torch.rand(1, 10, 8),) + + class WrapperModule(torch.nn.Module): def __init__(self, fn): super().__init__() @@ -550,23 +564,65 @@ def _use_foo_add(a: torch.Tensor, b: torch.Tensor): except SpecViolationError: self.fail("Should not error out on custom op") + def get_num_nondecomposed_ops(self, ep, partitioner): + # count the number of aten ops that the partitioner can delegate + # we do this by running run_decompositions() with the preserved ops given + # to us by the partitioner. Then we count the number of preserved aten ops + # which pass the filter_ops fn given by the partitioner + reference_ep = copy.deepcopy(ep) + aten_ops_not_decomposed, filter_ops = partitioner.ops_to_not_decompose(ep) + reference_decomp_ep = reference_ep.run_decompositions( + decomp_table=_default_decomposition_table(), + _preserve_ops=tuple(aten_ops_not_decomposed), + ) + num_non_decomposed_aten_ops = 0 + for node in reference_decomp_ep.graph.nodes: + if ( + node.op == "call_function" + and node.target in aten_ops_not_decomposed + and (filter_ops(node) if filter_ops else True) + ): + num_non_decomposed_aten_ops += 1 + return num_non_decomposed_aten_ops + def _test_model_with_non_decomp_partitioner(self, model: torch.nn.Module): # This is the pre-dispatch export that we will be switching to primarily # in the near future. The input to _to_edge_transform_and_lower needs to # be a graph generated by this pre dispatch export. ep = _export(model, model._get_random_inputs(), pre_dispatch=True) + non_decomp_partitioner = NonDecompTestPartitioner() + + num_non_decomposed_aten_ops = self.get_num_nondecomposed_ops( + ep, non_decomp_partitioner + ) + + # run to_edge_trasnform_and_lower edge = _to_edge_transform_and_lower( ep, compile_config=EdgeCompileConfig(), partitioner=[NonDecompTestPartitioner()], ) + # Check that non_decomposed_edge_ops are all consumed by the delegate + non_decomposed_edge_ops = ( + non_decomp_partitioner.supported_non_decomposed_edge_ops + ) + for node in edge.exported_program().graph.nodes: + if node.op == "call_function": + self.assertTrue(node.target not in non_decomposed_edge_ops) + + # check that the number of call_delegate_nodes is equal to the number of + # non_decomposed_aten_ops we found above + num_call_delegates = 0 for node in edge.exported_program().graph_module.graph.nodes: # There should only be a single call_function node in the graph # and that should be a call_delegate node. - if node.op == "call_function" and node.target != operator.getitem: - self.assertEqual( - node.target, torch.ops.higher_order.executorch_call_delegate - ) + if ( + node.op == "call_function" + and node.target == torch.ops.higher_order.executorch_call_delegate + ): + num_call_delegates += 1 + + self.assertEqual(num_call_delegates, num_non_decomposed_aten_ops) def test_to_edge_transform_and_lower(self): self._test_model_with_non_decomp_partitioner(TestLinear()) @@ -577,6 +633,8 @@ def test_to_edge_transform_and_lower(self): self._test_model_with_non_decomp_partitioner(TestUpsample()) + self._test_model_with_non_decomp_partitioner(TestLSTM()) + def test_to_edge_transform_and_lower_with_exception(self): class TestLinear(torch.nn.Module): def __init__(self):