Skip to content

Commit

Permalink
Run test_base_fp8 for compute capability 8.9 or later
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu committed Sep 17, 2024
1 parent bc93437 commit 4e79a7a
Showing 1 changed file with 12 additions and 17 deletions.
29 changes: 12 additions & 17 deletions tests/py/dynamo/models/test_models_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torchvision.models as models
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
from transformers import BertModel
from transformers.utils.fx import symbolic_trace as transformers_trace

from packaging.version import Version

Expand Down Expand Up @@ -196,16 +195,18 @@ def test_resnet18_half(ir):


@unittest.skipIf(
torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
"FP8 compilation in Torch-TRT is not supported on cards older than Hopper",
torch.cuda.get_device_capability() < (8, 9),
"FP8 quantization requires compute capability 8.9 or later",
)
@unittest.skipIf(
not importlib.util.find_spec("modelopt"),
reason="ModelOpt is necessary to run this test",
"ModelOpt is required to run this test",
)
@pytest.mark.unit
def test_base_fp8(ir):
import modelopt
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode
from torch.export._trace import _export

class SimpleNetwork(torch.nn.Module):
def __init__(self):
Expand All @@ -219,9 +220,6 @@ def forward(self, x):
x = self.linear2(x)
return x

import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode

def calibrate_loop(model):
"""Simple calibration function for testing."""
model(input_tensor)
Expand All @@ -236,7 +234,7 @@ def calibrate_loop(model):

with torch.no_grad():
with export_torch_mode():
exp_program = torch.export.export(model, (input_tensor,))
exp_program = _export(model, (input_tensor,))
trt_model = torchtrt.dynamo.compile(
exp_program,
inputs=[input_tensor],
Expand All @@ -247,7 +245,7 @@ def calibrate_loop(model):
reuse_cached_engines=False,
)
outputs_trt = trt_model(input_tensor)
assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=1e-2)
assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2)


@unittest.skipIf(
Expand All @@ -258,7 +256,9 @@ def calibrate_loop(model):
)
@pytest.mark.unit
def test_base_int8(ir):
import modelopt
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode
from torch.export._trace import _export

class SimpleNetwork(torch.nn.Module):
def __init__(self):
Expand All @@ -272,9 +272,6 @@ def forward(self, x):
x = self.linear2(x)
return x

import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode

def calibrate_loop(model):
"""Simple calibration function for testing."""
model(input_tensor)
Expand All @@ -289,8 +286,6 @@ def calibrate_loop(model):

with torch.no_grad():
with export_torch_mode():
from torch.export._trace import _export

exp_program = _export(model, (input_tensor,))
trt_model = torchtrt.dynamo.compile(
exp_program,
Expand All @@ -302,4 +297,4 @@ def calibrate_loop(model):
reuse_cached_engines=False,
)
outputs_trt = trt_model(input_tensor)
assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=1e-2)
assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2)

0 comments on commit 4e79a7a

Please sign in to comment.