Skip to content

Commit

Permalink
Fix scalar arithemetic and add test cases (#6224)
Browse files Browse the repository at this point in the history
Summary:
Add UnsquezeScalarPlaceholders pass to make scalars rank 1 Add MatchShapesPass to guarantee same rank for all inputs for ops that require it.
Additional fixes to make Scalar tests pass

Map which cases work and which don't.

Signed-off-by: Erik Lundell <erik.lundell@arm.com>
Change-Id: I4ea5e189e26cf7aff391ec153d525b2fb61aa16f

Fix shape issues

Change-Id: I0b8588cd5f8b284c25e806bb83bc788067d5b649

Pull Request resolved: #6224

Reviewed By: mergennachin

Differential Revision: D64427014

Pulled By: digantdesai

fbshipit-source-id: 5295e9ffab1d848b111e0cb01aa0ce9142c20781
  • Loading branch information
Erik-Lundell authored and facebook-github-bot committed Oct 17, 2024
1 parent 5f12f28 commit 6669e18
Show file tree
Hide file tree
Showing 12 changed files with 476 additions and 32 deletions.
8 changes: 2 additions & 6 deletions backends/arm/_passes/annotate_channels_last_dim_order_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import cast

import torch
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm.tosa_quant_utils import dq_op
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
from executorch.exir.pass_base import ExportPass, PassResult
Expand Down Expand Up @@ -52,12 +53,7 @@ def call(self, graph_module: torch.fx.GraphModule):
NHWC_Order = (0, 2, 3, 1)
HWCM_Order = (2, 3, 0, 1)
for node in graph_module.graph.nodes:
if isinstance(
node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list)
):
node_data = node.meta["val"][0].data
else:
node_data = node.meta["val"].data
node_data = get_first_fake_tensor(node).data

