From db692e72611c4112488db43f7acd5cd822c0bec3 Mon Sep 17 00:00:00 2001 From: ckl117 Date: Tue, 13 Aug 2024 20:01:24 +0800 Subject: [PATCH 1/2] qwen2-1.5b a8w8c8 --- .../transformers/generation_utils.py | 3 +- .../transformers/qwen2/modeling.py | 464 ++++++++++++++++-- 2 files changed, 419 insertions(+), 48 deletions(-) diff --git a/paddlenlp/experimental/transformers/generation_utils.py b/paddlenlp/experimental/transformers/generation_utils.py index eed68a83d53f..b95653bbd0f8 100644 --- a/paddlenlp/experimental/transformers/generation_utils.py +++ b/paddlenlp/experimental/transformers/generation_utils.py @@ -672,8 +672,7 @@ def _post_process_( # sample probs = F.softmax(logits) - # _, next_tokens = top_p_sampling(probs, top_p, -1) - _, next_tokens = paddle.topk(probs, 1, -1) + _, next_tokens = paddle.tensor.top_p_sampling(probs, top_p) if self.config.tensor_parallel_degree > 1: paddle.distributed.broadcast(next_tokens, 0) diff --git a/paddlenlp/experimental/transformers/qwen2/modeling.py b/paddlenlp/experimental/transformers/qwen2/modeling.py index 189552aff006..c09a48e371ce 100644 --- a/paddlenlp/experimental/transformers/qwen2/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2/modeling.py @@ -13,16 +13,26 @@ # limitations under the License. from __future__ import annotations +import json +import os from functools import partial import numpy as np import paddle from paddle import nn +from paddle.distributed import fleet from paddle.nn.quant import weight_quantize +from paddlenlp.experimental.model_utils import ( + ActScalesLoader, + CacheScaleLoader, + WeightScalesLoader, +) from paddlenlp.experimental.transformers.fused_transformer_layers import ( FusedBlockMultiTransformer, + FusedBlockMultiTransformerA8W8, FusedBlockMultiTransformerWeightOnly, + FusedMultiTransformerA8W8, FusedMultiTransformerBase, FusedMultiTransformerConfig, FusedMultiTransformerWeightOnly, @@ -33,6 +43,7 @@ ) from paddlenlp.experimental.transformers.utils import infererence_model_from_pretrained from paddlenlp.transformers import Qwen2Config, Qwen2PretrainedModel +from paddlenlp.transformers.conversion_utils import split_param_func from paddlenlp.transformers.model_outputs import ( # CausalLMOutputWithCrossAttentions, BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, @@ -88,6 +99,11 @@ def __init__(self, config: Qwen2Config): elif config.quant_type == "weight_only_int4": self.use_weight_only = True self.quant_algo = "weight_only_int4" + elif "a8w8" in config.quant_type: + self.quant_model_path = config.model_name_or_path + self.shift = config.quantization_config.shift + self.smooth = config.quantization_config.smooth + self.shift_smooth_all_linears = config.quantization_config.shift_smooth_all_linears if self.use_weight_only: assert ( @@ -96,7 +112,26 @@ def __init__(self, config: Qwen2Config): self.quant_type ) - self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size) + if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: + self.embed_tokens = fleet.meta_parallel.VocabParallelEmbedding( + self.vocab_size, + self.hidden_size, + weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()), + ) + else: + self.embed_tokens = nn.Embedding( + self.vocab_size, + self.hidden_size, + ) + + # get ring_id + ring_id = -1 + try: + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + ring_id = model_parallel_group.id + except: + pass ln_scale_attrs = [paddle.ParamAttr(name="fuseqwen2.{}.ln_scale".format(i)) for i in range(self.num_layers)] qkv_weight_attrs = [ @@ -115,6 +150,7 @@ def __init__(self, config: Qwen2Config): ffn_ln_scale_attrs = [ paddle.ParamAttr(name="fuseqwen2.{}.ffn_ln_scale".format(i)) for i in range(self.num_layers) ] + ffn1_weight_attrs = [ paddle.ParamAttr( name="fuseqwen2.{}.ffn1_weight".format(i), initializer=paddle.nn.initializer.Constant(value=0) @@ -148,6 +184,60 @@ def __init__(self, config: Qwen2Config): ffn1_bias_attrs = None ffn2_bias_attrs = None + if "a8w8" in self.quant_type: + qkv_out_scale_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.qkv_out_scale".format(i)) for i in range(self.num_layers) + ] + linear_out_scale_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.linear_out_scale".format(i)) for i in range(self.num_layers) + ] + ffn1_out_scale_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.ffn1_out_scale".format(i)) for i in range(self.num_layers) + ] + ffn2_out_scale_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.ffn2_out_scale".format(i)) for i in range(self.num_layers) + ] + + if self.shift_smooth_all_linears: + linear_shift_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.linear_shift".format(i)) for i in range(self.num_layers) + ] + linear_smooth_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.linear_smooth".format(i)) for i in range(self.num_layers) + ] + ffn2_shift_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.ffn2_shift".format(i)) for i in range(self.num_layers) + ] + ffn2_smooth_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.ffn2_smooth".format(i)) for i in range(self.num_layers) + ] + + if self.shift: + ln_bias_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.ln_bias".format(i)) for i in range(self.num_layers) + ] + ffn_ln_bias_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.ffn_ln_bias".format(i)) for i in range(self.num_layers) + ] + qkv_bias_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.qkv_bias".format(i)) for i in range(self.num_layers) + ] + ffn1_bias_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.ffn1_bias".format(i)) for i in range(self.num_layers) + ] + if self.shift_smooth_all_linears: + out_proj_bias_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.out_proj_bias".format(i)) for i in range(self.num_layers) + ] + ffn2_bias_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.ffn2_bias".format(i)) for i in range(self.num_layers) + ] + + qkv_weight_scale_attrs = None + out_proj_weight_scale_attrs = None + ffn1_weight_scale_attrs = None + ffn2_weight_scale_attrs = None + if self.use_weight_only: qkv_weight_scale_attrs = [ paddle.ParamAttr(name="fuseqwen2.{}.qkv_weight_scale".format(i)) for i in range(self.num_layers) @@ -166,6 +256,19 @@ def __init__(self, config: Qwen2Config): cache_v_scale_attrs = None cache_k_out_scale_attrs = None cache_v_out_scale_attrs = None + if config.cachekv_int8_type == "static": + cache_k_scale_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.cache_k_scale".format(i)) for i in range(self.num_layers) + ] + cache_v_scale_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.cache_v_scale".format(i)) for i in range(self.num_layers) + ] + cache_k_out_scale_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.cache_k_out_scale".format(i)) for i in range(self.num_layers) + ] + cache_v_out_scale_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.cache_v_out_scale".format(i)) for i in range(self.num_layers) + ] transformer_config = FusedMultiTransformerConfig( embed_dim=self.hidden_size, @@ -176,7 +279,7 @@ def __init__(self, config: Qwen2Config): activation="swiglu", num_layers=config.num_hidden_layers, nranks=config.tensor_parallel_degree, - ring_id=-1, + ring_id=ring_id, ln_scale_attrs=ln_scale_attrs, qkv_weight_attrs=qkv_weight_attrs, qkv_weight_scale_attrs=qkv_weight_scale_attrs, @@ -210,6 +313,7 @@ def __init__(self, config: Qwen2Config): use_neox_rotary_style=self.use_neox, cachekv_int8_type=config.cachekv_int8_type, rank_id=config.tensor_parallel_rank, + trans_qkvw=(False if paddle.is_compiled_with_rocm() and self.quant_type == "a8w8" else True), ) self.set_transformer_block(transformer_config) @@ -222,6 +326,8 @@ def __init__(self, config: Qwen2Config): def set_transformer_block(self, transformer_config): if self.use_weight_only: self.transformer_block = FusedMultiTransformerWeightOnly(transformer_config) + elif self.quant_type == "a8w8" or self.quant_type == "a8w8c8": + self.transformer_block = FusedMultiTransformerA8W8(transformer_config) else: self.transformer_block = FusedMultiTransformerBase(transformer_config) @@ -234,6 +340,7 @@ def set_input_embeddings(self, value): @paddle.no_grad() def set_state_dict(self, state_dict): head_size = self.hidden_size // self.num_attention_heads + split_fn = split_param_func() self.embed_tokens.weight.set_value( paddle.to_tensor(state_dict["qwen2.embed_tokens.weight"]).cast(self.embed_tokens.weight.dtype) ) @@ -246,42 +353,75 @@ def set_state_dict(self, state_dict): ) self.transformer_block.ln_scales[idx].set_value(ln_scale) - unfused_state_dict["qwen2.self_attn.q_proj.weight"] = state_dict[ - "qwen2.layers.{}.self_attn.q_proj.weight".format(idx) - ] - unfused_state_dict["qwen2.self_attn.k_proj.weight"] = state_dict[ - "qwen2.layers.{}.self_attn.k_proj.weight".format(idx) - ] - unfused_state_dict["qwen2.self_attn.v_proj.weight"] = state_dict[ - "qwen2.layers.{}.self_attn.v_proj.weight".format(idx) - ] - - concated_qkv_weight = ( - np.concatenate( - [ - unfused_state_dict["qwen2.self_attn.q_proj.weight"], - unfused_state_dict["qwen2.self_attn.k_proj.weight"], - unfused_state_dict["qwen2.self_attn.v_proj.weight"], - ], + if "qwen2.layers.{}.self_attn.qkv_proj.weight".format(idx) in state_dict.keys(): + concated_qkv_weight = np.concatenate( + split_fn( + state_dict["qwen2.layers.{}.self_attn.qkv_proj.weight".format(idx)], + is_qkv=True, + num_heads=self.num_attention_heads // self.config.tensor_parallel_degree, + num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, + ), axis=-1, - ) - .transpose(1, 0) - .reshape( - ( - self.num_attention_heads // self.config.tensor_parallel_degree - + 2 * self.num_key_value_heads // self.config.tensor_parallel_degree + ).transpose(1, 0) + else: + unfused_state_dict = {} + unfused_state_dict["qwen2.self_attn.q_proj.weight"] = state_dict[ + "qwen2.layers.{}.self_attn.q_proj.weight".format(idx) + ] + unfused_state_dict["qwen2.self_attn.k_proj.weight"] = state_dict[ + "qwen2.layers.{}.self_attn.k_proj.weight".format(idx) + ] + unfused_state_dict["qwen2.self_attn.v_proj.weight"] = state_dict[ + "qwen2.layers.{}.self_attn.v_proj.weight".format(idx) + ] + if paddle.is_compiled_with_rocm() and (self.quant_type == "a8w8" or self.quant_type == "a8w8c8"): + concated_qkv_weight = np.concatenate( + [ + unfused_state_dict["self_attn.q_proj.weight"], + unfused_state_dict["self_attn.k_proj.weight"], + unfused_state_dict["self_attn.v_proj.weight"], + ], + axis=-1, + ).reshape( + self.hidden_size, + ( + self.num_attention_heads // self.config.tensor_parallel_degree + + 2 * self.num_key_value_heads // self.config.tensor_parallel_degree + ) + * (head_size), + ) + else: + concated_qkv_weight = ( + np.concatenate( + [ + unfused_state_dict["qwen2.self_attn.q_proj.weight"], + unfused_state_dict["qwen2.self_attn.k_proj.weight"], + unfused_state_dict["qwen2.self_attn.v_proj.weight"], + ], + axis=-1, + ) + .transpose(1, 0) + .reshape( + ( + self.num_attention_heads // self.config.tensor_parallel_degree + + 2 * self.num_key_value_heads // self.config.tensor_parallel_degree + ) + * (head_size), + self.hidden_size, + ) ) - * (head_size), - self.hidden_size, - ) - ) qkv_weight = paddle.to_tensor(concated_qkv_weight).cast(paddle.get_default_dtype()) + if self.use_weight_only: qkv_weight = paddle.transpose(qkv_weight, perm=[1, 0]) qkv_quanted_weight, qkv_weight_scale = weight_quantize(qkv_weight, algo=self.quant_algo) self.transformer_block.qkv_weights[idx].set_value(qkv_quanted_weight) self.transformer_block.qkv_weights_scale[idx].set_value(qkv_weight_scale) + elif "a8w8" in self.quant_type: + self.transformer_block.qkv_weights[idx].set_value( + paddle.cast(paddle.to_tensor(concated_qkv_weight), "int8") + ) else: self.transformer_block.qkv_weights[idx].set_value(qkv_weight) @@ -303,8 +443,10 @@ def set_state_dict(self, state_dict): ], axis=-1, ) - qkv_bias = paddle.to_tensor(concated_qkv_biases).cast(self.transformer_block.qkv_biases[idx].dtype) - self.transformer_block.qkv_biases[idx].set_value(qkv_bias) + qkv_bias = paddle.to_tensor(concated_qkv_biases) + self.transformer_block.qkv_biases[idx].set_value( + qkv_bias.cast(self.transformer_block.qkv_biases[idx].dtype) + ) linear_weight = paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.weight".format(idx)]).cast( paddle.get_default_dtype() @@ -313,39 +455,267 @@ def set_state_dict(self, state_dict): linear_quanted_weight, linear_weight_scale = weight_quantize(linear_weight, algo=self.quant_algo) self.transformer_block.linear_weights[idx].set_value(linear_quanted_weight) self.transformer_block.linear_weights_scale[idx].set_value(linear_weight_scale) + elif "a8w8" in self.quant_type: + if paddle.is_compiled_with_rocm(): + self.transformer_block.linear_weights[idx].set_value( + paddle.cast( + paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.weight".format(idx)]), "int8" + ) + ) + else: + self.transformer_block.linear_weights[idx].set_value( + paddle.cast( + paddle.to_tensor( + state_dict["qwen2.layers.{}.self_attn.o_proj.weight".format(idx)] + ).transpose((1, 0)), + "int8", + ) + ) else: - self.transformer_block.linear_weights[idx].set_value(linear_weight) + self.transformer_block.linear_weights[idx].set_value( + linear_weight.cast(self.transformer_block.linear_weights[idx].dtype) + ) ffn_ln_scale = paddle.to_tensor( state_dict["qwen2.layers.{}.post_attention_layernorm.weight".format(idx)], - ).cast( - self.transformer_block.ffn_ln_scales[idx].dtype, ) - self.transformer_block.ffn_ln_scales[idx].set_value(ffn_ln_scale) - up_weight = paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.up_proj.weight".format(idx)]).cast( - paddle.get_default_dtype() + self.transformer_block.ffn_ln_scales[idx].set_value( + ffn_ln_scale.cast(self.transformer_block.ffn_ln_scales[idx].dtype) ) - gate_weight = paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.gate_proj.weight".format(idx)]).cast( - paddle.get_default_dtype() - ) - ffn1_weight = paddle.concat(x=[gate_weight, up_weight], axis=-1) + + if "qwen2.layers.{}.mlp.gate_up_fused_proj.weight".format(idx) in state_dict.keys(): + concated_ffn1_weight = np.concatenate( + split_fn(state_dict["qwen2.layers.{}.mlp.gate_up_fused_proj.weight".format(idx)]), axis=-1 + ) + else: + unfused_state_dict["mlp.gate_proj.weight"] = state_dict[ + "qwen2.layers.{}.mlp.gate_proj.weight".format(idx) + ] + unfused_state_dict["mlp.up_proj.weight"] = state_dict["qwen2.layers.{}.mlp.up_proj.weight".format(idx)] + concated_ffn1_weight = np.concatenate( + [unfused_state_dict["mlp.gate_proj.weight"], unfused_state_dict["mlp.up_proj.weight"]], axis=-1 + ) + ffn1_weight = paddle.to_tensor(concated_ffn1_weight) + if self.use_weight_only: ffn1_quanted_weight, ffn1_weight_scale = weight_quantize(ffn1_weight, algo=self.quant_algo) self.transformer_block.ffn1_weights[idx].set_value(ffn1_quanted_weight) self.transformer_block.ffn1_weights_scale[idx].set_value(ffn1_weight_scale) + elif "a8w8" in self.quant_type: + if paddle.is_compiled_with_rocm(): + self.transformer_block.ffn1_weights[idx].set_value( + paddle.cast(paddle.to_tensor(ffn1_weight), "int8") + ) + else: + self.transformer_block.ffn1_weights[idx].set_value( + paddle.cast(paddle.to_tensor(ffn1_weight).transpose((1, 0)), "int8") + ) else: - self.transformer_block.ffn1_weights[idx].set_value(ffn1_weight) + self.transformer_block.ffn1_weights[idx].set_value( + ffn1_weight.cast(self.transformer_block.ffn1_weights[idx].dtype) + ) - ffn2_weight = paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.weight".format(idx)]).cast( - paddle.get_default_dtype() - ) + ffn2_weight = paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.weight".format(idx)]) if self.use_weight_only: ffn2_quanted_weight, ffn2_weight_scale = weight_quantize(ffn2_weight, algo=self.quant_algo) self.transformer_block.ffn2_weights[idx].set_value(ffn2_quanted_weight) self.transformer_block.ffn2_weights_scale[idx].set_value(ffn2_weight_scale) + elif "a8w8" in self.quant_type: + if paddle.is_compiled_with_rocm(): + self.transformer_block.ffn2_weights[idx].set_value( + paddle.cast( + paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.weight".format(idx)]), "int8" + ) + ) + else: + self.transformer_block.ffn2_weights[idx].set_value( + paddle.cast( + paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.weight".format(idx)]).transpose( + (1, 0) + ), + "int8", + ) + ) else: - self.transformer_block.ffn2_weights[idx].set_value(ffn2_weight) + self.transformer_block.ffn2_weights[idx].set_value( + ffn2_weight.cast(self.transformer_block.ffn2_weights[idx].dtype) + ) + + if "a8w8" in self.quant_type: + if self.shift_smooth_all_linears: + self.transformer_block.linear_shifts[idx].set_value( + paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.shift_bias".format(idx)]) + ) + self.transformer_block.linear_smooths[idx].set_value( + paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.smooth_weight".format(idx)]) + ) + self.transformer_block.ffn2_shifts[idx].set_value( + paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.shift_bias".format(idx)]) + ) + self.transformer_block.ffn2_smooths[idx].set_value( + paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.smooth_weight".format(idx)]) + ) + + if self.shift: + self.transformer_block.ln_biases[idx].set_value( + paddle.to_tensor(state_dict["qwen2.layers.{}.input_layernorm.bias".format(idx)]) + ) + self.transformer_block.ffn_ln_biases[idx].set_value( + paddle.to_tensor(state_dict["qwen2.layers.{}.post_attention_layernorm.bias".format(idx)]) + ) + + unfused_state_dict["self_attn.q_proj.bias"] = state_dict[ + "qwen2.layers.{}.self_attn.q_proj.bias".format(idx) + ] + unfused_state_dict["self_attn.k_proj.bias"] = state_dict[ + "qwen2.layers.{}.self_attn.k_proj.bias".format(idx) + ] + unfused_state_dict["self_attn.v_proj.bias"] = state_dict[ + "qwen2.layers.{}.self_attn.v_proj.bias".format(idx) + ] + + concated_qkv_biases = np.concatenate( + [ + unfused_state_dict["self_attn.q_proj.bias"], + unfused_state_dict["self_attn.k_proj.bias"], + unfused_state_dict["self_attn.v_proj.bias"], + ], + axis=-1, + ) + + self.transformer_block.qkv_biases[idx].set_value(paddle.to_tensor(concated_qkv_biases)) + + unfused_state_dict["mlp.gate_proj.bias"] = state_dict[ + "qwen2.layers.{}.mlp.gate_proj.bias".format(idx) + ] + unfused_state_dict["mlp.up_proj.bias"] = state_dict["qwen2.layers.{}.mlp.up_proj.bias".format(idx)] + + concated_ffn1_bias = np.concatenate( + [unfused_state_dict["mlp.gate_proj.bias"], unfused_state_dict["mlp.up_proj.bias"]], axis=-1 + ) + + self.transformer_block.ffn1_biases[idx].set_value(paddle.to_tensor(concated_ffn1_bias)) + + if self.shift_smooth_all_linears: + self.transformer_block.linear_biases[idx].set_value( + paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.bias".format(idx)]) + ) + self.transformer_block.ffn2_biases[idx].set_value( + paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.layer.bias".format(idx)]) + ) + + if "a8w8" in self.quant_type: + current_work_dir = os.path.dirname(__file__) + scale_map_file = ( + f"{current_work_dir}/ptq_scales_map.json" + if not self.shift_smooth_all_linears + else f"{current_work_dir}/ptq_scales_map_shift_smooth.json" + ) + print(f"scale_map_file = {scale_map_file}") + with open(scale_map_file) as json_file: + scale_map_dict = json.load(json_file) + act_scale_map_dict = scale_map_dict["act_scale"] + weight_scale_map_dict = scale_map_dict["weight_scale"] + cache_scale_map_dict = scale_map_dict["cachekv_scale"] + # TODO(RichardWooSJTU): support multi-cards + + act_scale_json_path = os.path.join(self.quant_model_path, "act_scales.json") + weight_scale_json_path = os.path.join(self.quant_model_path, "weight_scales.json") + if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq: + act_scale_json_path = os.path.join( + self.quant_model_path, f"act_scales_{self.config.tensor_parallel_rank}.json" + ) + weight_scale_json_path = os.path.join( + self.quant_model_path, f"weight_scales_{self.config.tensor_parallel_rank}.json" + ) + act_scale_loader = ActScalesLoader( + act_scale_json_path, act_scale_map_dict, num_of_layers=self.config.num_hidden_layers + ) + self.transformer_block.act_scales = act_scale_loader.scale + print(f"weight_scale_json_path = {weight_scale_json_path}") + # print(f'weight_scale_map_dict = {weight_scale_map_dict}') + weight_scales_loader = WeightScalesLoader( + weight_scale_json_path, + weight_scale_map_dict, + num_of_layers=self.config.num_hidden_layers, + concat_qkv=True, + concat_ffn1=True, + ) + + if self.config.cachekv_int8_type == "static": + cache_scale_json_path = os.path.join(self.quant_model_path, "cachekv_scales.json") + if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq: + cache_scale_json_path = os.path.join( + self.quant_model_path, f"cachekv_scales_{self.config.tensor_parallel_rank}.json" + ) + cache_scales_loader = CacheScaleLoader( + cache_scale_json_path, + cache_scale_map_dict, + num_of_layers=self.config.num_hidden_layers, + num_heads=self.num_attention_heads // self.config.tensor_parallel_degree, + num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, + ) + + for k, v in cache_scales_loader.scale.items(): + for i_layer, weight_scale in enumerate(v): + weight_scale = weight_scale.astype("float32") + if k == "cache_k_scale": + self.transformer_block.cache_k_scales[i_layer].set_value(weight_scale) + elif k == "cache_v_scale": + self.transformer_block.cache_v_scales[i_layer].set_value(weight_scale) + elif k == "cache_k_out_scale": + self.transformer_block.cache_k_out_scales[i_layer].set_value(weight_scale) + else: + self.transformer_block.cache_v_out_scales[i_layer].set_value(weight_scale) + + for k, v in weight_scales_loader.scale.items(): + if "qkv_" in k: + for i_layer, weight_scale in enumerate(v): + tmp = paddle.to_tensor( + weight_scale + / ( + 127.0 * 127.0 * act_scale_loader.scale["qkv_in_scale"][i_layer] + ) # [3 * num_head * dim_head] + ).reshape([-1]) + if self.config.tensor_parallel_degree > 1 and self.config.single_card_ptq: + tmp = ( + tmp.reshape([3, self.num_attention_heads, head_size]) + .split(self.config.tensor_parallel_degree, axis=1)[ + self.config.tensor_parallel_rank + ] + .reshape([-1]) + ) + self.transformer_block.qkv_out_scales[i_layer].set_value(tmp) + pass + elif "out_linear_" in k: + for i_layer, weight_scale in enumerate(v): + tmp = paddle.to_tensor( + weight_scale / (127.0 * 127.0 * act_scale_loader.scale["out_linear_in_scale"][i_layer]) + ) + self.transformer_block.linear_out_scales[i_layer].set_value(tmp) + elif "ffn1_weight_scale" in k: + for i_layer, weight_scale in enumerate(v): + tmp = paddle.to_tensor( + weight_scale / (127.0 * 127.0 * act_scale_loader.scale["ffn1_in_scale"][i_layer]) + ) + if self.config.tensor_parallel_degree > 1 and self.config.single_card_ptq: + tmp = paddle.split(tmp, self.config.tensor_parallel_degree * 2) + tmp = paddle.concat( + [ + tmp[self.config.tensor_parallel_rank], + tmp[self.config.tensor_parallel_rank + self.config.tensor_parallel_degree], + ], + axis=0, + ) + self.transformer_block.ffn1_out_scales[i_layer].set_value(tmp) + elif "ffn2" in k: + for i_layer, weight_scale in enumerate(v): + self.transformer_block.ffn2_out_scales[i_layer].set_value( + paddle.to_tensor( + weight_scale / (127.0 * 127.0 * act_scale_loader.scale["ffn2_in_scale"][i_layer]) + ) + ) def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) @@ -647,6 +1017,8 @@ def __init__(self, config: Qwen2Config): def set_transformer_block(self, transformer_config): if self.use_weight_only: self.transformer_block = FusedBlockMultiTransformerWeightOnly(transformer_config) + elif self.quant_type == "a8w8" or self.quant_type == "a8w8c8": + self.transformer_block = FusedBlockMultiTransformerA8W8(transformer_config) else: self.transformer_block = FusedBlockMultiTransformer(transformer_config) From 946a632436f01275d55c99de97392c5877cc7a73 Mon Sep 17 00:00:00 2001 From: ckl117 Date: Mon, 19 Aug 2024 14:32:58 +0800 Subject: [PATCH 2/2] code check --- paddlenlp/experimental/transformers/qwen2/modeling.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/paddlenlp/experimental/transformers/qwen2/modeling.py b/paddlenlp/experimental/transformers/qwen2/modeling.py index c09a48e371ce..2d785341f546 100644 --- a/paddlenlp/experimental/transformers/qwen2/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2/modeling.py @@ -496,7 +496,7 @@ def set_state_dict(self, state_dict): concated_ffn1_weight = np.concatenate( [unfused_state_dict["mlp.gate_proj.weight"], unfused_state_dict["mlp.up_proj.weight"]], axis=-1 ) - ffn1_weight = paddle.to_tensor(concated_ffn1_weight) + ffn1_weight = paddle.to_tensor(concated_ffn1_weight).cast(paddle.get_default_dtype()) if self.use_weight_only: ffn1_quanted_weight, ffn1_weight_scale = weight_quantize(ffn1_weight, algo=self.quant_algo) @@ -612,7 +612,6 @@ def set_state_dict(self, state_dict): if not self.shift_smooth_all_linears else f"{current_work_dir}/ptq_scales_map_shift_smooth.json" ) - print(f"scale_map_file = {scale_map_file}") with open(scale_map_file) as json_file: scale_map_dict = json.load(json_file) act_scale_map_dict = scale_map_dict["act_scale"] @@ -633,8 +632,6 @@ def set_state_dict(self, state_dict): act_scale_json_path, act_scale_map_dict, num_of_layers=self.config.num_hidden_layers ) self.transformer_block.act_scales = act_scale_loader.scale - print(f"weight_scale_json_path = {weight_scale_json_path}") - # print(f'weight_scale_map_dict = {weight_scale_map_dict}') weight_scales_loader = WeightScalesLoader( weight_scale_json_path, weight_scale_map_dict,