Skip to content

Commit

Permalink
use _preserve_ops for to_edge_transform_and_lower (#4273)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #4273

## Motivation
`run_decompositions()` has a new preserve_ops functionality which allows us to specify which ops we want to refrain from decomposing. This is super helpful for the to_edge_transform_and_lower api because it allows us to preserve decomposition that occur beyond the first level.

For example consider LSTM. when exported using torch.export, we see a torch.ops.aten.LSTM() operator in the graph. When running decompositions this is decomposed into linear, and then further decomposed into addmm. Since the linear op is produced from decomposing LSTM and does not exist until after we run_decompositions(), we can not perform our trick of changing the name space to prevent its decomposition. However, now using `_preserve_ops=(torch.ops.aten.linear.default,)` we are able to prevent this second layer decomposition.

## API Implementation Change
So in the implementation we do two passes. The first pass is we run_decompositions preserving all aten ops specified by our partitioners using `_preserve_ops`. On our second pass, we further filter which aten ops should be preserved by using the check_op_fn given to us by partitioners. We then use our namespace trick to prevent the decomposition of all aten ops which pass our check_op_fn.

## Testing Changes
To strengthen our tests, I first change the functionality of the NonDecompPartitioner. We partition only pre-decomp aten ops. And each of these ops live within their own delegate (this allows us to have a 1:1 mapping for call_delegate and pre_decomp aten nodes). In testing, this will allow us to ensure that the number of ops which are to preserved is correct by counting the number of delegates calls.

In testing we then count the number of aten ops which should correctly be preserved. And then check after the fact that all these ops are
1. No longer in the graph after to_edge_transform_and_lower
2. Each of these preserved ops are transformed into a call_delegate node

Reviewed By: tarun292

Differential Revision: D59786323

fbshipit-source-id: 7ea946e0d5afc8ebddd26913f6e843305116ad3b
  • Loading branch information
mcr229 authored and facebook-github-bot committed Jul 17, 2024
1 parent b448254 commit c3357e1
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 36 deletions.
27 changes: 23 additions & 4 deletions exir/backend/test/op_partitioner_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import itertools
from typing import Callable, Dict, final, List, Optional, Tuple

import torch
Expand All @@ -24,6 +25,7 @@
from executorch.exir.graph_module import get_control_flow_submodules
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
from torch.export import ExportedProgram
from torch.fx.passes.infra.partitioner import Partition
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase


Expand Down Expand Up @@ -145,10 +147,12 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
@final
class NonDecompTestPartitioner(Partitioner):
"""
Partitions all add/mul nodes regardless of order
Non Decomp Test Partitioner, preserves aten ops from decomposition for delegate
consumption. Ensures that non_decomposed_edge_ops are all within their own delegate
"""

def __init__(self) -> None:
self.supported_non_decomposed_edge_ops = edge_ops_non_decomposed
self.op_support = any_chain(OpsToNotDecomposeOperatorSupport())
self.delegation_spec = DelegationSpec(
BackendWithCompilerDemo.__name__,
Expand All @@ -171,14 +175,29 @@ def filter_ops(node: torch.fx.Node) -> bool:

return (ops_not_to_decompose, filter_ops)

def _generate_single_node_partition(
self, gm: torch.fx.GraphModule
) -> List[Partition]:
partitions = []
partition_id = itertools.count()
nodes_seen = set()
for node in gm.graph.nodes:
if (
node.op == "call_function"
and node.target in self.supported_non_decomposed_edge_ops
and node not in nodes_seen
):
partitions.append(Partition(nodes=[node], id=next(partition_id)))
nodes_seen.add(node)

return partitions

def _partition_graph_module(
self,
graph_module: torch.fx.GraphModule,
) -> Dict[str, DelegationSpec]:
partition_tags: Dict[str, DelegationSpec] = {}
partition_list = generate_pattern_op_partitions(
graph_module, op_support=self.op_support
)
partition_list = self._generate_single_node_partition(graph_module)
for partition in partition_list:
for node in partition.nodes:
delegation_tag = f"tag{partition.id}"
Expand Down
89 changes: 62 additions & 27 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,66 @@ def _sanity_check_graph_for_non_decomp_ops(
logging.warning(warning_str)


def _gen_edge_manager_for_partitioners(
partitioner: Dict[str, List[Partitioner]],
aten_programs: Dict[str, ExportedProgram],
config: EdgeCompileConfig,
constant_methods: Optional[Dict[str, Any]],
) -> "EdgeProgramManager":
"""
Generates EdgeProgramManager for subsequent lowering to the
partitioners specified by partitioner. The EdgeProgramManager is generated from
aten_programs.
Partitioners specify what nodes should not be decomposed from the original aten programs.
This is done through two passes of run_decompositions.
- First pass preserves all aten_targets specified by partitioners to preserve
them from nested decompositions
- Second pass uses check_op fn provided by partitioners to perform additional checks
on nodes with preserved aten targets. They are then replaces with transformed ops to
keep them through the second pass of decompositions
"""
ops_set_to_not_decompose_by_program = {}
edge_programs: Dict[str, ExportedProgram] = {}
for name, program in aten_programs.items():
if partitioner is not None:
# preserve all ops listed by all partitioners first
all_ops_no_decomp = set()
for curr_partitioner in partitioner.get(name, []):
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
all_ops_no_decomp |= set(curr_ops_no_decomp)

program = program.run_decompositions(
_default_decomposition_table(), _preserve_ops=tuple(all_ops_no_decomp)
)
# Among all the preserved aten ops, use the check_op_fn to do an additional
# check on which ops need to be preserved and which ops need to be decomposed
# Those which are truly preserved will be replaced with transformed ops
ops_set_to_not_decompose_by_program[name] = (
_replace_aten_ops_with_transformed_ops(name, program, partitioner)
)
program = program.run_decompositions(_default_decomposition_table())

_restore_transformed_ops_to_aten_ops(program)

edge_programs[name] = program

edge_programs[name] = _generate_edge_program(
name,
config,
program,
list(ops_set_to_not_decompose_by_program.get(name, [])),
)

edge_manager = EdgeProgramManager(
edge_programs,
constant_methods,
config,
list(set().union(*ops_set_to_not_decompose_by_program.values())),
)
return edge_manager


def _to_edge_transform_and_lower(
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
transform_passes: Optional[
Expand Down Expand Up @@ -909,8 +969,6 @@ def _to_edge_transform_and_lower(
Returns:
EdgeProgramManager
"""
ops_set_to_not_decompose = set()

assert not isinstance(constant_methods, EdgeCompileConfig)
config = compile_config or EdgeCompileConfig()
if not isinstance(programs, dict):
Expand All @@ -923,31 +981,8 @@ def _to_edge_transform_and_lower(
else:
partitioner = {}

ops_set_to_not_decompose_by_program = {}
edge_programs: Dict[str, ExportedProgram] = {}
for name, program in aten_programs.items():
if partitioner is not None:
ops_set_to_not_decompose_by_program[name] = (
_replace_aten_ops_with_transformed_ops(name, program, partitioner)
)
program = program.run_decompositions(_default_decomposition_table())

_restore_transformed_ops_to_aten_ops(program)

edge_programs[name] = program

edge_programs[name] = _generate_edge_program(
name,
config,
program,
list(ops_set_to_not_decompose_by_program.get(name, [])),
)

edge_manager = EdgeProgramManager(
edge_programs,
constant_methods,
config,
list(set().union(*ops_set_to_not_decompose_by_program.values())),
edge_manager = _gen_edge_manager_for_partitioners(
partitioner, aten_programs, config, constant_methods
)

if transform_passes is not None:
Expand Down
68 changes: 63 additions & 5 deletions exir/program/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# pye-strict

import operator
import copy
import unittest
from typing import Any, Dict

Expand All @@ -27,6 +27,7 @@
ExecutorchProgramManager,
to_edge,
)
from executorch.exir.tracer import _default_decomposition_table
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier

from executorch.extension.pybindings.portable_lib import (
Expand Down Expand Up @@ -102,6 +103,19 @@ def _get_random_inputs(cls):
return (x,)


class TestLSTM(torch.nn.Module):
def __init__(self):
super().__init__()
self.lstm = torch.nn.LSTM(input_size=8, hidden_size=16, batch_first=True)

def forward(self, x):
return self.lstm(x)

@classmethod
def _get_random_inputs(cls):
return (torch.rand(1, 10, 8),)


class WrapperModule(torch.nn.Module):
def __init__(self, fn):
super().__init__()
Expand Down Expand Up @@ -550,23 +564,65 @@ def _use_foo_add(a: torch.Tensor, b: torch.Tensor):
except SpecViolationError:
self.fail("Should not error out on custom op")

def get_num_nondecomposed_ops(self, ep, partitioner):
# count the number of aten ops that the partitioner can delegate
# we do this by running run_decompositions() with the preserved ops given
# to us by the partitioner. Then we count the number of preserved aten ops
# which pass the filter_ops fn given by the partitioner
reference_ep = copy.deepcopy(ep)
aten_ops_not_decomposed, filter_ops = partitioner.ops_to_not_decompose(ep)
reference_decomp_ep = reference_ep.run_decompositions(
decomp_table=_default_decomposition_table(),
_preserve_ops=tuple(aten_ops_not_decomposed),
)
num_non_decomposed_aten_ops = 0
for node in reference_decomp_ep.graph.nodes:
if (
node.op == "call_function"
and node.target in aten_ops_not_decomposed
and (filter_ops(node) if filter_ops else True)
):
num_non_decomposed_aten_ops += 1
return num_non_decomposed_aten_ops

def _test_model_with_non_decomp_partitioner(self, model: torch.nn.Module):
# This is the pre-dispatch export that we will be switching to primarily
# in the near future. The input to _to_edge_transform_and_lower needs to
# be a graph generated by this pre dispatch export.
ep = _export(model, model._get_random_inputs(), pre_dispatch=True)
non_decomp_partitioner = NonDecompTestPartitioner()

num_non_decomposed_aten_ops = self.get_num_nondecomposed_ops(
ep, non_decomp_partitioner
)

# run to_edge_trasnform_and_lower
edge = _to_edge_transform_and_lower(
ep,
compile_config=EdgeCompileConfig(),
partitioner=[NonDecompTestPartitioner()],
)
# Check that non_decomposed_edge_ops are all consumed by the delegate
non_decomposed_edge_ops = (
non_decomp_partitioner.supported_non_decomposed_edge_ops
)
for node in edge.exported_program().graph.nodes:
if node.op == "call_function":
self.assertTrue(node.target not in non_decomposed_edge_ops)

# check that the number of call_delegate_nodes is equal to the number of
# non_decomposed_aten_ops we found above
num_call_delegates = 0
for node in edge.exported_program().graph_module.graph.nodes:
# There should only be a single call_function node in the graph
# and that should be a call_delegate node.
if node.op == "call_function" and node.target != operator.getitem:
self.assertEqual(
node.target, torch.ops.higher_order.executorch_call_delegate
)
if (
node.op == "call_function"
and node.target == torch.ops.higher_order.executorch_call_delegate
):
num_call_delegates += 1

self.assertEqual(num_call_delegates, num_non_decomposed_aten_ops)

def test_to_edge_transform_and_lower(self):
self._test_model_with_non_decomp_partitioner(TestLinear())
Expand All @@ -577,6 +633,8 @@ def test_to_edge_transform_and_lower(self):

self._test_model_with_non_decomp_partitioner(TestUpsample())

self._test_model_with_non_decomp_partitioner(TestLSTM())

def test_to_edge_transform_and_lower_with_exception(self):
class TestLinear(torch.nn.Module):
def __init__(self):
Expand Down

0 comments on commit c3357e1

Please sign in to comment.