if len(node_data.shape) == 4:
dim_order = NHWC_Order
Expand Down
8 changes: 7 additions & 1 deletion backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
InsertSqueezeAfterSumPass,
)
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
ConvertMeanDimToAveragePool,
)
Expand All @@ -30,6 +31,9 @@
ScalarsToAttributePass,
)
from executorch.backends.arm._passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
UnsqueezeScalarPlaceholdersPass,
)
from executorch.exir import ExportedProgram
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.pass_manager import PassManager
Expand All @@ -45,10 +49,12 @@ def transform_to_backend_pipeline(
):
"""Apply passes before transforming program to backend"""
self.add_pass(CastInt64ToInt32Pass(exported_program))
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(SizeAdjustConv2DPass())
self.add_pass(RemoveClonePass())
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(ConvertMeanDimToAveragePool())
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(DecomposeDivPass())
self.add_pass(InsertSqueezeAfterSumPass())
self.add_pass(ConvertSplitToSlicePass())
Expand All @@ -61,6 +67,6 @@ def transform_to_backend_pipeline(
return self._transform(exported_program.graph_module)

def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
self.add_pass(DecomposeDivPass())
self.add_pass(ScalarsToAttributePass())
self.add_pass(DecomposeDivPass())
return self._transform(graph_module)
20 changes: 20 additions & 0 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from typing import Optional

import torch
import torch.fx

from executorch.exir.dialects._ops import ops as exir_ops
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import FakeTensor


def create_node(
Expand Down Expand Up @@ -64,3 +66,21 @@ def insert_q_dq_pair(
# node's first use
q.args = (anchor,) + q_params
return dq


def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor:
"""
Returns a FakeTensor from the meta field of 'node'.
If the node contains many fake tensors, return the first one.
"""
if isinstance(
node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list)
):
fake_tensor = node.meta["val"][0]
else:
fake_tensor = node.meta["val"]

assert isinstance(
fake_tensor, FakeTensor
), f'Found {fake_tensor} in meta["val"] of {node}, expected to find FakeTensor.'
return fake_tensor
9 changes: 6 additions & 3 deletions backends/arm/_passes/decompose_div_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

edge_div_ops = (exir_ops.edge.aten.div.Tensor,)
aten_div_ops = (torch.ops.aten.div.Tensor, torch.ops.aten.div_.Tensor)


def get_div_decomposition(op) -> tuple:
"""
Returns the the (reciprocal_op, mul_op), where the ops depends on if
the div op is in exir_ops torch.ops.aten.
"""
if op == exir_ops.edge.aten.div.Tensor:
if op in edge_div_ops:
return (exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.mul.Tensor)
if op == torch.ops.aten.div.Tensor:
if op in aten_div_ops:
return (torch.ops.aten.reciprocal.default, torch.ops.aten.mul.Tensor)
raise RuntimeError(f"Can't get div decomposition for op {op}")

Expand All @@ -33,7 +36,7 @@ class DecomposeDivPass(ExportPass):
"""

def call_operator(self, op, args, kwargs, meta):
if op not in (exir_ops.edge.aten.div.Tensor, torch.ops.aten.div.Tensor):
if op not in (edge_div_ops + aten_div_ops):
return super().call_operator(op, args, kwargs, meta)

reciprocal_op, mul_op = get_div_decomposition(op)
Expand Down
126 changes: 126 additions & 0 deletions backends/arm/_passes/match_arg_ranks_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast

from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_first_fake_tensor,
)

from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule, Node


class MatchArgRanksPass(ExportPass):
"""
For ops in 'targeted_ops', make sure that the inputs share the same rank.
New dimensions are inserted at from the beginning of the
"""

def __init__(self, exported_program):
super().__init__()
self.exported_program = exported_program

targeted_ops = [
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.div.Tensor,
]

def _match_op_rank(self, graph_module, node, arg, max_rank):
"""
In graph_module, insert a view between arg and node to make the
rank of arg match the other args to node.
"""
shape = get_first_fake_tensor(arg).shape
rank = len(shape)
new_shape = list([1] * (max_rank - rank) + list(shape))
with graph_module.graph.inserting_before(node):
view = create_node(
graph_module.graph,
exir_ops.edge.aten.view_copy.default,
args=(arg, new_shape),
kwargs={},
)
node.replace_input_with(arg, view)

def _match_buffer_rank(self, arg, max_rank):
"""
Change arg's fake tensor meta to match max_rank if:
- arg is found in inputs_to_buffers or inputs_to_parameters.
"""
fake_tensor = get_first_fake_tensor(arg)
shape = fake_tensor.shape
rank = len(shape)
new_shape = list([1] * (max_rank - rank) + list(shape))

buffer_name = None
if arg.name in self.exported_program.graph_signature.inputs_to_buffers:
buffer_name = self.exported_program.graph_signature.inputs_to_buffers[
arg.name
]
elif arg.name in self.exported_program.graph_signature.inputs_to_parameters:
buffer_name = self.exported_program.graph_signature.inputs_to_parameters[
arg.name
]
if buffer_name:
new_tensor = self.exported_program.state_dict[buffer_name].reshape(
new_shape
)
self.exported_program.state_dict[buffer_name] = new_tensor
arg.meta["val"] = fake_tensor.fake_mode.from_tensor(
new_tensor, static_shapes=True
)

def call(self, graph_module: GraphModule) -> PassResult:
for node in graph_module.graph.nodes:
node = cast(Node, node)

if node.op != "call_function" or node.target not in self.targeted_ops:
continue

# Calculate max rank of all inputs to node
max_rank = 1
for arg in node.args:
if isinstance(arg, Node):
shape = get_first_fake_tensor(arg).shape
max_rank = max(max_rank, len(shape))

# Adjust output shape of args if needed.
for arg in node.args:
if not isinstance(arg, Node):
continue
shape = get_first_fake_tensor(arg).shape
rank = len(shape)
if rank == max_rank:
continue

# If the argument is call_function, match shape by inserting view node.
if arg.op == "call_function":
self._match_op_rank(graph_module, node, arg, max_rank)
else:
# If the argument is a buffer or parameter, adjust shape by changing the fake tensor meta.
self._match_buffer_rank(arg, max_rank)

graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)

def ensures(self, graph_module):
for node in graph_module.graph.nodes:
if node.op != "call_function" or node.target not in self.targeted_ops:
continue
arg0_rank = node.args[0].meta["val"].dim()
arg1_rank = node.args[1].meta["val"].dim()
if arg0_rank != arg1_rank:
raise ValueError(
"Arguments of arithmetic operators need to have the same rank!"
)
8 changes: 6 additions & 2 deletions backends/arm/_passes/scalars_to_attribute_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import cast, Union

import torch
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor

from executorch.exir.pass_base import ExportPass, PassResult
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
Expand All @@ -22,10 +22,14 @@ class ScalarsToAttributePass(ExportPass):

targeted_ops = [
torch.ops.aten.add.Tensor,
torch.ops.aten.add_.Tensor,
torch.ops.aten.sub.Tensor,
torch.ops.aten.sub_.Tensor,
torch.ops.aten.rsub.Scalar,
torch.ops.aten.mul.Tensor,
torch.ops.aten.mul_.Tensor,
torch.ops.aten.div.Tensor,
torch.ops.aten.div_.Tensor,
]

def call(self, graph_module: GraphModule) -> PassResult:
Expand All @@ -37,7 +41,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
biggest_rank = 1
for arg in n.args:
if isinstance(arg, Node):
_, shape, _ = extract_tensor_meta(arg.meta)
shape = get_first_fake_tensor(arg).shape
biggest_rank = max(biggest_rank, len(shape))

new_args = []
Expand Down
53 changes: 53 additions & 0 deletions backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.pass_base import ExportPass, PassResult


class UnsqueezeScalarPlaceholdersPass(ExportPass):
"""
Placeholders that have node.meta["val"].shape = () cause issues later in the lowering.
This pass unsqueezes the placeholders to make sure shape is at least (1,).
"""

def __init__(self, exported_program):
self.exported_program = exported_program
super().__init__()

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
if node.op != "placeholder":
continue
rank = node.meta["val"].dim()
if rank == 0:
if not (
node.name in self.exported_program.graph_signature.inputs_to_buffers
or node.name
in self.exported_program.graph_signature.inputs_to_parameters
):
continue
tensor = self.exported_program.state_dict[node.name]
if tensor.dim() == 0:
self.exported_program.state_dict[node.name] = tensor.unsqueeze(0)
node.meta["val"] = node.meta["val"].fake_mode.from_tensor(
tensor.unsqueeze(0), static_shapes=True
)
else:
node.meta["val"] = node.meta["val"].fake_mode.from_tensor(
tensor, static_shapes=True
)

graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)

def ensures(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
if node.op == "placeholder":
rank = node.meta["val"].dim()
if rank == 0:
raise ValueError("Placeholders of rank 0 are not supported!")
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _annotate_mul(

annotated_partitions = []
for node in gm.graph.nodes:
if node.target not in (torch.ops.aten.mul.Tensor,):
if node.target not in (torch.ops.aten.mul.Tensor, torch.ops.aten.mul_.Tensor):
continue
mul_node = node
annotated_partitions.append([mul_node])
Expand Down
15 changes: 5 additions & 10 deletions backends/arm/quantizer/quantization_annotation/sub_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

# pyre-unsafe

import itertools
import operator
from typing import Callable, List, Optional

import torch
Expand All @@ -16,7 +14,6 @@
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from torch.ao.quantization.quantizer import QuantizationAnnotation
from torch.fx import GraphModule, Node
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


@register_annotator("sub")
Expand All @@ -25,14 +22,12 @@ def _annotate_sub(
quantization_config: QuantizationConfig,
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
sub_partitions = get_source_partitions(
gm.graph, [operator.sub, torch.sub, operator.isub], filter_fn
)
sub_partitions = list(itertools.chain.from_iterable(sub_partitions.values()))
annotated_partitions = []
for sub_partition in sub_partitions:
annotated_partitions.append(sub_partition.nodes)
sub_node = sub_partition.output_nodes[0]
for node in gm.graph.nodes:
if node.target not in (torch.ops.aten.sub.Tensor, torch.ops.aten.sub_.Tensor):
continue
annotated_partitions.append(node)
sub_node = node
if arm_quantizer_utils.is_annotated(sub_node):
continue

Expand Down
Loading

0 comments on commit 6669e18

Please sign in to comment.