Skip to content

Commit

Permalink
check to_copy args in vulkan_partitioner (#6215)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #6215

in exir dialect, to_copy doesn't have dtype arg and it is inferred from the dtype of the output tensor. The args will be of length 1 with the sole arg being the input tensor. Thus the previous check always returns False as args is never > 1.

Reviewed By: SS-JIA

Differential Revision: D64267104

fbshipit-source-id: 62267ecae167aef7ecc415bf83f1fb024d66244f
  • Loading branch information
nathanaelsee authored and facebook-github-bot committed Oct 15, 2024
1 parent 5c3439d commit 708c6b6
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,24 @@ def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> bool:

return False

def is_valid_to_copy(self, node: torch.fx.node) -> bool: # pyre-ignore[11]
# lower only if floating point dtype conversion
return len(node.args) > 1 and node.args[1] in (torch.float32, torch.float16)
def is_valid_to_copy(self, node: torch.fx.Node) -> bool:
float_dtypes = [torch.float16, torch.float32]

if len(node.args) != 1:
return False

in_arg = node.args[0]
if not isinstance(in_arg, torch.fx.Node):
return False

in_tensor = in_arg.meta.get("val", None)
out_tensor = node.meta.get("val", None)

if isinstance(in_tensor, FakeTensor) and isinstance(out_tensor, FakeTensor):
if out_tensor.dtype in float_dtypes and in_tensor.dtype in float_dtypes:
return True

return False

def is_node_supported(
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
Expand Down Expand Up @@ -174,13 +189,13 @@ def _is_node_supported(
if target not in VulkanSupportedOperators._ops:
return False

features = VulkanSupportedOperators._ops[target]

if target == exir_ops.edge.aten._to_copy.default and not self.is_valid_to_copy(
node
):
return False

features = VulkanSupportedOperators._ops[target]

if self.require_dynamic_shapes and not features.supports_dynamic_shape:
return False

Expand Down

0 comments on commit 708c6b6

Please sign in to comment.