diff --git a/paddlenlp/quantization/quantization_linear.py b/paddlenlp/quantization/quantization_linear.py index 996eb74d7188..871eef56e93e 100644 --- a/paddlenlp/quantization/quantization_linear.py +++ b/paddlenlp/quantization/quantization_linear.py @@ -83,7 +83,7 @@ def __init__( self.quant_scale = self.create_parameter( shape=[out_features], attr=scale_attr, - dtype="float32", + dtype=self._dtype, is_bias=False, ) if self.quant_algo in ["fp4", "nf4"]: @@ -231,7 +231,7 @@ def __init__( self.quant_scale = self.create_parameter( shape=[self.output_size_per_partition], attr=scale_attr, - dtype="float32", + dtype=self._dtype, is_bias=False, ) self.quant_scale.is_distributed = True if self.is_mp else False @@ -345,7 +345,7 @@ def __init__( self.quant_scale = self.create_parameter( shape=[out_features], attr=scale_attr, - dtype="float32", + dtype=self._dtype, is_bias=False, ) self.quant_scale.is_distributed = True if self.is_mp else False diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index b5b2a85a75f4..74df2c24bfc9 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -2216,6 +2216,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): quantization_config=config.quantization_config, llm_int8_threshold=config.quantization_config.llm_int8_threshold, ) + quantization_linear_list = [] + for key in model.state_dict().keys(): + if "quant_weight" in key: + quantization_linear_list.append(key[:-13]) model, missing_keys, unexpected_keys, mismatched_keys = cls._load_pretrained_model( model=model,