Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support weight-stripped engine and REFIT_IDENTICAL flag #3167

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def compile(
engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR,
engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE,
custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE,
refit_identical_engine_weights: bool = _defaults.REFIT_IDENTICAL_ENGINE_WEIGHTS,
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -158,6 +160,8 @@ def compile(
engine_cache_dir (Optional[str]): Directory to store the cached TRT engines
engine_cache_size (Optional[int]): Maximum hard-disk space (bytes) to use for the engine cache, default is 1GB. If the cache exceeds this size, the oldest engines will be removed by default
custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored.
refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs.
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -281,6 +285,8 @@ def compile(
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"refit_identical_engine_weights": refit_identical_engine_weights,
"strip_engine_weights": strip_engine_weights,
}

settings = CompilationSettings(**compilation_options)
Expand Down Expand Up @@ -522,6 +528,8 @@ def convert_exported_program_to_serialized_trt_engine(
calibrator: object = None,
allow_shape_tensors: bool = False,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
refit_identical_engine_weights: bool = _defaults.REFIT_IDENTICAL_ENGINE_WEIGHTS,
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
**kwargs: Any,
) -> bytes:
"""Convert an ExportedProgram to a serialized TensorRT engine
Expand Down Expand Up @@ -580,6 +588,8 @@ def convert_exported_program_to_serialized_trt_engine(
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs.
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
Returns:
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
"""
Expand Down Expand Up @@ -653,6 +663,8 @@ def convert_exported_program_to_serialized_trt_engine(
"dla_local_dram_size": dla_local_dram_size,
"dla_global_dram_size": dla_global_dram_size,
"timing_cache_path": timing_cache_path,
"refit_identical_engine_weights": refit_identical_engine_weights,
"strip_engine_weights": strip_engine_weights,
}

exported_program = pre_export_lowering(exported_program)
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache")
ENGINE_CACHE_SIZE = 1073741824
CUSTOM_ENGINE_CACHE = None
REFIT_IDENTICAL_ENGINE_WEIGHTS = False
STRIP_ENGINE_WEIGHTS = False


def default_device() -> Device:
Expand Down
8 changes: 8 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
NUM_AVG_TIMING_ITERS,
OPTIMIZATION_LEVEL,
PASS_THROUGH_BUILD_FAILURES,
REFIT_IDENTICAL_ENGINE_WEIGHTS,
REQUIRE_FULL_COMPILATION,
REUSE_CACHED_ENGINES,
SPARSE_WEIGHTS,
STRIP_ENGINE_WEIGHTS,
TIMING_CACHE_PATH,
TRUNCATE_DOUBLE,
USE_FAST_PARTITIONER,
Expand Down Expand Up @@ -78,6 +80,8 @@ class CompilationSettings:
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
refit_identical_engine_weights (bool): Whether to refit the engine with identical weights
strip_engine_weights (bool): Whether to strip the engine weights
"""

enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
Expand Down Expand Up @@ -112,6 +116,8 @@ class CompilationSettings:
lazy_engine_init: bool = LAZY_ENGINE_INIT
cache_built_engines: bool = CACHE_BUILT_ENGINES
reuse_cached_engines: bool = REUSE_CACHED_ENGINES
refit_identical_engine_weights: bool = REFIT_IDENTICAL_ENGINE_WEIGHTS
strip_engine_weights: bool = STRIP_ENGINE_WEIGHTS


_SETTINGS_TO_BE_ENGINE_INVARIANT = (
Expand All @@ -124,6 +130,8 @@ class CompilationSettings:
"make_refittable",
"engine_capability",
"hardware_compatible",
"refit_identical_engine_weights",
"strip_engine_weights",
)


Expand Down
74 changes: 55 additions & 19 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)

import numpy as np
import tensorrt as trt
import torch
import torch.fx
from torch.fx.node import _get_qualified_name
Expand All @@ -43,7 +44,6 @@
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.logging import TRT_LOGGER

import tensorrt as trt
from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -283,7 +283,16 @@ def _populate_trt_builder_config(
builder_config.clear_flag(trt.BuilderFlag.TF32)

if self.compilation_settings.make_refittable:
builder_config.set_flag(trt.BuilderFlag.REFIT)
if version.parse(trt.__version__) >= version.parse("10.0"):
if self.compilation_settings.refit_identical_engine_weights:
builder_config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL)
else:
builder_config.set_flag(trt.BuilderFlag.REFIT)
else:
builder_config.set_flag(trt.BuilderFlag.REFIT)

if self.compilation_settings.strip_engine_weights:
builder_config.set_flag(trt.BuilderFlag.STRIP_PLAN)

if strict_type_constraints:
builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES)
Expand Down Expand Up @@ -573,24 +582,31 @@ def run(
"Found the cached engine that corresponds to this graph. It is directly loaded."
)

runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(serialized_engine)
# refit the cached engine with the new graph module
if not self.compilation_settings.strip_engine_weights:
runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(serialized_engine)

from torch_tensorrt.dynamo._refit import (
_refit_single_trt_engine_with_gm,
)
from torch_tensorrt.dynamo._refit import (
_refit_single_trt_engine_with_gm,
)

# TODO: Fast refit is problematic for now. It will fail if the engine has batch_norm layers.
# We set weight_name_map=None to use slow refit anyway for now. Will fix it in the future.
_refit_single_trt_engine_with_gm(
new_gm=self.module,
old_engine=engine,
input_list=self.input_specs,
settings=self.compilation_settings,
weight_name_map=None,
)
_refit_single_trt_engine_with_gm(
new_gm=self.module,
old_engine=engine,
input_list=self.input_specs,
settings=self.compilation_settings,
weight_name_map=self.weight_name_map,
)

serialized_engine = engine.serialize()
# Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared
serialization_config = engine.create_serialization_config()
serialization_config.clear_flag(
trt.SerializationFlag.EXCLUDE_WEIGHTS
)
serialized_engine = engine.serialize_with_config(
serialization_config
)

with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
Expand Down Expand Up @@ -632,14 +648,31 @@ def run(
self._save_timing_cache(
builder_config, self.compilation_settings.timing_cache_path
)

if (
self.engine_cache is not None
and self.compilation_settings.cache_built_engines
):
assert (
self.compilation_settings.make_refittable
), "weight-stripped engines must be refittable, please set make_refittable=True"

# no matter what compilation_settings is, we cache the weight-stripped engine
if self.compilation_settings.strip_engine_weights:
weight_stripped_serialized_engine = serialized_engine
else:
runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(serialized_engine)
serialization_config = engine.create_serialization_config()
serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
weight_stripped_serialized_engine = engine.serialize_with_config(
serialization_config
)

self.engine_cache.insert(
hash_val,
(
serialized_engine,
weight_stripped_serialized_engine,
self._input_names,
self._output_names,
self.input_specs,
Expand All @@ -653,7 +686,10 @@ def run(
engine_str = engine_bytes.getvalue()

return TRTInterpreterResult(
engine_str, self._input_names, self._output_names, self.weight_name_map
engine_str,
self._input_names,
self._output_names,
self.weight_name_map,
)

def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
Expand Down
3 changes: 1 addition & 2 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from typing import Any, List, Optional, Sequence

import tensorrt as trt
import torch
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch_tensorrt._Device import Device
Expand All @@ -18,8 +19,6 @@
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.utils import get_model_device, get_torch_inputs

import tensorrt as trt

logger = logging.getLogger(__name__)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tempfile import tempdir
from typing import Any, Dict, List, Optional, Sequence, Tuple

import tensorrt as trt
import torch
import torch_tensorrt
from torch.nn import Module
Expand All @@ -19,8 +20,6 @@
multi_gpu_device_check,
)

import tensorrt as trt

logger = logging.getLogger(__name__)


Expand All @@ -39,7 +38,7 @@ def __init__(
*,
name: str = "",
settings: CompilationSettings = CompilationSettings(),
weight_name_map: Any = None,
weight_name_map: Optional[dict[Any, Any]] = None,
):
"""Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine
Expand All @@ -52,6 +51,7 @@ def __init__(
Keyword Arguments:
name (str): Name for module
settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed
weight_name_map (dict): Mapping of engine weight name to state_dict weight name

Example:

Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
Keyword Arguments:
name (str): Name for module
settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed
weight_name_map (dict): Mapping of engine weight name to state_dict weight name

Example:

Expand Down
13 changes: 5 additions & 8 deletions tests/py/dynamo/models/test_engine_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
for i in range(3):
# remove timing cache and reset dynamo for engine caching messurement
remove_timing_cache()
torch._dynamo.reset()
if i == 0:
Expand All @@ -220,7 +221,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
trt_gm = torch_trt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=False,
use_python_runtime=True,
enabled_precisions={torch.float},
debug=False,
min_block_size=1,
Expand All @@ -231,7 +232,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
)
end.record()
torch.cuda.synchronize()
torch._dynamo.reset()
times.append(start.elapsed_time(end))
results.append(trt_gm(*inputs))

Expand Down Expand Up @@ -285,7 +285,7 @@ def test_dynamo_compile_with_custom_engine_cache(self):
trt_gm = torch_trt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=False,
use_python_runtime=True,
enabled_precisions={torch.float},
debug=False,
min_block_size=1,
Expand Down Expand Up @@ -387,7 +387,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
model,
backend="tensorrt",
options={
"use_python_runtime": True,
"use_python_runtime": False,
"enabled_precisions": {torch.float},
"debug": False,
"min_block_size": 1,
Expand All @@ -402,7 +402,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
results.append(compiled_model(*inputs)) # trigger the compilation
end.record()
torch.cuda.synchronize()
torch._dynamo.reset()
times.append(start.elapsed_time(end))

cos_sim = cosine_similarity(results[0], results[1])
Expand Down Expand Up @@ -441,7 +440,6 @@ def test_torch_compile_with_custom_engine_cache(self):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
for i in range(3):
# remove timing cache and reset dynamo for engine caching messurement
if i == 0:
cache_built_engines = False
reuse_cached_engines = False
Expand All @@ -454,7 +452,7 @@ def test_torch_compile_with_custom_engine_cache(self):
model,
backend="tensorrt",
options={
"use_python_runtime": True,
"use_python_runtime": False,
"enabled_precisions": {torch.float},
"debug": False,
"min_block_size": 1,
Expand Down Expand Up @@ -501,7 +499,6 @@ def test_torch_compile_change_input_shape(self):

custom_engine_cache = MyEngineCache(engine_cache_dir)
for i in range(3):
# remove timing cache and reset dynamo for engine caching messurement
inputs = [torch.rand((4 * (i + 1), 3, 224, 224)).to("cuda")]
compiled_model = torch.compile(
model,
Expand Down
Loading
Loading