torch2trt is a PyTorch to TensorRT converter which utilizes the TensorRT Python API. The converter is
-
Easy to use - Convert modules with a single function call
torch2trt
-
Easy to extend - Write your own layer converter in Python and register it with
@tensorrt_converter
If you find an issue, please let us know!
Please note, this converter has limited coverage of TensorRT / PyTorch. We created it primarily to easily optimize the models used in the JetBot project. If you find the converter helpful with other models, please let us know.
Below are some usage examples, for more check out the notebooks.
import torch
from torch2trt import torch2trt
from torchvision.models.alexnet import alexnet
# create some regular pytorch model...
model = alexnet(pretrained=True).eval().cuda()
# create example data
x = torch.ones((1, 3, 224, 224)).cuda()
# convert to TensorRT feeding sample data as input
model_trt = torch2trt(model, [x])
We can execute the returned TRTModule
just like the original PyTorch model
y = model(x)
y_trt = model_trt(x)
# check the output against PyTorch
print(torch.max(torch.abs(y - y_trt)))
We can save the model as a state_dict
.
torch.save(model_trt.state_dict(), 'alexnet_trt.pth')
We can load the saved model into a TRTModule
from torch2trt import TRTModule
model_trt = TRTModule()
model_trt.load_state_dict(torch.load('alexnet_trt.pth'))
We tested the converter against these models using the test.sh script. You can generate the results by calling
./test.sh TEST_OUTPUT.md
The results below show the throughput in FPS. You can find the raw output, which includes latency, in the benchmarks folder.
Model | Nano (PyTorch) | Nano (TensorRT) | Xavier (PyTorch) | Xavier (TensorRT) |
---|---|---|---|---|
alexnet | 46.4 | 69.9 | 250 | 580 |
squeezenet1_0 | 44 | 137 | 130 | 890 |
squeezenet1_1 | 76.6 | 248 | 132 | 1390 |
resnet18 | 29.4 | 90.2 | 140 | 712 |
resnet34 | 15.5 | 50.7 | 79.2 | 393 |
resnet50 | 12.4 | 34.2 | 55.5 | 312 |
resnet101 | 7.18 | 19.9 | 28.5 | 170 |
resnet152 | 4.96 | 14.1 | 18.9 | 121 |
densenet121 | 11.5 | 41.9 | 23.0 | 168 |
densenet169 | 8.25 | 33.2 | 16.3 | 118 |
densenet201 | 6.84 | 25.4 | 13.3 | 90.9 |
densenet161 | 4.71 | 15.6 | 17.2 | 82.4 |
vgg11 | 8.9 | 18.3 | 85.2 | 201 |
vgg13 | 6.53 | 14.7 | 71.9 | 166 |
vgg16 | 5.09 | 11.9 | 61.7 | 139 |
vgg19 | 54.1 | 121 | ||
vgg11_bn | 8.74 | 18.4 | 81.8 | 201 |
vgg13_bn | 6.31 | 14.8 | 68.0 | 166 |
vgg16_bn | 4.96 | 12.0 | 58.5 | 140 |
vgg19_bn | 51.4 | 121 |
To install without compiling plugins, call the following
git clone https://github.com/NVIDIA-AI-IOT/torch2trt
cd torch2trt
sudo python setup.py install
To install with plugins to support some operations in PyTorch that are not natviely supported with TensorRT, call the following
This currently only includes a plugin for
torch.nn.functional.interpolate
sudo apt-get install libprotobuf* protobuf-compiler ninja-build
git clone https://github.com/NVIDIA-AI-IOT/torch2trt
cd torch2trt
sudo python setup.py install --plugins
torch2trt is tested against a system configured with the JetCard setup. Different system configurations may require additional steps.
This converter works by attaching conversion functions (like convert_ReLU
) to the original
PyTorch functional calls (like torch.nn.ReLU.forward
). The sample input data is passed
through the network, just as before, except now whenever a registered function (torch.nn.ReLU.forward
)
is encountered, the corresponding converter (convert_ReLU
) is also called afterwards. The converter
is passed the arguments and return statement of the original PyTorch function, as well as the TensorRT
network that is being constructed. The input tensors to the original PyTorch function are modified to
have an attribute _trt
, which is the TensorRT counterpart to the PyTorch tensor. The conversion function
uses this _trt
to add layers to the TensorRT network, and then sets the _trt
attribute for
relevant output tensors. Once the model is fully executed, the final tensors returns are marked as outputs
of the TensorRT network, and the optimized TensorRT engine is built.
Here we show how to add a converter for the ReLU
module using the TensorRT
python API.
import tensorrt as trt
from torch2trt import tensorrt_converter
@tensorrt_converter('torch.nn.ReLU.forward')
def convert_ReLU(ctx):
input = ctx.method_args[1]
output = ctx.method_return
layer = ctx.network.add_activation(input=input._trt, type=trt.ActivationType.RELU)
output._trt = layer.get_output(0)
The converter takes one argument, a ConversionContext
, which will contain
the following
-
ctx.network
- The TensorRT network that is being constructed. -
ctx.method_args
- Positional arguments that were passed to the specified PyTorch function. The_trt
attribute is set for relevant input tensors. -
ctx.method_kwargs
- Keyword arguments that were passed to the specified PyTorch function. -
ctx.method_return
- The value returned by the specified PyTorch function. The converter must set the_trt
attribute where relevant.
Please see this folder for more examples.
Try to specify dynamic sizes in one variable and use it everywhere, rather than calling Tensor.size() on a chain of tensors. TRT can get confused and lose track of the dimension in the latter case, while the former keeps the shape inference tree shallow.
Help TRT avoid layer duplication by caching the results of reshapes and type conversions. The fewer layers there are, the easier it is for TRT to optimize.
Watch out for shape inference explosion in the logs.
Turn on debug_sync and check the logs for debugging.
Check stderr for TRT error messages (they don't show up in Python). In colab, go to Runtime > View Runtime logs
[TensorRT] ERROR: ../builder/cudnnBuilderBlockChooser.cpp (127) - Assertion Error in buildMemGraph: 0 (mg.nodes[mg.regionIndices[outputRegion]].size == mg.nodes[mg.regionIndices[inputRegion]].size)
means you're using too many shuffles with hard indices (non 0 or -1) in the reshape_dims, so TRT is finding a conflict.
[TensorRT] ERROR: Internal error: could not find any implementation for node (Unnamed Layer* 622) [ElementWise], try increasing the workspace size with IBuilder::setMaxWorkspaceSize()
means you're using too many shuffles with soft (0 or -1) indices in the reshape_dims, so TRT has too long a chain of inferred dimensions.