4 bits quantization of LLaMA and Bloom using GPTQ
This repo is modified from GPTQ-for-LLaMa, the basic usage is the same as that repo, FYI.
GPTQ is SOTA one-shot weight quantization method
Supports the fastest speed, but uses both triton and cuda. Triton only supports Linux, so if you are a Windows user, please use WSL2.
Support pulse model with lora finetuning 4-bit quantization.
PULSE results are evaluated on the medical dataset.
Quantization requires a large amount of CPU memory. However, the memory required can be reduced by using swap memory.
Depending on the GPUs/drivers, there may be a difference in performance, which decreases as the model size increases.(IST-DASLab/gptq#1)
According to GPTQ paper, As the size of the model increases, the difference in performance between FP16 and GPTQ decreases.
If you don't have conda, install it first.
conda create --name gptq python=3.9 -y
conda activate gptq
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
# Or, if you're having trouble with conda, use pip with python3.9:
# pip3 install torch torchvision torchaudio
git clone -b pulse https://github.com/hanrui1sensetime/GPTQ-for-PULSE.git
cd GPTQ-for-PULSE
pip install -r requirements.txt
python setup_cuda.py install
torch
: tested on v2.0.0+cu117transformers
: tested on v4.34.0datasets
: tested on v2.13.1safetensors
: tested on v0.3.1peft
: tested on v0.7.0
All experiments were run on a single NVIDIA RTX3090.
PULSE-7B model is implemented by bloomz.
# Generate 4-bit PULSE-7B model
CUDA_VISIBLE_DEVICES=0 python bloom.py ${MODEL_DIR} custom --wbits 4 --act-order --groupsize 128 --save pulse7b-4bit-128g.bin --calib_data ${CALIB_DATA_PATH}
# Generate 4-bit PULSE-7B with lora model
CUDA_VISIBLE_DEVICES=0 python bloom_lora.py ${MODEL_DIR} custom --wbits 4 --act-order --groupsize 128 --save pulse7b-4bit-128g.bin --calib_data ${CALIB_DATA_PATH} --peft_path ${PEFT_PATH}
This code is based on GPTQ-for-LLaMa
Thanks to Meta AI for releasing LLaMA, a powerful LLM.
Triton GPTQ kernel code is based on GPTQ-triton