From 4e4d7cc89a1a518bb389426e87a40113bc58730f Mon Sep 17 00:00:00 2001 From: RachelXu7 Date: Tue, 1 Aug 2023 11:29:06 +0000 Subject: [PATCH 1/3] Add WeightOnlyPTQ and GPTQ --- llm/causallm/argument.py | 2 ++ llm/causallm/finetune_generation.py | 13 +++++++++++-- llm/causallm/quant.py | 26 ++++++++++++++++++++++++-- 3 files changed, 37 insertions(+), 4 deletions(-) diff --git a/llm/causallm/argument.py b/llm/causallm/argument.py index 06fef64c3de2..7bc1a4e6f133 100644 --- a/llm/causallm/argument.py +++ b/llm/causallm/argument.py @@ -52,6 +52,8 @@ class QuantArgument: # PTQ related parameters do_ptq: bool = field(default=False, metadata={"help": "Whether to use PTQ"}) ptq_step: int = field(default=8, metadata={"help": "Step for PTQ"}) + ptq_weight_only: bool = field(default=False, metadata={"help": "Whether to use PTQ weight only"}) + quant_bits: int = field(default=8, metadata={"help": "Quantization bit size"}) fused_qkv: bool = field(default=False, metadata={"help": "Whether to use Fused Quantized QKV"}) parallel_ffn: bool = field(default=False, metadata={"help": "Whether to use Parallel FFN"}) diff --git a/llm/causallm/finetune_generation.py b/llm/causallm/finetune_generation.py index 8c6429564f19..a87641a0d5ff 100644 --- a/llm/causallm/finetune_generation.py +++ b/llm/causallm/finetune_generation.py @@ -231,7 +231,13 @@ def compute_metrics_do_generation(eval_preds): raise NotImplementedError( "PTQ strategy not supported for LoRA model. Please merge lora parameters to pretrain model first." ) - from quant import apply_ptq, apply_shift, apply_smooth, get_ptq_model_config + from quant import ( + apply_gptq, + apply_ptq, + apply_shift, + apply_smooth, + get_ptq_model_config, + ) trainer.model.eval() # Prepare ptq dataloader @@ -255,7 +261,10 @@ def compute_metrics_do_generation(eval_preds): if quant_args.smooth: apply_smooth(quant_args, trainer, ptq_dataloader, ptq_model_config) - apply_ptq(quant_args, trainer, ptq_dataloader) + if quant_args.do_gptq: + apply_gptq(quant_args, trainer, ptq_dataloader) + else: + apply_ptq(quant_args, trainer, ptq_dataloader) # Evaluation dev set if training_args.do_eval: diff --git a/llm/causallm/quant.py b/llm/causallm/quant.py index 9b0d55b1a2ed..d0c89a7db64b 100644 --- a/llm/causallm/quant.py +++ b/llm/causallm/quant.py @@ -20,12 +20,14 @@ from paddle.quantization import PTQ, QAT, QuantConfig from paddle.quantization.quanters.abs_max import FakeQuanterWithAbsMaxObserverLayer from paddleslim.quant.advanced import ( + GPTQ, EMASampler, MultiStepSampler, PieceWiseSearch, Shift, Smooth, ) +from paddleslim.quant.advanced.utils import find_parent_layer_and_sub_name from paddleslim.quant.layers import ( QuantizedColumnParallelLinear, QuantizedRowParallelLinear, @@ -117,8 +119,8 @@ def apply_smooth(quant_args, trainer, ptq_dataloader, ptq_model_config): def apply_ptq(quant_args, trainer, ptq_dataloader): q_config = QuantConfig(activation=None, weight=None) - act_quanter = AbsmaxObserver() - weight_quanter = AbsMaxChannelWiseWeightObserver() + act_quanter = AbsmaxObserver() if not quant_args.ptq_weight_only else None + weight_quanter = AbsMaxChannelWiseWeightObserver(quant_bits=quant_args.quant_bits) q_config.add_qat_layer_mapping(ColumnParallelLinear, QuantizedColumnParallelLinear) q_config.add_qat_layer_mapping(RowParallelLinear, QuantizedRowParallelLinear) q_config.add_type_config( @@ -137,6 +139,26 @@ def apply_ptq(quant_args, trainer, ptq_dataloader): trainer.model = ptq.convert(trainer.model, inplace=True) +def apply_gptq(quant_args, trainer, ptq_dataloader): + num_layer = 0 + model = trainer.model + for cur_name, cur_layer in model.named_sublayers(): + if type(cur_layer) in [paddle.nn.Linear, ColumnParallelLinear, RowParallelLinear]: + num_layer += 1 + print("GPTQ layer", num_layer, cur_name) + parent_layer, sub_name = find_parent_layer_and_sub_name(model, cur_name) + cur_quant_layer = GPTQ(cur_layer) + setattr(parent_layer, sub_name, cur_quant_layer) + trainer.ptq_loop( + ptq_dataloader, + description="PTQ", + max_eval_iters=quant_args.ptq_step, + ) + cur_quant_layer.fasterquant(percdamp=0.1, groupsize=-1, actorder=True) + del cur_quant_layer + setattr(parent_layer, sub_name, cur_layer) + + def get_ptq_model_config(model): if isinstance(model, PrefixModelForCausalLM): base_model_prefix = model.model.base_model_prefix From a7947b2d100002d847c087e9ef57f96f73296aa7 Mon Sep 17 00:00:00 2001 From: RachelXu7 Date: Wed, 2 Aug 2023 02:50:03 +0000 Subject: [PATCH 2/3] update --- llm/causallm/quant.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llm/causallm/quant.py b/llm/causallm/quant.py index d0c89a7db64b..dde88a22c803 100644 --- a/llm/causallm/quant.py +++ b/llm/causallm/quant.py @@ -32,7 +32,7 @@ QuantizedColumnParallelLinear, QuantizedRowParallelLinear, ) -from paddleslim.quant.observers import AbsMaxChannelWiseWeightObserver, AbsmaxObserver +from paddleslim.quant.observers import AbsMaxChannelWiseWeightObserver, AVGObserver from paddleslim.quant.quanters import PACTQuanter from paddlenlp.peft import PrefixModelForCausalLM @@ -95,7 +95,7 @@ def apply_smooth(quant_args, trainer, ptq_dataloader, ptq_model_config): search_scale_min=1.0, search_scale_max=5.0, weight_quant_method="abs_max_channel_wise", - act_quant_method="abs_max", + act_quant_method="avg", ) else: search_func = None @@ -119,7 +119,7 @@ def apply_smooth(quant_args, trainer, ptq_dataloader, ptq_model_config): def apply_ptq(quant_args, trainer, ptq_dataloader): q_config = QuantConfig(activation=None, weight=None) - act_quanter = AbsmaxObserver() if not quant_args.ptq_weight_only else None + act_quanter = AVGObserver() if not quant_args.ptq_weight_only else None weight_quanter = AbsMaxChannelWiseWeightObserver(quant_bits=quant_args.quant_bits) q_config.add_qat_layer_mapping(ColumnParallelLinear, QuantizedColumnParallelLinear) q_config.add_qat_layer_mapping(RowParallelLinear, QuantizedRowParallelLinear) From e898a0619cc2c2f703db25b0e665a97fbac1aa1a Mon Sep 17 00:00:00 2001 From: RachelXu7 Date: Wed, 2 Aug 2023 04:21:19 +0000 Subject: [PATCH 3/3] update --- llm/causallm/finetune_generation.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/llm/causallm/finetune_generation.py b/llm/causallm/finetune_generation.py index a87641a0d5ff..1abb79cce6e6 100644 --- a/llm/causallm/finetune_generation.py +++ b/llm/causallm/finetune_generation.py @@ -48,9 +48,9 @@ def main(): training_args.print_config(quant_args, "Quant") training_args.print_config(gen_args, "Generation") - if sum([quant_args.do_ptq, quant_args.do_qat, training_args.do_train]) > 1: + if sum([quant_args.do_ptq, quant_args.do_qat, quant_args.do_gptq, training_args.do_train]) > 1: raise ValueError( - "--do_train, --do_ptq and --do_qat cannot work at the same time. Please choose only one at a time" + "--do_train, --do_ptq, --do_gptq and --do_qat cannot work at the same time. Please choose only one at a time" ) # Setup GPU & distributed training @@ -231,13 +231,7 @@ def compute_metrics_do_generation(eval_preds): raise NotImplementedError( "PTQ strategy not supported for LoRA model. Please merge lora parameters to pretrain model first." ) - from quant import ( - apply_gptq, - apply_ptq, - apply_shift, - apply_smooth, - get_ptq_model_config, - ) + from quant import apply_ptq, apply_shift, apply_smooth, get_ptq_model_config trainer.model.eval() # Prepare ptq dataloader @@ -261,10 +255,16 @@ def compute_metrics_do_generation(eval_preds): if quant_args.smooth: apply_smooth(quant_args, trainer, ptq_dataloader, ptq_model_config) - if quant_args.do_gptq: - apply_gptq(quant_args, trainer, ptq_dataloader) - else: - apply_ptq(quant_args, trainer, ptq_dataloader) + apply_ptq(quant_args, trainer, ptq_dataloader) + + if quant_args.do_gptq: + if isinstance(model, LoRAModel): + raise NotImplementedError( + "PTQ strategy not supported for LoRA model. Please merge lora parameters to pretrain model first." + ) + from quant import apply_gptq + + apply_gptq(quant_args, trainer, ptq_dataloader) # Evaluation dev set if training_args.do_eval: