diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 666de4bda0..f8b4d905fb 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -377,6 +377,7 @@ python_unittest( ":test_memory_format_ops_pass_utils", "//caffe2:torch", "//executorch/extension/pybindings:aten_lib", # @manual + "//pytorch/vision:torchvision", # @manual ], ) @@ -394,6 +395,7 @@ python_unittest( "//executorch/exir/dialects:lib", "//executorch/exir/dialects/edge:lib", "//executorch/extension/pybindings:portable_lib", # @manual + "//pytorch/vision:torchvision", # @manual ], ) @@ -404,6 +406,7 @@ python_library( ], deps = [ "//caffe2:torch", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", "//executorch/exir:dim_order_utils", "//executorch/exir:lib", "//executorch/exir/capture:config", diff --git a/exir/tests/test_memory_format_ops_pass.py b/exir/tests/test_memory_format_ops_pass.py index 7c9e1bd248..53befded94 100644 --- a/exir/tests/test_memory_format_ops_pass.py +++ b/exir/tests/test_memory_format_ops_pass.py @@ -10,6 +10,8 @@ from typing import Union import torch + +import torchvision from executorch.exir import EdgeCompileConfig, to_edge from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload @@ -264,3 +266,65 @@ def call_operator(self, op, args, kwargs, meta): self.assertTrue(is_contiguous_dim_order(actual)) self.assertTrue(is_contiguous_dim_order(expected)) + + def test_resnet18(self) -> None: + model = torchvision.models.resnet18() + MemoryFormatOpsPassTestUtils.memory_format_test_runner( + self, + MemoryFormatTestSet( + module=model.eval(), + sample_input=(torch.randn(1, 3, 224, 224),), + target_memory_format=torch.contiguous_format, + op_level_check=False, + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, + atol=1e-3, + rtol=1e-3, + ), + ) + + def test_resnet18_xnnpack(self) -> None: + model = torchvision.models.resnet18() + MemoryFormatOpsPassTestUtils.memory_format_test_runner( + self, + MemoryFormatTestSet( + module=model.eval(), + sample_input=(torch.randn(1, 3, 224, 224),), + target_memory_format=torch.contiguous_format, + op_level_check=False, + use_xnnpack=True, + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, + atol=1e-3, + rtol=1e-3, + ), + ) + + def test_mobilenet_v3(self) -> None: + model = torchvision.models.mobilenetv3.mobilenet_v3_small(pretrained=True) + MemoryFormatOpsPassTestUtils.memory_format_test_runner( + self, + MemoryFormatTestSet( + module=model.eval(), + sample_input=(torch.randn(1, 3, 224, 224),), + target_memory_format=torch.contiguous_format, + op_level_check=False, + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, + atol=1e-3, + rtol=1e-3, + ), + ) + + def test_mobilenet_v3_xnnpack(self) -> None: + model = torchvision.models.mobilenetv3.mobilenet_v3_small(pretrained=True) + MemoryFormatOpsPassTestUtils.memory_format_test_runner( + self, + MemoryFormatTestSet( + module=model.eval(), + sample_input=(torch.randn(1, 3, 224, 224),), + target_memory_format=torch.contiguous_format, + op_level_check=False, + use_xnnpack=True, + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, + atol=1e-3, + rtol=1e-3, + ), + ) diff --git a/exir/tests/test_memory_format_ops_pass_aten.py b/exir/tests/test_memory_format_ops_pass_aten.py index ab8716d025..601893fd23 100644 --- a/exir/tests/test_memory_format_ops_pass_aten.py +++ b/exir/tests/test_memory_format_ops_pass_aten.py @@ -7,6 +7,7 @@ import unittest import torch +import torchvision from executorch.exir.tests.test_memory_format_ops_pass_utils import ( MemoryFormatOpsPassTestUtils, @@ -77,3 +78,33 @@ def test_op_dim_order_propagation_aten(self) -> None: _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, ), ) + + def test_resnet18(self) -> None: + model = torchvision.models.resnet18() + MemoryFormatOpsPassTestUtils.memory_format_test_runner( + self, + MemoryFormatTestSet( + module=model.eval(), + sample_input=(torch.randn(1, 3, 224, 224),), + target_memory_format=torch.contiguous_format, + op_level_check=False, + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, + atol=1e-3, + rtol=1e-3, + ), + ) + + def test_mobilenet_v3(self) -> None: + model = torchvision.models.mobilenetv3.mobilenet_v3_small(pretrained=True) + MemoryFormatOpsPassTestUtils.memory_format_test_runner( + self, + MemoryFormatTestSet( + module=model.eval(), + sample_input=(torch.randn(1, 3, 224, 224),), + target_memory_format=torch.contiguous_format, + op_level_check=False, + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, + atol=1e-3, + rtol=1e-3, + ), + ) diff --git a/exir/tests/test_memory_format_ops_pass_utils.py b/exir/tests/test_memory_format_ops_pass_utils.py index 93d790d491..dc02d25738 100644 --- a/exir/tests/test_memory_format_ops_pass_utils.py +++ b/exir/tests/test_memory_format_ops_pass_utils.py @@ -11,7 +11,9 @@ from typing import Any, Tuple import torch -from executorch.exir import to_edge + +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.exir import to_edge, to_edge_transform_and_lower from executorch.exir.capture._config import EdgeCompileConfig from executorch.exir.dim_order_utils import ( @@ -30,6 +32,10 @@ class MemoryFormatTestSet: sample_input: Tuple[Any, ...] target_memory_format: torch.memory_format _load_for_executorch_from_buffer: Any + op_level_check: bool = True + use_xnnpack: bool = False + rtol: float = 1e-05 + atol: float = 1e-08 class SimpleToCopyContiguousModule(torch.nn.Module): @@ -63,27 +69,42 @@ class MemoryFormatOpsPassTestUtils: def memory_format_test_runner( test_class: unittest.TestCase, test_set: MemoryFormatTestSet ): - aten_op_str = "torch.ops.aten._to_copy.default" - edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" - before = export(test_set.module, test_set.sample_input) - # check op strings before - FileCheck().check_count(aten_op_str, 1, exactly=True).check_not( - edge_op_str - ).run(before.graph_module.code) + if test_set.use_xnnpack: + epm = to_edge_transform_and_lower( + before, + compile_config=EdgeCompileConfig( + _skip_dim_order=False, _check_ir_validity=False + ), + partitioner=[XnnpackPartitioner()], + ) + else: + epm = to_edge( + before, compile_config=EdgeCompileConfig(_skip_dim_order=False) + ) + + # check memory format ops, if needed + if test_set.op_level_check: + aten_op_str = "torch.ops.aten._to_copy.default" + edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" - epm = to_edge(before, compile_config=EdgeCompileConfig(_skip_dim_order=False)) + # check op strings before + FileCheck().check_count(aten_op_str, 1, exactly=True).check_not( + edge_op_str + ).run(before.graph_module.code) - # check op strings - FileCheck().check_not(aten_op_str).check_count( - edge_op_str, 1, exactly=True - ).run(epm.exported_program().graph_module.code) + # check op strings + FileCheck().check_not(aten_op_str).check_count( + edge_op_str, 1, exactly=True + ).run(epm.exported_program().graph_module.code) # check EdgeOp and the new BackendOp should behave the same expected = before.module()(*test_set.sample_input) actual = epm.exported_program().module()(*test_set.sample_input) - test_class.assertTrue(torch.allclose(actual, expected)) + test_class.assertTrue( + torch.allclose(actual, expected, atol=test_set.atol, rtol=test_set.rtol) + ) test_class.assertEqual( is_channel_last_dim_order(actual), is_channel_last_dim_order(expected), @@ -105,7 +126,11 @@ def memory_format_test_runner( runtime_output = executorch_module.run_method( "forward", tuple(inputs_flattened) )[0] - test_class.assertTrue(torch.allclose(runtime_output, expected)) + test_class.assertTrue( + torch.allclose( + runtime_output, expected, atol=test_set.atol, rtol=test_set.rtol + ) + ) test_class.assertEqual( is_channel_last_dim_order(runtime_output), is_channel_last_dim_order(expected),