Skip to content

Commit

Permalink
Support GPU INT-8 quantization (#15)
Browse files Browse the repository at this point in the history
* support quantization
fix some stupid bugs
use opset 13 (onnx)

* add quantization demo

* add dependency

* qdqroberta

* update quantization notebook

* update quantization notebook

* update quantization notebook

* bump VERSION

* delete old script

* cleaning

* fix ORT to 1.9.0, 1.10.0 seems to be bugged

* modify text

* update tuto

* update tuto

* update tuto
  • Loading branch information
pommedeterresautee committed Dec 8, 2021
1 parent e1b2f38 commit ad837a9
Show file tree
Hide file tree
Showing 12 changed files with 7,513 additions and 67 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,4 @@ cython_debug/
.idea/
TensorRT/
triton_models/
demo/roberta-*/
14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# From 🤗 to 🤯, Hugging Face Transformer submillisecond inference️ and deployment to production
# Hugging Face Transformer submillisecond inference️ and deployment to production: 🤗 → 🤯

[![tests](https://github.com/ELS-RD/transformer-deploy/actions/workflows/python-app.yml/badge.svg)](https://github.com/ELS-RD/transformer-deploy/actions/workflows/python-app.yml) [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](./LICENCE) [![Python 3.6](https://img.shields.io/badge/python-3.6-blue.svg)](https://www.python.org/downloads/release/python-360/)

Expand All @@ -14,6 +14,7 @@
* [🐍 TensorRT usage in Python script](#tensorrt-usage-in-python-script)
* [⏱ benchmarks](#benchmarks)
* [🤗 end to end reproduction of Infinity Hugging Face demo](./demo/README.md) (to replay [Medium article](https://towardsdatascience.com/hugging-face-transformer-inference-under-1-millisecond-latency-e1be0057a51c?source=friends_link&sk=cd880e05c501c7880f2b9454830b8915))
* [🏎️ end to end GPU quantization tutorial](./demo/quantization_end_to_end.ipynb)

#### Why this tool?

Expand Down Expand Up @@ -85,7 +86,16 @@ With the single command below, you will:
* **generate** configuration files for Triton inference server

```shell
convert_model -m roberta-large-mnli --backend tensorrt onnx pytorch --seq-len 16 128 128 --batch-size 1 32 32
convert_model -m roberta-large-mnli --backend tensorrt onnx --seq-len 16 128 128 --batch-size 1 32 32
# ...
# Inference done on NVIDIA GeForce RTX 3090
# latencies:
# [Pytorch (FP32)] mean=123.26ms, sd=3.35ms, min=117.84ms, max=136.12ms, median=122.09ms, 95p=129.50ms, 99p=131.24ms
# [Pytorch (FP16)] mean=78.41ms, sd=2.83ms, min=75.58ms, max=88.48ms, median=77.28ms, 95p=84.66ms, 99p=85.97ms
# [TensorRT (FP16)] mean=182.99ms, sd=3.15ms, min=175.75ms, max=191.58ms, median=182.32ms, 95p=188.37ms, 99p=190.80ms
# [ONNX Runtime (vanilla)] mean=119.03ms, sd=8.27ms, min=112.15ms, max=185.57ms, median=116.51ms, 95p=129.18ms, 99p=167.70ms
# [ONNX Runtime (optimized)] mean=53.82ms, sd=0.81ms, min=52.79ms, max=58.27ms, median=53.74ms, 95p=55.38ms, 99p=57.29ms

```

> **16 128 128** -> minimum, optimal, maximum sequence length, to help TensorRT better optimize your model
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.1
0.2.0
2 changes: 1 addition & 1 deletion demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ docker run -it --rm --gpus all \
-v $PWD:/project ghcr.io/els-rd/transformer-deploy:0.1.1 \
bash -c "cd /project && \
convert_model -m \"philschmid/MiniLM-L6-H384-uncased-sst2\" \
--backend tensorrt onnx pytorch \
--backend tensorrt onnx \
--seq-len 16 128 128"
```

Expand Down
5,777 changes: 5,777 additions & 0 deletions demo/quantization_end_to_end.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ sympy
coloredlogs
pytest
colored
black
black[jupyter]
isort
flake8
3 changes: 1 addition & 2 deletions requirements_gpu.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
onnx
onnxruntime-gpu
onnxruntime-gpu==1.9.0
nvidia-pyindex
tritonclient[all]
pycuda
torch==1.10.0+cu113
nvidia-pyindex
nvidia-tensorrt
onnx_graphsurgeon
polygraphy
Expand Down
1,631 changes: 1,631 additions & 0 deletions src/transformer_deploy/QDQModels/QDQRoberta.py

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions src/transformer_deploy/QDQModels/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2021, Lefebvre Sarrut Services
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
8 changes: 5 additions & 3 deletions src/transformer_deploy/backends/ort_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def create_model_for_provider(path: str, provider_to_use: str) -> InferenceSessi
return InferenceSession(path, options, providers=provider_to_use)


def convert_to_onnx(model_pytorch: PreTrainedModel, output_path: str, inputs_pytorch: OD[str, torch.Tensor]) -> None:
def convert_to_onnx(
model_pytorch: PreTrainedModel, output_path: str, inputs_pytorch: OD[str, torch.Tensor], opset: int = 12
) -> None:
# dynamic axis == variable length axis
dynamic_axis = OrderedDict()
for k in inputs_pytorch.keys():
Expand All @@ -47,7 +49,7 @@ def convert_to_onnx(model_pytorch: PreTrainedModel, output_path: str, inputs_pyt
model_pytorch, # model to optimize
args=tuple(inputs_pytorch.values()), # tuple of multiple inputs
f=output_path, # output path / file object
opset_version=12, # the ONNX version to use
opset_version=opset, # the ONNX version to use, 13 if quantized model, 12 for not quantized ones
do_constant_folding=True, # simplify model (replace constant expressions)
input_names=list(inputs_pytorch.keys()), # input names
output_names=["output"], # output axis name
Expand All @@ -65,7 +67,7 @@ def optimize_onnx(onnx_path: str, onnx_optim_fp16_path: str, use_cuda: bool) ->
model_type="bert",
use_gpu=use_cuda,
opt_level=1,
num_heads=0, # automatic detection
num_heads=0, # automatic detection don't work with opset 13
hidden_size=0, # automatic detection
optimization_options=optimization_options,
)
Expand Down
56 changes: 28 additions & 28 deletions src/transformer_deploy/backends/trt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,20 @@ def setup_binding_shapes(
host_inputs: List[np.ndarray],
input_binding_idxs: List[int],
output_binding_idxs: List[int],
):
) -> Tuple[List[np.ndarray], List[DeviceAllocation]]:
# explicitly set dynamic input shapes, so dynamic output shapes can be computed internally
for host_input, binding_index in zip(host_inputs, input_binding_idxs):
context.set_binding_shape(binding_index, host_input.shape)
assert context.all_binding_shapes_specified
host_outputs = []
device_outputs = []
host_outputs: List[np.ndarray] = []
device_outputs: List[DeviceAllocation] = []
for binding_index in output_binding_idxs:
output_shape = context.get_binding_shape(binding_index)
# allocate buffers to hold output results after copying back to host
buffer = np.empty(output_shape, dtype=np.float32)
host_outputs.append(buffer)
# allocate output buffers on device
device_outputs.append(cuda.mem_alloc(buffer.nbytes))
# allocate buffers to hold output results after copying back to host
buffer = np.empty(output_shape, dtype=np.float32)
host_outputs.append(buffer)
# allocate output buffers on device
device_outputs.append(cuda.mem_alloc(buffer.nbytes))
return host_outputs, device_outputs


Expand Down Expand Up @@ -136,6 +136,8 @@ def build_engine(
optimal_shape: Tuple[int, int],
max_shape: Tuple[int, int],
workspace_size: int,
fp16: bool,
int8: bool,
) -> ICudaEngine:
with trt.Builder(logger) as builder: # type: Builder
with builder.create_network(
Expand All @@ -144,22 +146,20 @@ def build_engine(
with trt.OnnxParser(network_definition, logger) as parser: # type: OnnxParser
builder.max_batch_size = max_shape[0] # max batch size
config: IBuilderConfig = builder.create_builder_config()
# config.min_timing_iterations = 1
# config.avg_timing_iterations = 1
config.max_workspace_size = workspace_size
# to enable complete trt inspector debugging, only for TensorRT >= 8.2
# config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
# disable CUDNN optimizations
config.set_tactic_sources(
tactic_sources=1 << int(trt.TacticSource.CUBLAS) | 1 << int(trt.TacticSource.CUBLAS_LT)
)
# config.set_flag(trt.BuilderFlag.INT8)
# config.set_quantization_flag(trt.QuantizationFlag.CALIBRATE_BEFORE_FUSION)
# config.int8_calibrator = Calibrator()
config.set_flag(trt.BuilderFlag.FP16)
if int8:
config.set_flag(trt.BuilderFlag.INT8)
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
config.set_flag(trt.BuilderFlag.DISABLE_TIMING_CACHE)
# https://github.com/NVIDIA/TensorRT/issues/1196 (sometimes big diff in output when using FP16)
config.set_flag(trt.BuilderFlag.STRICT_TYPES)
config.set_flag(trt.BuilderFlag.PREFER_PRECISION_CONSTRAINTS)
with open(onnx_file_path, "rb") as f:
parser.parse(f.read())
profile: IOptimizationProfile = builder.create_optimization_profile()
Expand All @@ -171,12 +171,8 @@ def build_engine(
max=max_shape,
)
config.add_optimization_profile(profile)
# for i in range(network.num_layers):
# layer: ILayer = network.get_layer(i)
# if "gemm" in str(layer.name).lower():
# for g in range(layer.num_outputs):
# layer.precision = trt.DataType.FLOAT
network_definition = fix_fp16_network(network_definition)
if fp16:
network_definition = fix_fp16_network(network_definition)
trt_engine = builder.build_serialized_network(network_definition, config)
engine: ICudaEngine = runtime.deserialize_cuda_engine(trt_engine)
assert engine is not None, "error during engine generation, check error messages above :-("
Expand All @@ -200,16 +196,20 @@ def infer_tensorrt(
output_binding_idxs: List[int],
stream: Stream,
) -> np.ndarray:
# warning: small change in output if int64 is used instead of int32
input_list: List[ndarray] = [tensor.astype(np.int32) for tensor in host_inputs.values()]
# allocate GPU memory for input tensors
device_inputs = [cuda.mem_alloc(tensor.nbytes) for tensor in input_list]
for h_input, d_input in zip(input_list, device_inputs):
cuda.memcpy_htod_async(d_input, h_input) # host to GPU
input_list: List[ndarray] = list()
device_inputs: List[DeviceAllocation] = list()
for tensor in host_inputs.values():
# warning: small change in output if int64 is used instead of int32
tensor_int32: np.ndarray = np.asarray(tensor, dtype=np.int32)
input_list.append(tensor_int32)
# allocate GPU memory for input tensors
device_input: DeviceAllocation = cuda.mem_alloc(tensor_int32.nbytes)
device_inputs.append(device_input)
cuda.memcpy_htod_async(device_input, tensor_int32.ravel(), stream)
# calculate input shape, bind it, allocate GPU memory for the output
host_outputs, device_outputs = setup_binding_shapes(context, input_list, input_binding_idxs, output_binding_idxs)
bindings = device_inputs + device_outputs
context.execute_async_v2(bindings, stream.handle)
assert context.execute_async_v2(bindings, stream_handle=stream.handle), "failure during execution of inference"
for h_output, d_output in zip(host_outputs, device_outputs):
cuda.memcpy_dtoh_async(h_output, d_output) # GPU to host
stream.synchronize() # sync all CUDA ops
Expand Down
71 changes: 42 additions & 29 deletions src/transformer_deploy/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import tensorrt as trt
import torch
from pycuda._driver import Stream
from pytorch_quantization.nn import TensorQuantizer
from tensorrt.tensorrt import IExecutionContext, Logger, Runtime
from torch.cuda import get_device_name
from torch.cuda.amp import autocast
Expand All @@ -47,12 +48,13 @@ def main():
parser = argparse.ArgumentParser(
description="optimize and deploy transformers", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("-m", "--model", required=True, help="path to model or URL to Hugging Face Hub")
parser.add_argument("-m", "--model", required=True, help="path to model or URL to Hugging Face hub")
parser.add_argument("-t", "--tokenizer", help="path to tokenizer or URL to Hugging Face hub")
parser.add_argument(
"--auth-token",
default=None,
help=(
"HuggingFace Hub auth token. Set to `None` (default) for public models. "
"Hugging Face Hub auth token. Set to `None` (default) for public models. "
"For private models, use `True` to use local cached token, or a string of your HF API token"
),
)
Expand All @@ -72,6 +74,7 @@ def main():
type=int,
nargs=3,
)
parser.add_argument("-q", "--quantization", action="store_true", help="int-8 GPU quantization support")
parser.add_argument("-w", "--workspace-size", default=10000, help="workspace size in MiB (TensorRT)", type=int)
parser.add_argument("-o", "--output", default="triton_models", help="name to be used for ")
parser.add_argument("-n", "--name", default="transformer", help="model name to be used in triton server")
Expand All @@ -81,7 +84,7 @@ def main():
default=["onnx"],
help="backend to use. One of [onnx,tensorrt, pytorch] or all",
nargs="*",
choices=["onnx", "tensorrt", "pytorch"],
choices=["onnx", "tensorrt"],
)
parser.add_argument("--nb-instances", default=1, help="# of model instances, may improve troughput", type=int)
parser.add_argument("--warmup", default=100, help="# of inferences to warm each model", type=int)
Expand All @@ -107,7 +110,8 @@ def main():
tensorrt_path = os.path.join(args.output, "model.plan")

assert torch.cuda.is_available(), "CUDA is not available. Please check your CUDA installation"
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(args.model, use_auth_token=auth_token)
tokenizer_path = args.tokenizer if args.tokenizer else args.model
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_auth_token=auth_token)
input_names: List[str] = tokenizer.model_input_names
logging.info(f"axis: {input_names}")
include_token_ids = "token_type_ids" in input_names
Expand All @@ -130,16 +134,45 @@ def main():
logging.info(f"[Pytorch] input shape {inputs_pytorch['input_ids'].shape}")
logging.info(f"[Pytorch] output shape: {output_pytorch.shape}")
# create onnx model and compare results
convert_to_onnx(model_pytorch=model_pytorch, output_path=onnx_model_path, inputs_pytorch=inputs_pytorch)
opset = 12
if args.quantization:
TensorQuantizer.use_fb_fake_quant = True
opset = 13

convert_to_onnx(
model_pytorch=model_pytorch, output_path=onnx_model_path, inputs_pytorch=inputs_pytorch, opset=opset
)
if args.quantization:
TensorQuantizer.use_fb_fake_quant = False
onnx_model = create_model_for_provider(path=onnx_model_path, provider_to_use="CUDAExecutionProvider")
output_onnx = onnx_model.run(None, inputs_onnx)
assert np.allclose(a=output_onnx, b=output_pytorch, atol=args.atol)
del onnx_model
if "pytorch" not in args.backend:
del model_pytorch

timings = {}

with torch.inference_mode():
for _ in range(args.warmup):
_ = model_pytorch(**inputs_pytorch)
torch.cuda.synchronize()
time_buffer = []
for _ in range(args.nb_measures):
with track_infer_time(time_buffer):
_ = model_pytorch(**inputs_pytorch)
torch.cuda.synchronize()
timings["Pytorch (FP32)"] = time_buffer
with autocast():
for _ in range(args.warmup):
_ = model_pytorch(**inputs_pytorch)
torch.cuda.synchronize()
time_buffer = []
for _ in range(args.nb_measures):
with track_infer_time(time_buffer):
_ = model_pytorch(**inputs_pytorch)
torch.cuda.synchronize()
timings["Pytorch (FP16)"] = time_buffer
del model_pytorch

if "tensorrt" in args.backend:
trt_logger: Logger = trt.Logger(trt.Logger.INFO if args.verbose else trt.Logger.WARNING)
runtime: Runtime = trt.Runtime(trt_logger)
Expand All @@ -151,6 +184,8 @@ def main():
optimal_shape=tensor_shapes[1],
max_shape=tensor_shapes[2],
workspace_size=args.workspace_size * 1024 * 1024,
fp16=not args.quantization,
int8=args.quantization,
)
save_engine(engine=engine, engine_file_path=tensorrt_path)
# important to check the engine has been correctly serialized
Expand Down Expand Up @@ -242,28 +277,6 @@ def main():
)
conf.create_folders(tokenizer=tokenizer, model_path=onnx_optim_fp16_path)

if "pytorch" in args.backend:
with torch.inference_mode():
for _ in range(args.warmup):
_ = model_pytorch(**inputs_pytorch)
torch.cuda.synchronize()
time_buffer = []
for _ in range(args.nb_measures):
with track_infer_time(time_buffer):
_ = model_pytorch(**inputs_pytorch)
torch.cuda.synchronize()
timings["Pytorch (FP32)"] = time_buffer
with autocast():
for _ in range(args.warmup):
_ = model_pytorch(**inputs_pytorch)
torch.cuda.synchronize()
time_buffer = []
for _ in range(args.nb_measures):
with track_infer_time(time_buffer):
_ = model_pytorch(**inputs_pytorch)
torch.cuda.synchronize()
timings["Pytorch (FP16)"] = time_buffer

print(f"Inference done on {get_device_name(0)}")
print("latencies:")
for name, time_buffer in timings.items():
Expand Down

0 comments on commit ad837a9

Please sign in to comment.