- [2024/05] SmoothQuant enables INT8 model inference on AMD Instinct MI300X using Composable Kernel.
- [2024/03] We show SmoothQuant can enable W8A8 quantization for Llama-1/2/3, Falcon, Mistral, and Mixtral models with negligible loss. Results.
- [2024/01] SmoothQuant is integrated into Microsoft's ONNX Runtime.
- [2023/11] SmoothQuant is integrated into Amazon SageMaker.
- [2023/10] SmoothQuant is integrated into NVIDIA TensorRT-LLM.
- [2023/03] SmoothQuant is integrated into Intel Neural-Compressor.
Large language models (LLMs) show excellent performance but are compute- and memory-intensive. Quantization can reduce memory and accelerate inference. However, for LLMs beyond 100 billion parameters, existing methods cannot maintain accuracy or do not run efficiently on hardware. We propose SmoothQuant, a training-free, accuracy-preserving, and general-purpose post-training quantization (PTQ) solution to enable 8-bit weight, 8-bit activation (W8A8) quantization for LLMs. Based on the fact that weights are easy to quantize while activations are not, SmoothQuant smooths the activation outliers by offline migrating the quantization difficulty from activations to weights with a mathematically equivalent transformation. SmoothQuant enables an INT8 quantization of both weights and activations for all the matrix multiplications in LLMs, including OPT-175B, BLOOM-176B, GLM-130B, and MT-NLG 530B. SmoothQuant has better hardware efficiency than existing techniques. We demonstrate up to 1.56x speedup and 2x memory reduction for LLMs with negligible loss in accuracy. We integrate SmoothQuant into FasterTransformer, a state-of-the-art LLM serving framework, and achieve faster inference speed with half the number of GPUs compared to FP16, enabling the serving of a 530B LLM within a single node. Our work offers a turn-key solution that reduces hardware costs and democratizes LLMs.
conda create -n smoothquant python=3.8
conda activate smoothquant
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
pip install transformers==4.36.0 accelerate datasets zstandard
python setup.py install
We implement SmoothQuant INT8 inference for PyTorch with CUTLASS INT8 GEMM kernels, which are wrapped as PyTorch modules in torch-int. Please install torch-int before running the SmoothQuant PyTorch INT8 inference.
We implement the quantized OPT model class in smoothquant/opt.py, which uses INT8 linear layers and bundles quantization scales. We provide the already smoothed and quantized OPT model at https://huggingface.co/mit-han-lab/opt-[MODEL-SIZE]-smoothquant, where [MODEL-SIZE]
can be 125m
, 1.3B
, 2.7B
, 6.7B
, 13B
, 30b
, and 66b
. You can load the INT8 model with the following code:
from smoothquant.opt import Int8OPTForCausalLM
model = Int8OPTForCausalLM.from_pretrained("mit-han-lab/opt-30b-smoothquant")
You can also check generate_act_scales.py and export_int8_model.py to see how we smooth, quantize and export INT8 models.
In examples/smoothquant_opt_real_int8_demo.ipynb, we use OPT-30B model to demonstrate the latency and memory advantages of SmoothQuant. We demonstrate on OPT-30B because it is the largest model we can run both the FP16 and INT8 inference on a single A100 GPU. For larger models requiring multiple GPUs, we recommend using the FasterTransformer implementation of SmoothQuant.
We provide the activation channel scales for Llama, Mistral, Mixtral, Falcon, OPT, and BLOOM models in act_scales/. We get those scales with 512 random sentences in the Pile validation set. You can use the OPT demo (examples/smoothquant_opt_demo.ipynb) and Llama demo (examples/smoothquant_llama_demo.ipynb) to test smoothing and quantizing those models.
We also provide the script to get the activation channel scales for your models. Please refer to examples/generate_act_scales.py. You can use the following command to get the scales for your models:
python examples/generate_act_scales.py \
--model-name <model_name_or_path> \
--output-path <output_act_scales_file_path> \
--num-samples <num_samples> \
--seq-len <sequence_length> \
--dataset-path <path_to_the_calibration_dataset>
In examples/smoothquant_opt_demo.ipynb, we use OPT-13B as an example to demonstrate SmoothQuant can match the accuracy of FP16 and INT8 inference, while the naive baseline cannot. We simulate INT8 inference with FP16 (smoothquant/fake_quant.py), i.e., fake quantization.
We provide an evaluation script to evaluate the language modeling perplexity of OPT, BLoom, Llama, Falcon, Mistral, and Mixtral models with W8A8 simulated quantization. Please refer to smoothquant/ppl_eval.py. You can use the following command to evaluate the models:
python smoothquant/ppl_eval.py \
--model_path <model_name_or_path> \
--act_scales_path <act_scales_file_path> \
--smooth \
--alpha <alpha> \
--quantize
Results:
Model | Method | PPL | Alpha |
---|---|---|---|
Llama-2-7B | FP16 | 5.474 | |
SQ W8A8 | 5.515 | 0.85 | |
Llama-2-13B | FP16 | 4.950 | |
SQ W8A8 | 4.929 | 0.85 | |
Llama-2-70B | FP16 | 3.320 | |
SQ W8A8 | 3.359 | 0.9 | |
Llama-3-8B | FP16 | 6.138 | |
SQ W8A8 | 6.258 | 0.85 | |
Llama-3-70B | FP16 | 2.857 | |
SQ W8A8 | 2.982 | 0.85 | |
Mistral-7B | FP16 | 5.253 | |
SQ W8A8 | 5.277 | 0.8 | |
Mixtral-8x7B | FP16 | 3.842 | |
SQ W8A8 | 3.893 | 0.8 | |
Falcon-7B | FP16 | 6.590 | |
SQ W8A8 | 6.629 | 0.6 | |
Falcon-40B | FP16 | 5.228 | |
SQ W8A8 | 5.255 | 0.7 |
For measured speedup, we recommend using the NVIDIA TensorRT-LLM implementation of SmoothQuant.
- SmoothQuant migrates part of the quantization difficulties from activation to weights, which smooths out the systematic outliers in activation, making both weights and activations easy to quantize.
- SmoothQuant can achieve W8A8 quantization of LLMs (e.g., OPT-175B) without degrading performance.
- SmoothQuant can achieve faster inference compared to FP16 when integrated into PyTorch, while previous work LLM.int8() does not lead to acceleration (usually slower).
- We also integrate SmoothQuant into the state-of-the-art serving framework FasterTransformer, achieving faster inference speed using only half the GPU numbers compared to FP16 (1 instead of 2 for OPT-66B, 4 instead of 8 for OPT-175B).
If you find SmoothQuant useful or relevant to your research, please kindly cite our paper:
@InProceedings{xiao2023smoothquant,
title = {{S}mooth{Q}uant: Accurate and Efficient Post-Training Quantization for Large Language Models},
author = {Xiao, Guangxuan and Lin, Ji and Seznec, Mickael and Wu, Hao and Demouth, Julien and Han, Song},
booktitle = {Proceedings of the 40th International Conference on Machine Learning},
year = {2023}
}