-
Notifications
You must be signed in to change notification settings - Fork 323
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix scalar arithemetic and add test cases (#6224)
Summary: Add UnsquezeScalarPlaceholders pass to make scalars rank 1 Add MatchShapesPass to guarantee same rank for all inputs for ops that require it. Additional fixes to make Scalar tests pass Map which cases work and which don't. Signed-off-by: Erik Lundell <erik.lundell@arm.com> Change-Id: I4ea5e189e26cf7aff391ec153d525b2fb61aa16f Fix shape issues Change-Id: I0b8588cd5f8b284c25e806bb83bc788067d5b649 Pull Request resolved: #6224 Reviewed By: mergennachin Differential Revision: D64427014 Pulled By: digantdesai fbshipit-source-id: 5295e9ffab1d848b111e0cb01aa0ce9142c20781
- Loading branch information
1 parent
5f12f28
commit 6669e18
Showing
12 changed files
with
476 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import cast | ||
|
||
from executorch.backends.arm._passes.arm_pass_utils import ( | ||
create_node, | ||
get_first_fake_tensor, | ||
) | ||
|
||
from executorch.exir.dialects._ops import ops as exir_ops | ||
|
||
from executorch.exir.pass_base import ExportPass, PassResult | ||
from torch.fx import GraphModule, Node | ||
|
||
|
||
class MatchArgRanksPass(ExportPass): | ||
""" | ||
For ops in 'targeted_ops', make sure that the inputs share the same rank. | ||
New dimensions are inserted at from the beginning of the | ||
""" | ||
|
||
def __init__(self, exported_program): | ||
super().__init__() | ||
self.exported_program = exported_program | ||
|
||
targeted_ops = [ | ||
exir_ops.edge.aten.add.Tensor, | ||
exir_ops.edge.aten.sub.Tensor, | ||
exir_ops.edge.aten.mul.Tensor, | ||
exir_ops.edge.aten.div.Tensor, | ||
] | ||
|
||
def _match_op_rank(self, graph_module, node, arg, max_rank): | ||
""" | ||
In graph_module, insert a view between arg and node to make the | ||
rank of arg match the other args to node. | ||
""" | ||
shape = get_first_fake_tensor(arg).shape | ||
rank = len(shape) | ||
new_shape = list([1] * (max_rank - rank) + list(shape)) | ||
with graph_module.graph.inserting_before(node): | ||
view = create_node( | ||
graph_module.graph, | ||
exir_ops.edge.aten.view_copy.default, | ||
args=(arg, new_shape), | ||
kwargs={}, | ||
) | ||
node.replace_input_with(arg, view) | ||
|
||
def _match_buffer_rank(self, arg, max_rank): | ||
""" | ||
Change arg's fake tensor meta to match max_rank if: | ||
- arg is found in inputs_to_buffers or inputs_to_parameters. | ||
""" | ||
fake_tensor = get_first_fake_tensor(arg) | ||
shape = fake_tensor.shape | ||
rank = len(shape) | ||
new_shape = list([1] * (max_rank - rank) + list(shape)) | ||
|
||
buffer_name = None | ||
if arg.name in self.exported_program.graph_signature.inputs_to_buffers: | ||
buffer_name = self.exported_program.graph_signature.inputs_to_buffers[ | ||
arg.name | ||
] | ||
elif arg.name in self.exported_program.graph_signature.inputs_to_parameters: | ||
buffer_name = self.exported_program.graph_signature.inputs_to_parameters[ | ||
arg.name | ||
] | ||
if buffer_name: | ||
new_tensor = self.exported_program.state_dict[buffer_name].reshape( | ||
new_shape | ||
) | ||
self.exported_program.state_dict[buffer_name] = new_tensor | ||
arg.meta["val"] = fake_tensor.fake_mode.from_tensor( | ||
new_tensor, static_shapes=True | ||
) | ||
|
||
def call(self, graph_module: GraphModule) -> PassResult: | ||
for node in graph_module.graph.nodes: | ||
node = cast(Node, node) | ||
|
||
if node.op != "call_function" or node.target not in self.targeted_ops: | ||
continue | ||
|
||
# Calculate max rank of all inputs to node | ||
max_rank = 1 | ||
for arg in node.args: | ||
if isinstance(arg, Node): | ||
shape = get_first_fake_tensor(arg).shape | ||
max_rank = max(max_rank, len(shape)) | ||
|
||
# Adjust output shape of args if needed. | ||
for arg in node.args: | ||
if not isinstance(arg, Node): | ||
continue | ||
shape = get_first_fake_tensor(arg).shape | ||
rank = len(shape) | ||
if rank == max_rank: | ||
continue | ||
|
||
# If the argument is call_function, match shape by inserting view node. | ||
if arg.op == "call_function": | ||
self._match_op_rank(graph_module, node, arg, max_rank) | ||
else: | ||
# If the argument is a buffer or parameter, adjust shape by changing the fake tensor meta. | ||
self._match_buffer_rank(arg, max_rank) | ||
|
||
graph_module.recompile() | ||
graph_module = super().call(graph_module).graph_module | ||
return PassResult(graph_module, True) | ||
|
||
def ensures(self, graph_module): | ||
for node in graph_module.graph.nodes: | ||
if node.op != "call_function" or node.target not in self.targeted_ops: | ||
continue | ||
arg0_rank = node.args[0].meta["val"].dim() | ||
arg1_rank = node.args[1].meta["val"].dim() | ||
if arg0_rank != arg1_rank: | ||
raise ValueError( | ||
"Arguments of arithmetic operators need to have the same rank!" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
53 changes: 53 additions & 0 deletions
53
backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from executorch.exir.pass_base import ExportPass, PassResult | ||
|
||
|
||
class UnsqueezeScalarPlaceholdersPass(ExportPass): | ||
""" | ||
Placeholders that have node.meta["val"].shape = () cause issues later in the lowering. | ||
This pass unsqueezes the placeholders to make sure shape is at least (1,). | ||
""" | ||
|
||
def __init__(self, exported_program): | ||
self.exported_program = exported_program | ||
super().__init__() | ||
|
||
def call(self, graph_module: torch.fx.GraphModule): | ||
for node in graph_module.graph.nodes: | ||
if node.op != "placeholder": | ||
continue | ||
rank = node.meta["val"].dim() | ||
if rank == 0: | ||
if not ( | ||
node.name in self.exported_program.graph_signature.inputs_to_buffers | ||
or node.name | ||
in self.exported_program.graph_signature.inputs_to_parameters | ||
): | ||
continue | ||
tensor = self.exported_program.state_dict[node.name] | ||
if tensor.dim() == 0: | ||
self.exported_program.state_dict[node.name] = tensor.unsqueeze(0) | ||
node.meta["val"] = node.meta["val"].fake_mode.from_tensor( | ||
tensor.unsqueeze(0), static_shapes=True | ||
) | ||
else: | ||
node.meta["val"] = node.meta["val"].fake_mode.from_tensor( | ||
tensor, static_shapes=True | ||
) | ||
|
||
graph_module.recompile() | ||
graph_module = super().call(graph_module).graph_module | ||
return PassResult(graph_module, True) | ||
|
||
def ensures(self, graph_module: torch.fx.GraphModule): | ||
for node in graph_module.graph.nodes: | ||
if node.op == "placeholder": | ||
rank = node.meta["val"].dim() | ||
if rank == 0: | ||
raise ValueError("Placeholders of rank 0 are not supported!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.