Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support GPU INT-8 quantization #15

Merged
merged 17 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,4 @@ cython_debug/
.idea/
TensorRT/
triton_models/
demo/roberta-*/
14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# From 🤗 to 🤯, Hugging Face Transformer submillisecond inference️ and deployment to production
# Hugging Face Transformer submillisecond inference️ and deployment to production: 🤗 → 🤯

[![tests](https://github.com/ELS-RD/transformer-deploy/actions/workflows/python-app.yml/badge.svg)](https://github.com/ELS-RD/transformer-deploy/actions/workflows/python-app.yml) [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](./LICENCE) [![Python 3.6](https://img.shields.io/badge/python-3.6-blue.svg)](https://www.python.org/downloads/release/python-360/)

Expand All @@ -14,6 +14,7 @@
* [🐍 TensorRT usage in Python script](#tensorrt-usage-in-python-script)
* [⏱ benchmarks](#benchmarks)
* [🤗 end to end reproduction of Infinity Hugging Face demo](./demo/README.md) (to replay [Medium article](https://towardsdatascience.com/hugging-face-transformer-inference-under-1-millisecond-latency-e1be0057a51c?source=friends_link&sk=cd880e05c501c7880f2b9454830b8915))
* [🏎️ end to end GPU quantization tutorial](./demo/quantization_end_to_end.ipynb)

#### Why this tool?

Expand Down Expand Up @@ -85,7 +86,16 @@ With the single command below, you will:
* **generate** configuration files for Triton inference server

```shell
convert_model -m roberta-large-mnli --backend tensorrt onnx pytorch --seq-len 16 128 128 --batch-size 1 32 32
convert_model -m roberta-large-mnli --backend tensorrt onnx --seq-len 16 128 128 --batch-size 1 32 32
# ...
# Inference done on NVIDIA GeForce RTX 3090
# latencies:
# [Pytorch (FP32)] mean=123.26ms, sd=3.35ms, min=117.84ms, max=136.12ms, median=122.09ms, 95p=129.50ms, 99p=131.24ms
# [Pytorch (FP16)] mean=78.41ms, sd=2.83ms, min=75.58ms, max=88.48ms, median=77.28ms, 95p=84.66ms, 99p=85.97ms
# [TensorRT (FP16)] mean=182.99ms, sd=3.15ms, min=175.75ms, max=191.58ms, median=182.32ms, 95p=188.37ms, 99p=190.80ms
# [ONNX Runtime (vanilla)] mean=119.03ms, sd=8.27ms, min=112.15ms, max=185.57ms, median=116.51ms, 95p=129.18ms, 99p=167.70ms
# [ONNX Runtime (optimized)] mean=53.82ms, sd=0.81ms, min=52.79ms, max=58.27ms, median=53.74ms, 95p=55.38ms, 99p=57.29ms

```

> **16 128 128** -> minimum, optimal, maximum sequence length, to help TensorRT better optimize your model
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.1
0.2.0
2 changes: 1 addition & 1 deletion demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ docker run -it --rm --gpus all \
-v $PWD:/project ghcr.io/els-rd/transformer-deploy:0.1.1 \
bash -c "cd /project && \
convert_model -m \"philschmid/MiniLM-L6-H384-uncased-sst2\" \
--backend tensorrt onnx pytorch \
--backend tensorrt onnx \
--seq-len 16 128 128"
```

Expand Down
5,777 changes: 5,777 additions & 0 deletions demo/quantization_end_to_end.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ sympy
coloredlogs
pytest
colored
black
black[jupyter]
isort
flake8
3 changes: 1 addition & 2 deletions requirements_gpu.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
onnx
onnxruntime-gpu
onnxruntime-gpu==1.9.0
nvidia-pyindex
tritonclient[all]
pycuda
torch==1.10.0+cu113
nvidia-pyindex
nvidia-tensorrt
onnx_graphsurgeon
polygraphy
Expand Down
1,631 changes: 1,631 additions & 0 deletions src/transformer_deploy/QDQModels/QDQRoberta.py

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions src/transformer_deploy/QDQModels/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2021, Lefebvre Sarrut Services
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
8 changes: 5 additions & 3 deletions src/transformer_deploy/backends/ort_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def create_model_for_provider(path: str, provider_to_use: str) -> InferenceSessi
return InferenceSession(path, options, providers=provider_to_use)


def convert_to_onnx(model_pytorch: PreTrainedModel, output_path: str, inputs_pytorch: OD[str, torch.Tensor]) -> None:
def convert_to_onnx(
model_pytorch: PreTrainedModel, output_path: str, inputs_pytorch: OD[str, torch.Tensor], opset: int = 12
) -> None:
# dynamic axis == variable length axis
dynamic_axis = OrderedDict()
for k in inputs_pytorch.keys():
Expand All @@ -47,7 +49,7 @@ def convert_to_onnx(model_pytorch: PreTrainedModel, output_path: str, inputs_pyt
model_pytorch, # model to optimize
args=tuple(inputs_pytorch.values()), # tuple of multiple inputs
f=output_path, # output path / file object
opset_version=12, # the ONNX version to use
opset_version=opset, # the ONNX version to use, 13 if quantized model, 12 for not quantized ones
do_constant_folding=True, # simplify model (replace constant expressions)
input_names=list(inputs_pytorch.keys()), # input names
output_names=["output"], # output axis name
Expand All @@ -65,7 +67,7 @@ def optimize_onnx(onnx_path: str, onnx_optim_fp16_path: str, use_cuda: bool) ->
model_type="bert",
use_gpu=use_cuda,
opt_level=1,
num_heads=0, # automatic detection
num_heads=0, # automatic detection don't work with opset 13
hidden_size=0, # automatic detection
optimization_options=optimization_options,
)
Expand Down
56 changes: 28 additions & 28 deletions src/transformer_deploy/backends/trt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,20 @@ def setup_binding_shapes(
host_inputs: List[np.ndarray],
input_binding_idxs: List[int],
output_binding_idxs: List[int],
):
) -> Tuple[List[np.ndarray], List[DeviceAllocation]]:
# explicitly set dynamic input shapes, so dynamic output shapes can be computed internally
for host_input, binding_index in zip(host_inputs, input_binding_idxs):
context.set_binding_shape(binding_index, host_input.shape)
assert context.all_binding_shapes_specified
host_outputs = []
device_outputs = []
host_outputs: List[np.ndarray] = []
device_outputs: List[DeviceAllocation] = []
for binding_index in output_binding_idxs:
output_shape = context.get_binding_shape(binding_index)
# allocate buffers to hold output results after copying back to host
buffer = np.empty(output_shape, dtype=np.float32)
host_outputs.append(buffer)
# allocate output buffers on device
device_outputs.append(cuda.mem_alloc(buffer.nbytes))
# allocate buffers to hold output results after copying back to host
buffer = np.empty(output_shape, dtype=np.float32)
host_outputs.append(buffer)
# allocate output buffers on device
device_outputs.append(cuda.mem_alloc(buffer.nbytes))
return host_outputs, device_outputs


Expand Down Expand Up @@ -136,6 +136,8 @@ def build_engine(
optimal_shape: Tuple[int, int],
max_shape: Tuple[int, int],
workspace_size: int,
fp16: bool,
int8: bool,
) -> ICudaEngine:
with trt.Builder(logger) as builder: # type: Builder
with builder.create_network(
Expand All @@ -144,22 +146,20 @@ def build_engine(
with trt.OnnxParser(network_definition, logger) as parser: # type: OnnxParser
builder.max_batch_size = max_shape[0] # max batch size
config: IBuilderConfig = builder.create_builder_config()
# config.min_timing_iterations = 1
# config.avg_timing_iterations = 1
config.max_workspace_size = workspace_size
# to enable complete trt inspector debugging, only for TensorRT >= 8.2
# config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
# disable CUDNN optimizations
config.set_tactic_sources(
tactic_sources=1 << int(trt.TacticSource.CUBLAS) | 1 << int(trt.TacticSource.CUBLAS_LT)
)
# config.set_flag(trt.BuilderFlag.INT8)
# config.set_quantization_flag(trt.QuantizationFlag.CALIBRATE_BEFORE_FUSION)
# config.int8_calibrator = Calibrator()
config.set_flag(trt.BuilderFlag.FP16)
if int8:
config.set_flag(trt.BuilderFlag.INT8)
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
config.set_flag(trt.BuilderFlag.DISABLE_TIMING_CACHE)
# https://github.com/NVIDIA/TensorRT/issues/1196 (sometimes big diff in output when using FP16)
config.set_flag(trt.BuilderFlag.STRICT_TYPES)
config.set_flag(trt.BuilderFlag.PREFER_PRECISION_CONSTRAINTS)
with open(onnx_file_path, "rb") as f:
parser.parse(f.read())
profile: IOptimizationProfile = builder.create_optimization_profile()
Expand All @@ -171,12 +171,8 @@ def build_engine(
max=max_shape,
)
config.add_optimization_profile(profile)
# for i in range(network.num_layers):
# layer: ILayer = network.get_layer(i)
# if "gemm" in str(layer.name).lower():
# for g in range(layer.num_outputs):
# layer.precision = trt.DataType.FLOAT
network_definition = fix_fp16_network(network_definition)
if fp16:
network_definition = fix_fp16_network(network_definition)
trt_engine = builder.build_serialized_network(network_definition, config)
engine: ICudaEngine = runtime.deserialize_cuda_engine(trt_engine)
assert engine is not None, "error during engine generation, check error messages above :-("
Expand All @@ -200,16 +196,20 @@ def infer_tensorrt(
output_binding_idxs: List[int],
stream: Stream,
) -> np.ndarray:
# warning: small change in output if int64 is used instead of int32
input_list: List[ndarray] = [tensor.astype(np.int32) for tensor in host_inputs.values()]
# allocate GPU memory for input tensors
device_inputs = [cuda.mem_alloc(tensor.nbytes) for tensor in input_list]
for h_input, d_input in zip(input_list, device_inputs):
cuda.memcpy_htod_async(d_input, h_input) # host to GPU
input_list: List[ndarray] = list()
device_inputs: List[DeviceAllocation] = list()
for tensor in host_inputs.values():
# warning: small change in output if int64 is used instead of int32
tensor_int32: np.ndarray = np.asarray(tensor, dtype=np.int32)
input_list.append(tensor_int32)
# allocate GPU memory for input tensors
device_input: DeviceAllocation = cuda.mem_alloc(tensor_int32.nbytes)
device_inputs.append(device_input)
cuda.memcpy_htod_async(device_input, tensor_int32.ravel(), stream)
# calculate input shape, bind it, allocate GPU memory for the output
host_outputs, device_outputs = setup_binding_shapes(context, input_list, input_binding_idxs, output_binding_idxs)
bindings = device_inputs + device_outputs
context.execute_async_v2(bindings, stream.handle)
assert context.execute_async_v2(bindings, stream_handle=stream.handle), "failure during execution of inference"
for h_output, d_output in zip(host_outputs, device_outputs):
cuda.memcpy_dtoh_async(h_output, d_output) # GPU to host
stream.synchronize() # sync all CUDA ops
Expand Down
71 changes: 42 additions & 29 deletions src/transformer_deploy/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import tensorrt as trt
import torch
from pycuda._driver import Stream
from pytorch_quantization.nn import TensorQuantizer
from tensorrt.tensorrt import IExecutionContext, Logger, Runtime
from torch.cuda import get_device_name
from torch.cuda.amp import autocast
Expand All @@ -47,12 +48,13 @@ def main():
parser = argparse.ArgumentParser(
description="optimize and deploy transformers", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("-m", "--model", required=True, help="path to model or URL to Hugging Face Hub")
parser.add_argument("-m", "--model", required=True, help="path to model or URL to Hugging Face hub")
parser.add_argument("-t", "--tokenizer", help="path to tokenizer or URL to Hugging Face hub")
parser.add_argument(
"--auth-token",
default=None,
help=(
"HuggingFace Hub auth token. Set to `None` (default) for public models. "
"Hugging Face Hub auth token. Set to `None` (default) for public models. "
"For private models, use `True` to use local cached token, or a string of your HF API token"
),
)
Expand All @@ -72,6 +74,7 @@ def main():
type=int,
nargs=3,
)
parser.add_argument("-q", "--quantization", action="store_true", help="int-8 GPU quantization support")
parser.add_argument("-w", "--workspace-size", default=10000, help="workspace size in MiB (TensorRT)", type=int)
parser.add_argument("-o", "--output", default="triton_models", help="name to be used for ")
parser.add_argument("-n", "--name", default="transformer", help="model name to be used in triton server")
Expand All @@ -81,7 +84,7 @@ def main():
default=["onnx"],
help="backend to use. One of [onnx,tensorrt, pytorch] or all",
nargs="*",
choices=["onnx", "tensorrt", "pytorch"],
choices=["onnx", "tensorrt"],
)
parser.add_argument("--nb-instances", default=1, help="# of model instances, may improve troughput", type=int)
parser.add_argument("--warmup", default=100, help="# of inferences to warm each model", type=int)
Expand All @@ -107,7 +110,8 @@ def main():
tensorrt_path = os.path.join(args.output, "model.plan")

assert torch.cuda.is_available(), "CUDA is not available. Please check your CUDA installation"
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(args.model, use_auth_token=auth_token)
tokenizer_path = args.tokenizer if args.tokenizer else args.model
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_auth_token=auth_token)
input_names: List[str] = tokenizer.model_input_names
logging.info(f"axis: {input_names}")
include_token_ids = "token_type_ids" in input_names
Expand All @@ -130,16 +134,45 @@ def main():
logging.info(f"[Pytorch] input shape {inputs_pytorch['input_ids'].shape}")
logging.info(f"[Pytorch] output shape: {output_pytorch.shape}")
# create onnx model and compare results
convert_to_onnx(model_pytorch=model_pytorch, output_path=onnx_model_path, inputs_pytorch=inputs_pytorch)
opset = 12
if args.quantization:
TensorQuantizer.use_fb_fake_quant = True
opset = 13

convert_to_onnx(
model_pytorch=model_pytorch, output_path=onnx_model_path, inputs_pytorch=inputs_pytorch, opset=opset
)
if args.quantization:
TensorQuantizer.use_fb_fake_quant = False
onnx_model = create_model_for_provider(path=onnx_model_path, provider_to_use="CUDAExecutionProvider")
output_onnx = onnx_model.run(None, inputs_onnx)
assert np.allclose(a=output_onnx, b=output_pytorch, atol=args.atol)
del onnx_model
if "pytorch" not in args.backend:
del model_pytorch

timings = {}

with torch.inference_mode():
for _ in range(args.warmup):
_ = model_pytorch(**inputs_pytorch)
torch.cuda.synchronize()
time_buffer = []
for _ in range(args.nb_measures):
with track_infer_time(time_buffer):
_ = model_pytorch(**inputs_pytorch)
torch.cuda.synchronize()
timings["Pytorch (FP32)"] = time_buffer
with autocast():
for _ in range(args.warmup):
_ = model_pytorch(**inputs_pytorch)
torch.cuda.synchronize()
time_buffer = []
for _ in range(args.nb_measures):
with track_infer_time(time_buffer):
_ = model_pytorch(**inputs_pytorch)
torch.cuda.synchronize()
timings["Pytorch (FP16)"] = time_buffer
del model_pytorch

if "tensorrt" in args.backend:
trt_logger: Logger = trt.Logger(trt.Logger.INFO if args.verbose else trt.Logger.WARNING)
runtime: Runtime = trt.Runtime(trt_logger)
Expand All @@ -151,6 +184,8 @@ def main():
optimal_shape=tensor_shapes[1],
max_shape=tensor_shapes[2],
workspace_size=args.workspace_size * 1024 * 1024,
fp16=not args.quantization,
int8=args.quantization,
)
save_engine(engine=engine, engine_file_path=tensorrt_path)
# important to check the engine has been correctly serialized
Expand Down Expand Up @@ -242,28 +277,6 @@ def main():
)
conf.create_folders(tokenizer=tokenizer, model_path=onnx_optim_fp16_path)

if "pytorch" in args.backend:
with torch.inference_mode():
for _ in range(args.warmup):
_ = model_pytorch(**inputs_pytorch)
torch.cuda.synchronize()
time_buffer = []
for _ in range(args.nb_measures):
with track_infer_time(time_buffer):
_ = model_pytorch(**inputs_pytorch)
torch.cuda.synchronize()
timings["Pytorch (FP32)"] = time_buffer
with autocast():
for _ in range(args.warmup):
_ = model_pytorch(**inputs_pytorch)
torch.cuda.synchronize()
time_buffer = []
for _ in range(args.nb_measures):
with track_infer_time(time_buffer):
_ = model_pytorch(**inputs_pytorch)
torch.cuda.synchronize()
timings["Pytorch (FP16)"] = time_buffer

print(f"Inference done on {get_device_name(0)}")
print("latencies:")
for name, time_buffer in timings.items():
Expand Down