Skip to content

Commit

Permalink
introduce model-level end2end tests to dim order tests with different…
Browse files Browse the repository at this point in the history
… delegate (pytorch#6093)

Summary:
Pull Request resolved: pytorch#6093

This diff introduced end2end tests on several models + delegation combinations.

Models: llama2, resnet18, mobilenet_v3
Delegate: no delegate, xnnpack

Reviewed By: digantdesai, larryliu0820

Differential Revision: D64174329

fbshipit-source-id: 0807e0282d136bf1ef6d5be88e0c9f8512580f38
  • Loading branch information
Gasoonjia authored and facebook-github-bot committed Oct 10, 2024
1 parent 69766fb commit a43b4a6
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 15 deletions.
3 changes: 3 additions & 0 deletions exir/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ python_unittest(
":test_memory_format_ops_pass_utils",
"//caffe2:torch",
"//executorch/extension/pybindings:aten_lib", # @manual
"//pytorch/vision:torchvision", # @manual
],
)

Expand All @@ -394,6 +395,7 @@ python_unittest(
"//executorch/exir/dialects:lib",
"//executorch/exir/dialects/edge:lib",
"//executorch/extension/pybindings:portable_lib", # @manual
"//pytorch/vision:torchvision", # @manual
],
)

Expand All @@ -404,6 +406,7 @@ python_library(
],
deps = [
"//caffe2:torch",
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
"//executorch/exir:dim_order_utils",
"//executorch/exir:lib",
"//executorch/exir/capture:config",
Expand Down
64 changes: 64 additions & 0 deletions exir/tests/test_memory_format_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import Union

import torch

import torchvision
from executorch.exir import EdgeCompileConfig, to_edge
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
Expand Down Expand Up @@ -264,3 +266,65 @@ def call_operator(self, op, args, kwargs, meta):

self.assertTrue(is_contiguous_dim_order(actual))
self.assertTrue(is_contiguous_dim_order(expected))

def test_resnet18(self) -> None:
model = torchvision.models.resnet18()
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=model.eval(),
sample_input=(torch.randn(1, 3, 224, 224),),
target_memory_format=torch.contiguous_format,
op_level_check=False,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
atol=1e-3,
rtol=1e-3,
),
)

def test_resnet18_xnnpack(self) -> None:
model = torchvision.models.resnet18()
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=model.eval(),
sample_input=(torch.randn(1, 3, 224, 224),),
target_memory_format=torch.contiguous_format,
op_level_check=False,
use_xnnpack=True,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
atol=1e-3,
rtol=1e-3,
),
)

def test_mobilenet_v3(self) -> None:
model = torchvision.models.mobilenetv3.mobilenet_v3_small(pretrained=True)
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=model.eval(),
sample_input=(torch.randn(1, 3, 224, 224),),
target_memory_format=torch.contiguous_format,
op_level_check=False,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
atol=1e-3,
rtol=1e-3,
),
)

def test_mobilenet_v3_xnnpack(self) -> None:
model = torchvision.models.mobilenetv3.mobilenet_v3_small(pretrained=True)
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=model.eval(),
sample_input=(torch.randn(1, 3, 224, 224),),
target_memory_format=torch.contiguous_format,
op_level_check=False,
use_xnnpack=True,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
atol=1e-3,
rtol=1e-3,
),
)
31 changes: 31 additions & 0 deletions exir/tests/test_memory_format_ops_pass_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import unittest

import torch
import torchvision

from executorch.exir.tests.test_memory_format_ops_pass_utils import (
MemoryFormatOpsPassTestUtils,
Expand Down Expand Up @@ -77,3 +78,33 @@ def test_op_dim_order_propagation_aten(self) -> None:
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
),
)

def test_resnet18(self) -> None:
model = torchvision.models.resnet18()
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=model.eval(),
sample_input=(torch.randn(1, 3, 224, 224),),
target_memory_format=torch.contiguous_format,
op_level_check=False,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
atol=1e-3,
rtol=1e-3,
),
)

def test_mobilenet_v3(self) -> None:
model = torchvision.models.mobilenetv3.mobilenet_v3_small(pretrained=True)
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=model.eval(),
sample_input=(torch.randn(1, 3, 224, 224),),
target_memory_format=torch.contiguous_format,
op_level_check=False,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
atol=1e-3,
rtol=1e-3,
),
)
55 changes: 40 additions & 15 deletions exir/tests/test_memory_format_ops_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from typing import Any, Tuple

import torch
from executorch.exir import to_edge

from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import to_edge, to_edge_transform_and_lower
from executorch.exir.capture._config import EdgeCompileConfig

from executorch.exir.dim_order_utils import (
Expand All @@ -30,6 +32,10 @@ class MemoryFormatTestSet:
sample_input: Tuple[Any, ...]
target_memory_format: torch.memory_format
_load_for_executorch_from_buffer: Any
op_level_check: bool = True
use_xnnpack: bool = False
rtol: float = 1e-05
atol: float = 1e-08


class SimpleToCopyContiguousModule(torch.nn.Module):
Expand Down Expand Up @@ -63,27 +69,42 @@ class MemoryFormatOpsPassTestUtils:
def memory_format_test_runner(
test_class: unittest.TestCase, test_set: MemoryFormatTestSet
):
aten_op_str = "torch.ops.aten._to_copy.default"
edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"

before = export(test_set.module, test_set.sample_input)

# check op strings before
FileCheck().check_count(aten_op_str, 1, exactly=True).check_not(
edge_op_str
).run(before.graph_module.code)
if test_set.use_xnnpack:
epm = to_edge_transform_and_lower(
before,
compile_config=EdgeCompileConfig(
_skip_dim_order=False, _check_ir_validity=False
),
partitioner=[XnnpackPartitioner()],
)
else:
epm = to_edge(
before, compile_config=EdgeCompileConfig(_skip_dim_order=False)
)

# check memory format ops, if needed
if test_set.op_level_check:
aten_op_str = "torch.ops.aten._to_copy.default"
edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"

epm = to_edge(before, compile_config=EdgeCompileConfig(_skip_dim_order=False))
# check op strings before
FileCheck().check_count(aten_op_str, 1, exactly=True).check_not(
edge_op_str
).run(before.graph_module.code)

# check op strings
FileCheck().check_not(aten_op_str).check_count(
edge_op_str, 1, exactly=True
).run(epm.exported_program().graph_module.code)
# check op strings
FileCheck().check_not(aten_op_str).check_count(
edge_op_str, 1, exactly=True
).run(epm.exported_program().graph_module.code)

# check EdgeOp and the new BackendOp should behave the same
expected = before.module()(*test_set.sample_input)
actual = epm.exported_program().module()(*test_set.sample_input)
test_class.assertTrue(torch.allclose(actual, expected))
test_class.assertTrue(
torch.allclose(actual, expected, atol=test_set.atol, rtol=test_set.rtol)
)
test_class.assertEqual(
is_channel_last_dim_order(actual),
is_channel_last_dim_order(expected),
Expand All @@ -105,7 +126,11 @@ def memory_format_test_runner(
runtime_output = executorch_module.run_method(
"forward", tuple(inputs_flattened)
)[0]
test_class.assertTrue(torch.allclose(runtime_output, expected))
test_class.assertTrue(
torch.allclose(
runtime_output, expected, atol=test_set.atol, rtol=test_set.rtol
)
)
test_class.assertEqual(
is_channel_last_dim_order(runtime_output),
is_channel_last_dim_order(expected),
Expand Down

0 comments on commit a43b4a6

Please sign in to comment.