diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 96e5f313ae..ab757d5f27 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -88,6 +88,8 @@ def compile( engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR, engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE, custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE, + refit_identical_engine_weights: bool = _defaults.REFIT_IDENTICAL_ENGINE_WEIGHTS, + strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -158,6 +160,8 @@ def compile( engine_cache_dir (Optional[str]): Directory to store the cached TRT engines engine_cache_size (Optional[int]): Maximum hard-disk space (bytes) to use for the engine cache, default is 1GB. If the cache exceeds this size, the oldest engines will be removed by default custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored. + refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs. + strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -281,6 +285,8 @@ def compile( "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, + "refit_identical_engine_weights": refit_identical_engine_weights, + "strip_engine_weights": strip_engine_weights, } settings = CompilationSettings(**compilation_options) @@ -522,6 +528,8 @@ def convert_exported_program_to_serialized_trt_engine( calibrator: object = None, allow_shape_tensors: bool = False, timing_cache_path: str = _defaults.TIMING_CACHE_PATH, + refit_identical_engine_weights: bool = _defaults.REFIT_IDENTICAL_ENGINE_WEIGHTS, + strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS, **kwargs: Any, ) -> bytes: """Convert an ExportedProgram to a serialized TensorRT engine @@ -580,6 +588,8 @@ def convert_exported_program_to_serialized_trt_engine( calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation + refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs. + strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required. Returns: bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs """ @@ -653,6 +663,8 @@ def convert_exported_program_to_serialized_trt_engine( "dla_local_dram_size": dla_local_dram_size, "dla_global_dram_size": dla_global_dram_size, "timing_cache_path": timing_cache_path, + "refit_identical_engine_weights": refit_identical_engine_weights, + "strip_engine_weights": strip_engine_weights, } exported_program = pre_export_lowering(exported_program) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 68e446dab5..982e48c1d5 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -40,6 +40,8 @@ ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache") ENGINE_CACHE_SIZE = 1073741824 CUSTOM_ENGINE_CACHE = None +REFIT_IDENTICAL_ENGINE_WEIGHTS = False +STRIP_ENGINE_WEIGHTS = False def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index f8886fbd67..e4458b187b 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -24,9 +24,11 @@ NUM_AVG_TIMING_ITERS, OPTIMIZATION_LEVEL, PASS_THROUGH_BUILD_FAILURES, + REFIT_IDENTICAL_ENGINE_WEIGHTS, REQUIRE_FULL_COMPILATION, REUSE_CACHED_ENGINES, SPARSE_WEIGHTS, + STRIP_ENGINE_WEIGHTS, TIMING_CACHE_PATH, TRUNCATE_DOUBLE, USE_FAST_PARTITIONER, @@ -78,6 +80,8 @@ class CompilationSettings: timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage + refit_identical_engine_weights (bool): Whether to refit the engine with identical weights + strip_engine_weights (bool): Whether to strip the engine weights """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -112,6 +116,8 @@ class CompilationSettings: lazy_engine_init: bool = LAZY_ENGINE_INIT cache_built_engines: bool = CACHE_BUILT_ENGINES reuse_cached_engines: bool = REUSE_CACHED_ENGINES + refit_identical_engine_weights: bool = REFIT_IDENTICAL_ENGINE_WEIGHTS + strip_engine_weights: bool = STRIP_ENGINE_WEIGHTS _SETTINGS_TO_BE_ENGINE_INVARIANT = ( @@ -124,6 +130,8 @@ class CompilationSettings: "make_refittable", "engine_capability", "hardware_compatible", + "refit_identical_engine_weights", + "strip_engine_weights", ) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index ff35bf39d7..956e21775d 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -18,6 +18,7 @@ ) import numpy as np +import tensorrt as trt import torch import torch.fx from torch.fx.node import _get_qualified_name @@ -43,7 +44,6 @@ from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER -import tensorrt as trt from packaging import version _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -283,7 +283,16 @@ def _populate_trt_builder_config( builder_config.clear_flag(trt.BuilderFlag.TF32) if self.compilation_settings.make_refittable: - builder_config.set_flag(trt.BuilderFlag.REFIT) + if version.parse(trt.__version__) >= version.parse("10.0"): + if self.compilation_settings.refit_identical_engine_weights: + builder_config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL) + else: + builder_config.set_flag(trt.BuilderFlag.REFIT) + else: + builder_config.set_flag(trt.BuilderFlag.REFIT) + + if self.compilation_settings.strip_engine_weights: + builder_config.set_flag(trt.BuilderFlag.STRIP_PLAN) if strict_type_constraints: builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES) @@ -573,24 +582,31 @@ def run( "Found the cached engine that corresponds to this graph. It is directly loaded." ) - runtime = trt.Runtime(TRT_LOGGER) - engine = runtime.deserialize_cuda_engine(serialized_engine) + # refit the cached engine with the new graph module + if not self.compilation_settings.strip_engine_weights: + runtime = trt.Runtime(TRT_LOGGER) + engine = runtime.deserialize_cuda_engine(serialized_engine) - from torch_tensorrt.dynamo._refit import ( - _refit_single_trt_engine_with_gm, - ) + from torch_tensorrt.dynamo._refit import ( + _refit_single_trt_engine_with_gm, + ) - # TODO: Fast refit is problematic for now. It will fail if the engine has batch_norm layers. - # We set weight_name_map=None to use slow refit anyway for now. Will fix it in the future. - _refit_single_trt_engine_with_gm( - new_gm=self.module, - old_engine=engine, - input_list=self.input_specs, - settings=self.compilation_settings, - weight_name_map=None, - ) + _refit_single_trt_engine_with_gm( + new_gm=self.module, + old_engine=engine, + input_list=self.input_specs, + settings=self.compilation_settings, + weight_name_map=self.weight_name_map, + ) - serialized_engine = engine.serialize() + # Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared + serialization_config = engine.create_serialization_config() + serialization_config.clear_flag( + trt.SerializationFlag.EXCLUDE_WEIGHTS + ) + serialized_engine = engine.serialize_with_config( + serialization_config + ) with io.BytesIO() as engine_bytes: engine_bytes.write(serialized_engine) @@ -632,14 +648,31 @@ def run( self._save_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) + if ( self.engine_cache is not None and self.compilation_settings.cache_built_engines ): + assert ( + self.compilation_settings.make_refittable + ), "weight-stripped engines must be refittable, please set make_refittable=True" + + # no matter what compilation_settings is, we cache the weight-stripped engine + if self.compilation_settings.strip_engine_weights: + weight_stripped_serialized_engine = serialized_engine + else: + runtime = trt.Runtime(TRT_LOGGER) + engine = runtime.deserialize_cuda_engine(serialized_engine) + serialization_config = engine.create_serialization_config() + serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) + weight_stripped_serialized_engine = engine.serialize_with_config( + serialization_config + ) + self.engine_cache.insert( hash_val, ( - serialized_engine, + weight_stripped_serialized_engine, self._input_names, self._output_names, self.input_specs, @@ -653,7 +686,10 @@ def run( engine_str = engine_bytes.getvalue() return TRTInterpreterResult( - engine_str, self._input_names, self._output_names, self.weight_name_map + engine_str, + self._input_names, + self._output_names, + self.weight_name_map, ) def run_node(self, n: torch.fx.Node) -> torch.fx.Node: diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index f0b65b3a6e..06fade9674 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -3,6 +3,7 @@ import logging from typing import Any, List, Optional, Sequence +import tensorrt as trt import torch from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._Device import Device @@ -18,8 +19,6 @@ from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule from torch_tensorrt.dynamo.utils import get_model_device, get_torch_inputs -import tensorrt as trt - logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index f74c239550..cc7c5da30e 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -5,6 +5,7 @@ from tempfile import tempdir from typing import Any, Dict, List, Optional, Sequence, Tuple +import tensorrt as trt import torch import torch_tensorrt from torch.nn import Module @@ -19,8 +20,6 @@ multi_gpu_device_check, ) -import tensorrt as trt - logger = logging.getLogger(__name__) @@ -39,7 +38,7 @@ def __init__( *, name: str = "", settings: CompilationSettings = CompilationSettings(), - weight_name_map: Any = None, + weight_name_map: Optional[dict[Any, Any]] = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine @@ -52,6 +51,7 @@ def __init__( Keyword Arguments: name (str): Name for module settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed + weight_name_map (dict): Mapping of engine weight name to state_dict weight name Example: diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 7bf42da7f0..03daeada5f 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -96,6 +96,7 @@ def __init__( Keyword Arguments: name (str): Name for module settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed + weight_name_map (dict): Mapping of engine weight name to state_dict weight name Example: diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py index 367f68c1f6..5dcdfe4ae9 100644 --- a/tests/py/dynamo/models/test_engine_cache.py +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -206,6 +206,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) for i in range(3): + # remove timing cache and reset dynamo for engine caching messurement remove_timing_cache() torch._dynamo.reset() if i == 0: @@ -220,7 +221,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): trt_gm = torch_trt.dynamo.compile( exp_program, tuple(inputs), - use_python_runtime=False, + use_python_runtime=True, enabled_precisions={torch.float}, debug=False, min_block_size=1, @@ -231,7 +232,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): ) end.record() torch.cuda.synchronize() - torch._dynamo.reset() times.append(start.elapsed_time(end)) results.append(trt_gm(*inputs)) @@ -285,7 +285,7 @@ def test_dynamo_compile_with_custom_engine_cache(self): trt_gm = torch_trt.dynamo.compile( exp_program, tuple(inputs), - use_python_runtime=False, + use_python_runtime=True, enabled_precisions={torch.float}, debug=False, min_block_size=1, @@ -387,7 +387,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): model, backend="tensorrt", options={ - "use_python_runtime": True, + "use_python_runtime": False, "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, @@ -402,7 +402,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): results.append(compiled_model(*inputs)) # trigger the compilation end.record() torch.cuda.synchronize() - torch._dynamo.reset() times.append(start.elapsed_time(end)) cos_sim = cosine_similarity(results[0], results[1]) @@ -441,7 +440,6 @@ def test_torch_compile_with_custom_engine_cache(self): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) for i in range(3): - # remove timing cache and reset dynamo for engine caching messurement if i == 0: cache_built_engines = False reuse_cached_engines = False @@ -454,7 +452,7 @@ def test_torch_compile_with_custom_engine_cache(self): model, backend="tensorrt", options={ - "use_python_runtime": True, + "use_python_runtime": False, "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, @@ -501,7 +499,6 @@ def test_torch_compile_change_input_shape(self): custom_engine_cache = MyEngineCache(engine_cache_dir) for i in range(3): - # remove timing cache and reset dynamo for engine caching messurement inputs = [torch.rand((4 * (i + 1), 3, 224, 224)).to("cuda")] compiled_model = torch.compile( model, diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py new file mode 100644 index 0000000000..196800e758 --- /dev/null +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -0,0 +1,314 @@ +import os +import pickle +import shutil +import unittest + +import torch +import torch_tensorrt as torch_trt +import torchvision.models as models +from torch.testing._internal.common_utils import TestCase +from torch_tensorrt.dynamo import convert_exported_program_to_serialized_trt_engine +from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +assertions = unittest.TestCase() + + +class TestWeightStrippedEngine(TestCase): + def test_weight_stripped_engine_sizes(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + exp_program = torch.export.export(pyt_model, example_inputs) + weight_included_engine = convert_exported_program_to_serialized_trt_engine( + exp_program, + example_inputs, + make_refittable=False, + strip_engine_weights=False, + refit_identical_engine_weights=False, + ) + weight_stripped_engine = convert_exported_program_to_serialized_trt_engine( + exp_program, + example_inputs, + make_refittable=True, + strip_engine_weights=True, + refit_identical_engine_weights=False, + ) + weight_stripped_refit_identical_engine = ( + convert_exported_program_to_serialized_trt_engine( + exp_program, + example_inputs, + make_refittable=True, + strip_engine_weights=True, + refit_identical_engine_weights=True, + ) + ) + assertions.assertTrue( + len(bytes(weight_included_engine)) > len(bytes(weight_stripped_engine)), + msg=f"Weight-stripped engine size is not smaller than the weight included engine size. Weight included engine size: {len(bytes(weight_included_engine))}, weight stripped engine size: {len(bytes(weight_stripped_engine))}", + ) + assertions.assertTrue( + len(bytes(weight_stripped_engine)) + > len(bytes(weight_stripped_refit_identical_engine)), + msg=f"Weight-stripped refit-identical engine size is not smaller than the weight-stripped engine size. Weight-stripped engine size: {len(bytes(weight_stripped_engine))}, weight-stripped refit-identical engine size: {len(bytes(weight_stripped_refit_identical_engine))}", + ) + + def test_weight_stripped_engine_results(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + # Mark the dim0 of inputs as dynamic + batch = torch.export.Dim("batch", min=1, max=200) + exp_program = torch.export.export( + pyt_model, args=example_inputs, dynamic_shapes={"x": {0: batch}} + ) + + inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] + + trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=True, + enabled_precisions={torch.float}, + debug=False, + min_block_size=1, + make_refittable=True, + cache_built_engines=False, + reuse_cached_engines=False, + refit_identical_engine_weights=False, + strip_engine_weights=True, + ) + output = trt_gm(*inputs) + assertions.assertEqual( + output.sum(), 0, msg="weight-stripped engine results should be all zeros" + ) + + from torch_tensorrt.dynamo._refit import refit_module_weights + + refitted_trt_gm = refit_module_weights(trt_gm, exp_program) + refitted_output = refitted_trt_gm(*inputs) + assertions.assertNotEqual( + refitted_output.sum(), + 0, + msg="refitted engine results should not be all zeros", + ) + + compiled_model = torch.compile( + pyt_model, + backend="tensorrt", + options={ + "use_python_runtime": False, + "enabled_precisions": {torch.float}, + "debug": False, + "min_block_size": 1, + "make_refittable": True, + "cache_built_engines": False, + "reuse_cached_engines": False, + "refit_identical_engine_weights": False, + "strip_engine_weights": False, + }, + ) + compiled_model_output = compiled_model(*inputs) + cos_sim = cosine_similarity(refitted_output, compiled_model_output) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"refitted_output doesn't match with compiled_model_output. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + def test_weight_stripped_engine_with_engine_cache(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + exp_program = torch.export.export(pyt_model, example_inputs) + + engine_cache_dir = "/tmp/test_weight_stripped_engine_with_engine_cache" + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + weight_included_engine = convert_exported_program_to_serialized_trt_engine( + exp_program, + example_inputs, + make_refittable=False, + strip_engine_weights=False, + refit_identical_engine_weights=False, + ) + + trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(example_inputs), + use_python_runtime=True, + enabled_precisions={torch.float}, + debug=False, + min_block_size=1, + make_refittable=True, + refit_identical_engine_weights=True, + strip_engine_weights=False, # engine cache will save the stripped engine even if this is False + cache_built_engines=True, + reuse_cached_engines=True, + engine_cache_dir=engine_cache_dir, + ) + output = trt_gm(*example_inputs) + + blob_path = os.path.join( + engine_cache_dir, os.listdir(engine_cache_dir)[0], "blob.bin" + ) + with open(blob_path, "rb") as f: + blob = f.read() + unpacked = pickle.loads(blob) + cached_stripped_engine = unpacked["serialized_engine"] + + assertions.assertTrue( + len(bytes(weight_included_engine)) > len(bytes(cached_stripped_engine)), + msg=f"cached engine size is not smaller than the weight included engine size. Weight included engine size: {len(bytes(weight_included_engine))}, cached stripped engine size: {len(bytes(cached_stripped_engine))}", + ) + assertions.assertNotEqual(output.sum(), 0, msg="results are all zeros") + + def test_dynamo_compile_with_refittable_weight_stripped_engine(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + exp_program = torch.export.export(pyt_model, args=example_inputs) + + engine_cache_dir = ( + "/tmp/test_dynamo_compile_with_refittable_weight_stripped_engine" + ) + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + def remove_timing_cache(path=TIMING_CACHE_PATH): + if os.path.exists(path): + os.remove(path) + + # The 1st iteration is to measure the compilation time without engine caching + # The 2nd and 3rd iterations are to measure the compilation time with engine caching. + # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. + # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. + inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] + results = [] + times = [] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + for i in range(3): + remove_timing_cache() + torch._dynamo.reset() + if i == 0: + cache_built_engines = False + reuse_cached_engines = False + else: + cache_built_engines = True + reuse_cached_engines = True + + torch.cuda.synchronize() + start.record() + trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=True, + enabled_precisions={torch.float}, + debug=False, + min_block_size=1, + make_refittable=True, + refit_identical_engine_weights=False, + strip_engine_weights=False, + cache_built_engines=cache_built_engines, + reuse_cached_engines=reuse_cached_engines, + engine_cache_dir=engine_cache_dir, + ) + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + results.append(trt_gm(*inputs)) + + assertions.assertNotEqual(results[0].sum(), 0, msg="results[0] are all zeros") + assertions.assertNotEqual(results[1].sum(), 0, msg="results[1] are all zeros") + assertions.assertNotEqual(results[2].sum(), 0, msg="results[2] are all zeros") + + cos_sim = cosine_similarity(results[0], results[1]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[0] doesn't match with results[1]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity(results[1], results[2]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[1] doesn't match with results[2]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + assertions.assertTrue( + times[0] > times[2], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", + ) + + def test_torch_compile_with_refittable_weight_stripped_engine(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + + engine_cache_dir = ( + "/tmp/test_torch_compile_with_refittable_weight_stripped_engine" + ) + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + def remove_timing_cache(path=TIMING_CACHE_PATH): + if os.path.exists(path): + os.remove(path) + + # The 1st iteration is to measure the compilation time without engine caching + # The 2nd and 3rd iterations are to measure the compilation time with engine caching. + # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. + # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. + inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] + results = [] + times = [] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + for i in range(3): + remove_timing_cache() + torch._dynamo.reset() + if i == 0: + cache_built_engines = False + reuse_cached_engines = False + else: + cache_built_engines = True + reuse_cached_engines = True + + torch.cuda.synchronize() + start.record() + compiled_model = torch.compile( + pyt_model, + backend="tensorrt", + options={ + "use_python_runtime": False, + "enabled_precisions": {torch.float}, + "debug": False, + "min_block_size": 1, + "make_refittable": True, + "cache_built_engines": cache_built_engines, + "reuse_cached_engines": reuse_cached_engines, + "engine_cache_dir": engine_cache_dir, + "torch_executed_ops": {"torch.ops.aten.relu.default"}, + "refit_identical_engine_weights": True, + "strip_engine_weights": False, + }, + ) + results.append(compiled_model(*inputs)) # trigger the compilation + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + + assertions.assertNotEqual(results[0].sum(), 0, msg="results[0] are all zeros") + assertions.assertNotEqual(results[1].sum(), 0, msg="results[1] are all zeros") + assertions.assertNotEqual(results[2].sum(), 0, msg="results[2] are all zeros") + + cos_sim = cosine_similarity(results[0], results[1]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[0] doesn't match with results[1]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity(results[1], results[2]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[1] doesn't match with results[2]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + assertions.assertTrue( + times[0] > times[2], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", + )