This project aims to explore the deployment of SwinTransformer based on TensorRT, including the test results of FP16 and INT8.
Swin Transformer original github repo (the name Swin
stands for Shifted window) is initially described in arxiv, which capably serves as a
general-purpose backbone for computer vision. It is basically a hierarchical Transformer whose representation is
computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention
computation to non-overlapping local windows while also allowing for cross-window connection.
-
Please refer to the Data preparation session to prepare Imagenet-1K.
-
Docker setup.
a). Docker pull and launch. TensorRT 8.5.1.7 is preinstalled in this docker.
docker pull nvcr.io/nvidia/tensorrt:22.11-py3 docker run --name tensorrt_22.11_py3_swin -it --rm --gpus "device=0" --network host --shm-size 16g -v /($path_of_your_projects):/root/space/projects nvcr.io/nvidia/tensorrt:22.11-py3 &
b). Install necessary utils:
pip install pytorch-quantization==2.1.2 --extra-index-url https://pypi.ngc.nvidia.com pip install torch==1.13.0 torchvision==0.14.0 pip install timm==0.4.12 pip install termcolor==1.1.0 pip install pyyaml tqdm yacs pip install onnx onnxruntime
Focus on the modifications and additions.
.
├── config.py # Add the default config of quantization and onnx export
├── export.py # Export the PyTorch model to ONNX format
├── calib.sh # Calib script
├── models
│ ├── build.py
│ ├── __init__.py
│ └── swin_transformer.py # Build the model and add the quantization operations, modified to export the onnx and build the TensorRT engine
├── README.md
├── qat.sh # Execute calibration and QAT finetuning
├── trt # Directory for TensorRT's engine evaluation and visualization.
│ ├── debug # Compare scripts with polygraphy, compare the results of onnx and TRT engine with fixed input
│ ├── build_engine.py # Script for engine build
│ ├── engine.py
│ ├── eval_trt.py # Evaluate the tensorRT engine's accuary.
│ ├── eval_onnxrt.py # Run the onnx model, generate the results, just for debugging
├── swin_quant_flow.py # QAT workflow for swin_transformer, we haven't try the swin_mlp structure
└── weights
You need to pay attention to some small modifications below.
-
For dynamic batchsize support, please refer to the modifications in
models/swin_transformer.py
. The window_reverse does not support dynamic batch because it cast the first dimension of windows to integer.def window_reverse(windows, window_size, H, W): # B = int(windows.shape[0] / (H * W / window_size / window_size)) # x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) # x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) C = int(windows.shape[-1]) x = windows.view(-1, H // window_size, W // window_size, window_size, window_size, C) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) return x
-
For fp16 mode, fp16 can't store very large and very small numbers like fp32. So we need to set some specific layers to fp32 during the engine build. We can fallback the
POW
andREDUCE
layers to fp32, it is enough to fix the accuracy problem and don't hurt the perfomance/throughput. Sometime maybe with different weights, you need to fallbackPOW
,REDUCEMEAN
,Add
andSqrt
to fp16, please refer tofix_fp16_network
function intrt/trt_utils.py
.
If you are using the env setting as above, just skip this.
-
Exporting the operator roll to ONNX opset version 9 is not supported.
A: Please refer to torch/onnx/symbolic_opset9.py, add the support of exporting torch.roll. -
Node (Concat_264) Op (Concat) [ShapeInferenceError] All inputs to Concat must have same rank.
A: Please refer to the modifications inmodels/swin_transformer.py
. We use the input_resolution and window_size to compute the nW.if mask is not None: nW = int(self.input_resolution[0]*self.input_resolution[1]/self.window_size[0]/self.window_size[1]) #nW = mask.shape[0] #print('nW: ', nW) attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn)
-
Download the
Swin-T
pretrained model from Model Zoo. -
export.py
exports a pytorch model to onnx format.$ python export.py --eval --cfg SwinTransformer/configs/swin/swin_tiny_patch4_window7_224.yaml --resume ./weights/swin_tiny_patch4_window7_224.pth --data-path /root/space/projects/datasets/imagenet
-
Build the TensorRT engine using
trtexec
.$ python trt/build_engine.py --onnx-file ./weights/swin_tiny_patch4_window7_224.onnx --trt-engine ./weights/swin_tiny_patch4_window7_224_batch32_fp32.engine --verbose --mode fp32 --b-opt 32
For fp16 mode.
$ python trt/build_engine.py --onnx-file ./weights/swin_tiny_patch4_window7_224.onnx --trt-engine ./weights/swin_tiny_patch4_window7_224_batch32_fp16.engine --verbose --mode fp16 --b-opt 32
You can use the
trtexec
to test the throughput of the TensorRT engine.$ trtexec --loadEngine=./weights/swin_tiny_patch4_window7_224_batch32_fp16.engine
-
trt/eval_trt.py
aims to evalute the accuracy of the TensorRT engine.$ python trt/eval_trt.py --engine ./weights/swin_tiny_patch4_window7_224_batch32_fp16.engine --data-path /root/space/projects/datasets/imagenet --batch-size 32
-
trt/eval_onnxrt.py
aims to evalute the accuracy of the Onnx model, just for debug.$ python trt/eval_onnxrt.py --eval --cfg SwinTransformer/configs/swin/swin_tiny_patch4_window7_224.yaml --resume ./weights/swin_tiny_patch4_window7_224_fixed.onnx --data-path /root/space/projects/datasets/imagenet --batch-size 32
Accuracy and Speedup Test of TensorRT engine (A100 40GB, TensorRT 8.5.1.7)
Model (BS=32) | FP32(latency: mean) | FP16(latency: mean) | FP16 Acc.1 |
---|---|---|---|
Swin Tiny | 14.2938 ms | 8.87109 ms | 81.20% |
Swin Small | 22.9841 ms | 12.9888 ms | 83.20% |
Swin Base | 31.4782 ms | 17.3515 ms | 85.20% |
Swin Large | 53.5593 ms | 27.6452 ms | 86.20% |
For int8, after calibration, the accuray is 80.8% with swin-tiny, just as expected. But the speedup is not obvious. So fp16 deployment is highly recommended.
Accuracy Test of TensorRT engine (T4, TensorRT 8.2)
SwinTransformer(T4) | Acc@1 | Notes |
---|---|---|
PyTorch Pretrained Model | 81.160 | |
TensorRT Engine(FP32) | 81.156 | |
TensorRT Engine(FP16) | 81.150 | With POW and REDUCE layers fallback to FP32 |
TensorRT Engine(INT8 QAT) | - | Finetune for 1 epoch, got 79.980, need to improve the int8 throughput first |
Speed Test of TensorRT engine (T4, TensorRT 8.2)
SwinTransformer(T4) | FP32 | FP16 | Explicit Quantization(INT8, QAT) |
---|---|---|---|
batchsize=1 | 245.388 qps | 510.072 qps | 385.454 qps |
batchsize=16 | 316.8624 qps | 804.112 qps | 815.606 qps |
batchsize=64 | 329.13984 qps | 833.4208 qps | 780.006 qps |
batchsize=256 | 331.9808 qps | 844.10752 qps | - |
Result:
-
Now the accuracy and speedup of FP16 is as expected, it is highly recommended to deploy Swin-Transformer with FP16 precision.
-
Compared with FP16, INT8 does not speed up at present.
The main modifications of models/swin_transformer.py
are as below.
-
For
PatchMerging
block, modifytorch.nn.Liner
toquant_nn.QuantLinear
. -
For
WindowAttention
block,
a) For query, key and value, modifytorch.nn.Liner
toquant_nn.QuantLinear
.
b) Quantize the four inputs oftorch.matmul
. -
For
MLP
block, modifytorch.nn.Liner
toquant_nn.QuantLinear
. -
For
SwinTransformerBlock
block, quantize the inputs of operator+
.
In order to do the QAT finetuning, some utils are needed to install.
tqdm
, prettytable
, scipy
, absl-py
-
With
swin_quant_flow.py
, wrap a fake-quantized model, calibrate, QAT finetuning and export to onnx model.$ ./calib.sh
Or you can run calibration and QAT-finetuning in the same time.
$ ./qat.sh
-
Build TensorRT engine and evaluate as above. Same commands.