From e410e4a18f9f5274c894b925d494bfa091472fcc Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Tue, 6 Aug 2024 09:08:22 +0000 Subject: [PATCH 1/7] stage 1 --- csrc/generation/get_padding_offset_v2.cu | 2 +- llm/predict/export_model.py | 3 +- llm/predict/predictor.py | 390 +++++++----------- llm/utils/utils.py | 47 --- .../transformers/llama/modeling.py | 7 +- paddlenlp/quantization/quantization_config.py | 4 +- 6 files changed, 151 insertions(+), 302 deletions(-) diff --git a/csrc/generation/get_padding_offset_v2.cu b/csrc/generation/get_padding_offset_v2.cu index 080764ed9955..737351fa0b5d 100644 --- a/csrc/generation/get_padding_offset_v2.cu +++ b/csrc/generation/get_padding_offset_v2.cu @@ -103,7 +103,7 @@ std::vector GetPaddingOffsetV2InferDtype(const paddle::DataTyp } PD_BUILD_OP(get_padding_offset_v2) - .Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"}) + .Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"}) .Outputs({"x_remove_padding", "cum_offsets_out", "padding_offset", "cu_seqlens_q", "cu_seqlens_k"}) .SetKernelFn(PD_KERNEL(GetPaddingOffsetV2)) .SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetV2InferShape)) diff --git a/llm/predict/export_model.py b/llm/predict/export_model.py index df8598e05cb8..91b3bc4cb7c3 100644 --- a/llm/predict/export_model.py +++ b/llm/predict/export_model.py @@ -61,7 +61,8 @@ def main(): }, ) predictor.model.config.save_pretrained(export_args.output_path) - predictor.model.generation_config.save_pretrained(export_args.output_path) + predictor.generation_config.save_pretrained(export_args.output_path) + predictor.tokenizer.save_pretrained(export_args.output_path) generate_rank_mapping(os.path.join(export_args.output_path, "rank_mapping.csv")) diff --git a/llm/predict/predictor.py b/llm/predict/predictor.py index 39ffc6b40fe6..1841ca1eea58 100644 --- a/llm/predict/predictor.py +++ b/llm/predict/predictor.py @@ -31,10 +31,7 @@ from utils.utils import ( dybatch_preprocess, get_alibi_slopes, - get_default_max_decoding_length, - get_default_max_encoding_length, get_infer_model_path, - get_model_max_position_embeddings, get_prefix_tuning_params, init_chat_template, load_real_time_tokens, @@ -67,8 +64,9 @@ class PredictorArgument: model_name_or_path: str = field(default=None, metadata={"help": "The directory of model."}) model_prefix: str = field(default="model", metadata={"help": "the prefix name of static model"}) - src_length: int = field(default=None, metadata={"help": "The max length of source text."}) - max_length: int = field(default=None, metadata={"help": "the max length for decoding."}) + src_length: int = field(default=4096, metadata={"help": "The max length of source text."}) + min_length: int = field(default=1, metadata={"help": "the min length for decoding."}) + max_length: int = field(default=2048, metadata={"help": "the max length for decoding."}) top_k: int = field(default=0, metadata={"help": "top_k parameter for generation"}) top_p: float = field(default=0.7, metadata={"help": "top_p parameter for generation"}) temperature: float = field(default=0.95, metadata={"help": "top_p parameter for generation"}) @@ -118,7 +116,9 @@ class PredictorArgument: block_size: int = field(default=64, metadata={"help": "the block size for cache_kvs."}) cachekv_int8_type: str = field( default=None, - metadata={"help": "If cachekv_int8_type set as `dynamic`, cache kv would be quantized to int8 dynamically. If cachekv_int8_type set as `static`, cache kv would be quantized to int8 Statically."}, + metadata={ + "help": "If cachekv_int8_type set as `dynamic`, cache kv would be quantized to int8 dynamically. If cachekv_int8_type set as `static`, cache kv would be quantized to int8 Statically." + }, ) chat_template: str = field( @@ -180,11 +180,11 @@ def init_dist_env(): def get_eos_token_id( tokenizer: PretrainedTokenizer, generation_config: Optional[GenerationConfig] = None -) -> int | List[List[int]]: +) -> List[List[int]]: """get eos_token_id from generation_config or tokenizer Returns: - int | List[int]: eos_token_id to stop the generation + List[int]: eos_token_id to stop the generation """ eos_token_ids = [] if tokenizer.eos_token_id is not None: @@ -390,8 +390,10 @@ def _infer(self, inputs: dict[str, np.ndarray]): return decoded_ids -class InferencePredictorMixin: +class InferencePredictorMixin(BasePredictor): def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer): + BasePredictor.__init__(self, config, tokenizer) + self.architectures = self.model_config.architectures[0].lower() self.dtype = config.dtype or self.model_config @@ -461,14 +463,6 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer): item.squeeze_(0) for item in paddle.split(prefix_cache, self.num_layers, axis=0) ] - try: - self.generation_config = GenerationConfig.from_pretrained(config.model_name_or_path) - except: - logger.warning( - "Can't find generation config, so it will not use generation_config field in the model config" - ) - self.generation_config = None - def _postprocess(self, predictions, return_tokens=False): if paddle.distributed.get_rank() == 0: tokens: np.ndarray = load_real_time_tokens() @@ -643,7 +637,7 @@ def _preprocess(self, source): return inputs -class StaticInferencePredictor(InferencePredictorMixin, BasePredictor): +class StaticInferencePredictor(InferencePredictorMixin): def __init__( self, config: PredictorArgument, @@ -651,7 +645,6 @@ def __init__( tokenizer: PretrainedTokenizer = None, ): self.cache_kvs_shape = cache_kvs_shape - BasePredictor.__init__(self, config, tokenizer) InferencePredictorMixin.__init__(self, config, tokenizer) self.predictor = self._create_predictor(config) @@ -735,7 +728,7 @@ def _infer(self, inputs): self.predictor.run() -class DygraphInferencePredictor(InferencePredictorMixin, BasePredictor): +class DygraphInferencePredictor(InferencePredictorMixin): def __init__( self, config: PredictorArgument, @@ -743,7 +736,6 @@ def __init__( tokenizer: PretrainedTokenizer = None, ): self.cache_kvs_shape = model.get_cache_kvs_shape(model.config, config.batch_size, config.total_max_length) - BasePredictor.__init__(self, config, tokenizer) InferencePredictorMixin.__init__(self, config, tokenizer) self.model = model @@ -766,8 +758,9 @@ def _infer(self, inputs: dict[str, paddle.Tensor]): return None -class BlockInferencePredictorMixin: +class BlockInferencePredictorMixin(BasePredictor): def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer): + BasePredictor.__init__(self, config, tokenizer) self.num_layers = len(self.cache_kvs_shape) // 2 self.num_attention_heads = self.cache_kvs_shape[0][-3] @@ -780,10 +773,7 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer): self.dtype = config.dtype or self.model_config.dtype - self.total_max_length = config.src_length + config.max_length self.block_size = config.block_size - self.pre_max_block_num = (self.total_max_length + config.block_size - 1) // config.block_size - self.max_block_nums = config.batch_size * self.pre_max_block_num try: self.rope_theta = self.model_config.rope_theta @@ -838,60 +828,71 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer): for _ in range(self.num_layers) ] - if config.benchmark: - self.min_length = config.max_length - else: - self.min_length = 2 - - self.free_list = [i for i in range(self.max_block_nums)][::-1] - self.used_list = [[] for _ in range(config.batch_size)] - - def init_inputs(self, config: PredictorArgument): - self.inputs = {} + def pad_batch_data(self, insts): + """Pad the instances to the max sequence length in batch.""" + seq_lens = [] + for i, inst in enumerate(insts): + length = len(inst) + seq_lens.append(length) + self.input_ids[i, :length] = np.array(inst) + return seq_lens + + def init_model_inputs(self, config: PredictorArgument): + self.input_ids = paddle.full( + shape=[config.batch_size, config.total_max_length], fill_value=self.tokenizer.pad_token_id, dtype="int64" + ) + self.model_inputs = {} if config.export_precache: - self.inputs["src_mask"] = (self.pre_cache_mask - 1) * 1e4 - self.inputs["pre_ids"] = paddle.full([config.batch_size, self.total_max_length], -1, dtype="int64") - self.inputs["bad_tokens"] = paddle.to_tensor( - [ - -1, - ], - dtype="int64", - ) - self.inputs["penalty_score"] = paddle.full(shape=[config.batch_size, 1], fill_value=1.0, dtype="float32") - self.inputs["frequency_score"] = paddle.full(shape=[config.batch_size, 1], fill_value=0.0, dtype="float32") - self.inputs["presence_score"] = paddle.full(shape=[config.batch_size, 1], fill_value=0.0, dtype="float32") + self.model_inputs["src_mask"] = (self.pre_cache_mask - 1) * 1e4 - self.inputs["min_length"] = paddle.full( - shape=[config.batch_size, 1], fill_value=self.min_length, dtype="int64" + self.model_inputs["block_tables"] = paddle.full( + shape=[config.batch_size, (config.total_max_length + config.block_size - 1) // config.block_size], + fill_value=-1, + dtype="int32", ) - self.inputs["max_length"] = paddle.full( - shape=[config.batch_size, 1], fill_value=config.max_length, dtype="int64" + self.model_inputs["top_p"] = paddle.full( + shape=[config.batch_size, 1], fill_value=config.top_p, dtype="float32" + ) + self.model_inputs["temperature"] = paddle.full( + shape=[config.batch_size, 1], fill_value=config.temperature, dtype="float32" + ) + self.model_inputs["eos_token_id"] = paddle.to_tensor( + np.array(get_eos_token_id(self.tokenizer, self.generation_config)).reshape(-1, 1).astype("int64") + ) + self.model_inputs["penalty_score"] = paddle.full( + shape=[config.batch_size, 1], fill_value=config.repetition_penalty, dtype="float32" + ) + self.model_inputs["frequency_score"] = paddle.full( + shape=[config.batch_size, 1], fill_value=0.0, dtype="float32" ) - self.inputs["stop_nums"] = paddle.full(shape=[1], fill_value=config.batch_size, dtype="int64") - self.inputs["rope_emb"] = self._get_rotary_position_embedding( - paddle.arange(self.total_max_length).reshape((1, -1)), self.head_dim, self.rope_theta + self.model_inputs["presence_score"] = paddle.full( + shape=[config.batch_size, 1], fill_value=0.0, dtype="float32" ) - eos_token_id = get_eos_token_id(self.tokenizer, self.generation_config) - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - self.inputs["eos_token_id"] = paddle.to_tensor( - np.array(eos_token_id * config.batch_size).reshape(-1, 1).astype("int64") + self.model_inputs["min_length"] = paddle.full( + shape=[config.batch_size, 1], fill_value=config.min_length, dtype="int64" ) + self.model_inputs["max_length"] = paddle.full( + shape=[config.batch_size, 1], fill_value=config.max_length, dtype="int64" + ) + self.model_inputs["rope_emb"] = self._get_rotary_position_embedding( + paddle.arange(config.total_max_length).reshape((1, -1)), self.head_dim, self.rope_theta + ) + self.model_inputs["bad_tokens"] = paddle.to_tensor([-1], dtype="int64") + self.model_inputs["is_block_step"] = paddle.full(shape=[config.batch_size], fill_value=False, dtype="bool") + # bloom model needs src_mask and tgt_mask! if "bloom" in self.architectures: - lower_one_tril = paddle.tril( - paddle.ones(shape=(self.total_max_length, self.total_max_length), dtype=self.dtype) - ) + lower_one_tril = paddle.tril(paddle.ones(shape=(config.total_max_length, config.total_max_length), dtype=self.dtype)) lower_one_tril = lower_one_tril[None, None, :, :] - self.inputs["src_mask"] = lower_one_tril.tile([self.batch_size, 1, 1, 1]) - self.inputs["tgt_mask"] = paddle.full( - shape=[config.batch_size, 1, 1, self.total_max_length], fill_value=1, dtype=self.dtype + self.model_inputs["src_mask"] = lower_one_tril.tile([self.batch_size, 1, 1, 1]) + self.model_inputs["tgt_mask"] = paddle.full( + shape=[config.batch_size, 1, 1, config.total_max_length], fill_value=1, dtype=self.dtype ) - arange_tensor_encoder = paddle.arange(self.total_max_length).astype(self.dtype) + arange_tensor_encoder = paddle.arange(config.total_max_length).astype(self.dtype) alibi_slopes = get_alibi_slopes(self.num_attention_heads) alibi = alibi_slopes[None, :, None, None] * arange_tensor_encoder - alibi_encoder = alibi.tile([self.batch_size, 1, self.total_max_length, 1]) + alibi_encoder = alibi.tile([self.batch_size, 1, config.total_max_length, 1]) alibi_decoder = alibi.tile( [ self.batch_size, @@ -900,43 +901,14 @@ def init_inputs(self, config: PredictorArgument): 1, ] ) - # self.inputs["src_mask/tgt_mask"] is read only, will not be updated! - self.inputs["src_mask"] = ( - alibi_encoder + (1 - self.inputs["src_mask"]) * paddle.finfo(self.dtype).min + # self.model_inputs["src_mask/tgt_mask"] is read only, will not be updated! + self.model_inputs["src_mask"] = ( + alibi_encoder + (1 - self.model_inputs["src_mask"]) * paddle.finfo(self.dtype).min ).cast(self.dtype) - self.inputs["tgt_mask"] = ( - alibi_decoder + (1 - self.inputs["tgt_mask"]) * paddle.finfo(self.dtype).min + self.model_inputs["tgt_mask"] = ( + alibi_decoder + (1 - self.model_inputs["tgt_mask"]) * paddle.finfo(self.dtype).min ).cast(self.dtype) - # need update - self.inputs["block_tables"] = paddle.full( - shape=[config.batch_size, self.pre_max_block_num], fill_value=-1, dtype="int32" - ) - self.inputs["input_ids"] = paddle.full( - shape=[config.batch_size, self.total_max_length], fill_value=-1, dtype="int64" - ) - self.inputs["top_p"] = paddle.full(shape=[config.batch_size, 1], fill_value=config.top_p, dtype="float32") - self.inputs["temperature"] = paddle.full(shape=[config.batch_size, 1], fill_value=1.0, dtype="float32") - self.inputs["seq_lens_this_time"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int32") - self.inputs["seq_lens_encoder"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int32") - self.inputs["seq_lens_decoder"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int32") - self.inputs["step_idx"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int64") - self.inputs["not_need_stop"] = paddle.full(shape=[1], fill_value=False, dtype="bool") - self.inputs["stop_flags"] = paddle.full(shape=[config.batch_size, 1], fill_value=True, dtype="bool") - self.inputs["next_tokens"] = paddle.full(shape=[config.batch_size, 1], fill_value=-1, dtype="int64") - self.inputs["is_block_step"] = paddle.full(shape=[config.batch_size], fill_value=False, dtype="bool") - free_list = list(range(self.pre_max_block_num - 1, int(self.pre_max_block_num * 0.75) - 1, -1)) - self.inputs["encoder_block_lens"] = paddle.full(shape=[config.batch_size], fill_value=0, dtype="int32") - self.inputs["step_block_list"] = paddle.full(shape=[config.batch_size], fill_value=-1, dtype="int32") - self.inputs["step_lens"] = paddle.full(shape=[1], fill_value=0, dtype="int32") - self.inputs["recover_block_list"] = paddle.full(shape=[config.batch_size], fill_value=-1, dtype="int32") - self.inputs["recover_lens"] = paddle.full(shape=[1], fill_value=0, dtype="int32") - self.inputs["need_block_list"] = paddle.full(shape=[config.batch_size], fill_value=-1, dtype="int32") - self.inputs["need_block_len"] = paddle.full(shape=[1], fill_value=0, dtype="int32") - self.inputs["used_list_len"] = paddle.full(shape=[config.batch_size], fill_value=0, dtype="int32") - self.inputs["free_list"] = paddle.to_tensor(free_list, dtype="int32") - self.inputs["free_list_len"] = paddle.full(shape=[1], fill_value=self.pre_max_block_num * 0.25, dtype="int32") - def _get_rotary_position_embedding(self, position_ids, head_dim, rope_theta=10000.0): """ Pre-calculate rotary position embedding for position_ids. @@ -961,12 +933,13 @@ def _get_rotary_position_embedding(self, position_ids, head_dim, rope_theta=1000 rot_emb[1] = paddle.sin(emb) return rot_emb - def _preprocess(self, source): + def _preprocess(self, input_text: list[str]): if self.tokenizer.chat_template is not None: - source = [source] if isinstance(source, str) else source - source = [self.tokenizer.apply_chat_template(sentence, tokenize=False) for sentence in source] + input_text = [input_text] if isinstance(input_text, str) else input_text + input_text = [self.tokenizer.apply_chat_template(sentence, tokenize=False) for sentence in input_text] - for i, text in enumerate(source): + input_ids = [] + for text in input_text: tokens = self.tokenizer( text, return_tensors="np", @@ -977,30 +950,42 @@ def _preprocess(self, source): add_special_tokens=self.tokenizer.chat_template is None or isinstance(self.tokenizer, (ChatGLMv2Tokenizer, ChatGLMTokenizer)), ) - input_ids = tokens["input_ids"][0] - length = len(input_ids) - self.inputs["input_ids"][i : i + 1, :length] = input_ids - self.inputs["penalty_score"][i : i + 1] = self.config.repetition_penalty - self.inputs["frequency_score"][i : i + 1] = 0.0 - self.inputs["presence_score"][i : i + 1] = 0.0 - self.inputs["top_p"][i : i + 1] = self.config.top_p - self.inputs["temperature"][i : i + 1] = self.config.temperature - self.inputs["seq_lens_this_time"][i : i + 1] = length - self.inputs["seq_lens_encoder"][i : i + 1] = length - self.inputs["seq_lens_decoder"][i : i + 1] = 0 - self.inputs["step_idx"][i : i + 1] = 0 - self.inputs["stop_flags"][i : i + 1] = False - self.inputs["not_need_stop"][0] = True - need_block_nums = ( - length + self.config.max_length + self.pre_cache_length + self.block_size - 1 - ) // self.block_size - for bi in range(need_block_nums): - bi_now = self.free_list.pop() - self.used_list[i].append(bi_now) - self.inputs["block_tables"][i : i + 1, bi] = bi_now - - -class DygraphBlockInferencePredictor(BlockInferencePredictorMixin, BasePredictor): + input_ids.append(tokens["input_ids"][0]) + + seq_lens = self.pad_batch_data(input_ids) + self.model_inputs["input_ids"] = self.input_ids + + self.model_inputs["block_tables"][:][:] = -1 + free_list = list(range(self.max_block_nums)) + for i in range(self.config.batch_size): + for j in range( + (seq_lens[i] + self.config.max_length + self.config.block_size - 1) // self.config.block_size + ): + used_block_id = free_list.pop() + self.model_inputs["block_tables"][i, j] = used_block_id + + self.model_inputs["seq_lens_this_time"] = paddle.to_tensor(np.array(seq_lens).astype("int32").reshape(-1, 1)) + self.model_inputs["seq_lens_encoder"] = paddle.to_tensor(np.array(seq_lens).astype("int32").reshape(-1, 1)) + self.model_inputs["seq_lens_decoder"] = paddle.full( + shape=[self.config.batch_size, 1], fill_value=0, dtype="int32" + ) + self.model_inputs["step_idx"] = paddle.full(shape=[self.config.batch_size, 1], fill_value=0, dtype="int64") + self.model_inputs["not_need_stop"] = paddle.full(shape=[1], fill_value=True, dtype="bool") + self.model_inputs["stop_flags"] = paddle.full( + shape=[self.config.batch_size, 1], fill_value=False, dtype="bool" + ) + self.model_inputs["stop_nums"] = paddle.full(shape=[1], fill_value=self.config.batch_size, dtype="int64") + self.model_inputs["pre_ids"] = paddle.full( + shape=[self.config.batch_size, self.config.max_length], fill_value=-1, dtype="int64" + ) + self.model_inputs["next_tokens"] = paddle.full(shape=[self.config.batch_size, 1], fill_value=-1, dtype="int64") + + if self.config.mode == "static": + for k, v in self.model_inputs.items(): + v.name = k + + +class DygraphBlockInferencePredictor(BlockInferencePredictorMixin): def __init__( self, config: PredictorArgument, @@ -1008,26 +993,23 @@ def __init__( tokenizer: PretrainedTokenizer = None, ): self.cache_kvs_shape = model.get_cache_kvs_shape(model.config, config.batch_size) - BasePredictor.__init__(self, config, tokenizer) BlockInferencePredictorMixin.__init__(self, config, tokenizer) - cachekv_dtype = self.dtype - if config.cachekv_int8_type is not None: - cachekv_dtype = "uint8" + cachekv_dtype = self.dtype if config.cachekv_int8_type is None else "uint8" self.cache_kvs = [paddle.zeros(shape, dtype=cachekv_dtype) for shape in self.cache_kvs_shape] self.model = model - self.init_inputs(config) + self.init_model_inputs(config) if config.export_precache: - self.inputs["pre_caches"] = self.pre_caches + self.model_inputs["pre_caches"] = self.pre_caches if config.cachekv_int8_type == "dynamic": - self.inputs["k_quant_scales"] = self.k_quant_scales - self.inputs["v_quant_scales"] = self.v_quant_scales - self.inputs["k_dequant_scales"] = self.k_dequant_scales - self.inputs["v_dequant_scales"] = self.v_dequant_scales + self.model_inputs["k_quant_scales"] = self.k_quant_scales + self.model_inputs["v_quant_scales"] = self.v_quant_scales + self.model_inputs["k_dequant_scales"] = self.k_dequant_scales + self.model_inputs["v_dequant_scales"] = self.v_dequant_scales - self.inputs["cache_kvs"] = self.cache_kvs + self.model_inputs["cache_kvs"] = self.cache_kvs @paddle.no_grad() def _infer(self, inputs: dict[str, paddle.Tensor]): @@ -1036,7 +1018,7 @@ def _infer(self, inputs: dict[str, paddle.Tensor]): ) @paddle.no_grad() - def predict(self, input_texts: str | list[str], return_tokens=False): + def predict(self, input_texts: list[str], return_tokens=False): self._preprocess(input_texts) result_queue = mp.Queue() @@ -1049,12 +1031,8 @@ def predict(self, input_texts: str | list[str], return_tokens=False): read_res_process = mp.Process(target=read_res, args=[self.model_name_or_path, tensor_queue, result_queue]) read_res_process.start() - while self.inputs["not_need_stop"]: - self._infer(self.inputs) - # reset free_list - for i in range(self.config.batch_size): - self.free_list.extend(self.used_list[i]) - self.used_list[i] = [] + while self.model_inputs["not_need_stop"]: + self._infer(self.model_inputs) outputs = [] output_tokens = [] @@ -1068,7 +1046,7 @@ def predict(self, input_texts: str | list[str], return_tokens=False): return outputs -class StaticBlockInferencePredictor(BlockInferencePredictorMixin, BasePredictor): +class StaticBlockInferencePredictor(BlockInferencePredictorMixin): def __init__( self, config: PredictorArgument, @@ -1076,39 +1054,31 @@ def __init__( tokenizer: PretrainedTokenizer = None, ): self.cache_kvs_shape = cache_kvs_shape - BasePredictor.__init__(self, config, tokenizer) BlockInferencePredictorMixin.__init__(self, config, tokenizer) - self.init_inputs(config) + self._create_predictor(config) + + self.init_model_inputs(config) if config.export_precache: for i in range(self.num_layers): - self.inputs["pre_caches_{}".format(i)] = self.pre_caches[i] + self.model_inputs["pre_caches_{}".format(i)] = self.pre_caches[i] - self.cache_kvs = {} - cachekv_dtype = config.dtype - if config.cachekv_int8_type is not None: - cachekv_dtype = "uint8" + cachekv_dtype = config.dtype if config.cachekv_int8_type is None else "uint8" for i in range(len(self.cache_kvs_shape) // 2): - self.cache_kvs["key_caches_{}".format(i)] = paddle.zeros( + self.model_inputs["key_caches_{}".format(i)] = paddle.zeros( self.cache_kvs_shape[2 * i], dtype=cachekv_dtype ) - self.cache_kvs["value_caches_{}".format(i)] = paddle.zeros( + self.model_inputs["value_caches_{}".format(i)] = paddle.zeros( self.cache_kvs_shape[2 * i + 1], dtype=cachekv_dtype ) for i in range(self.num_layers): if self.config.cachekv_int8_type == "dynamic": - self.inputs["k_quant_scales_" + str(i)] = self.k_quant_scales[i] - self.inputs["v_quant_scales_" + str(i)] = self.v_quant_scales[i] - self.inputs["k_dequant_scales_" + str(i)] = self.k_dequant_scales[i] - self.inputs["v_dequant_scales_" + str(i)] = self.v_dequant_scales[i] - - self._create_predictor(config) - self.input_names = self.predictor.get_input_names() - - self._share_data() - self.seq_lens_handle = self.predictor.get_input_handle("seq_lens_this_time") + self.model_inputs["k_quant_scales_" + str(i)] = self.k_quant_scales[i] + self.model_inputs["v_quant_scales_" + str(i)] = self.v_quant_scales[i] + self.model_inputs["k_dequant_scales_" + str(i)] = self.k_dequant_scales[i] + self.model_inputs["v_dequant_scales_" + str(i)] = self.v_dequant_scales[i] def _create_predictor(self, predictor_args: PredictorArgument): if not is_paddlenlp_ops_available(): @@ -1163,37 +1133,9 @@ def _create_predictor(self, predictor_args: PredictorArgument): self.predictor = paddle.inference.create_predictor(config) - def _share_data(self): - """ - Share external data for inference predictor. - """ - for name in self.input_names: - if "pre_key_" in name or "pre_value_" in name: - input_tensor = self.predictor.get_input_handle(name) - input_tensor.share_external_data(self.inputs[name]) - continue - if "caches" in name: - input_tensor = self.predictor.get_input_handle(name) - input_tensor.share_external_data(self.cache_kvs[name]) - continue - if "seq_lens_this_time" in name: - continue - input_tensor = self.predictor.get_input_handle(name) - input_tensor.share_external_data(self.inputs[name]) - - def _infer(self): - self.predictor.run() - - def predict(self, input_texts: str | list[str], return_tokens=False): - + def predict(self, input_texts: list[str], return_tokens=False): s_time = time.time() self._preprocess(input_texts) - real_bsz = len(input_texts) - - import copy - - seq_lens_this_time = copy.deepcopy(self.inputs["seq_lens_this_time"][:real_bsz]) - self.seq_lens_handle.share_external_data(seq_lens_this_time) logger.info(f"preprocess spend {time.time() - s_time}") result_queue = mp.Queue() @@ -1207,15 +1149,10 @@ def predict(self, input_texts: str | list[str], return_tokens=False): read_res_process.start() s_time = time.time() - while self.inputs["not_need_stop"]: - self.predictor.run() + while self.model_inputs["not_need_stop"]: + self.predictor.run(list(self.model_inputs.values())) logger.info(f"running spend {time.time() - s_time}") - # reset free_list - for i in range(self.config.batch_size): - self.free_list.extend(self.used_list[i]) - self.used_list[i] = [] - outputs = [] output_tokens = [] while len(outputs) < self.batch_size: @@ -1227,19 +1164,6 @@ def predict(self, input_texts: str | list[str], return_tokens=False): else: return outputs - def _preprocess(self, source): - BlockInferencePredictorMixin._preprocess(self, source) - for i, text in enumerate(source): - tokens = self.tokenizer( - text, return_tensors="np", padding=False, truncation=True, max_length=(self.config.src_length) - ) - input_ids = tokens["input_ids"][0] - length = len(input_ids) - need_block_nums = ( - length + self.config.max_length + self.pre_cache_length + self.block_size - 1 - ) // self.block_size - self.inputs["encoder_block_lens"][i : i + 1] = need_block_nums - def get_ptq_multicards_num(directory): count = 0 @@ -1268,38 +1192,6 @@ def create_predictor( config = AutoConfig.from_pretrained(predictor_args.model_name_or_path) - max_position_embeddings = get_model_max_position_embeddings(config) - if max_position_embeddings is None: - max_position_embeddings = 2048 - logger.warning("Can not retrieval `max_position_embeddings` from config.json, use default value 2048") - - if predictor_args.src_length is None: - if predictor_args.max_length is None: - predictor_args.src_length = get_default_max_encoding_length(config) - predictor_args.max_length = get_default_max_decoding_length(config) - else: - predictor_args.src_length = max_position_embeddings - predictor_args.max_length - if predictor_args.src_length <= 0: - raise ValueError( - f"--max_length<{predictor_args.max_length}> param should be smaller " - f"than max_position_embeddings<{max_position_embeddings}>" - ) - else: - if predictor_args.max_length is None: - predictor_args.max_length = max_position_embeddings - predictor_args.src_length - if predictor_args.max_length <= 0: - raise ValueError( - f"--src_length<{predictor_args.src_length}> param should be smaller " - f"than max_position_embeddings<{max_position_embeddings}>" - ) - else: - if predictor_args.src_length + predictor_args.max_length > max_position_embeddings: - raise ValueError( - f"The sum of src_length<{predictor_args.src_length}> and " - f"max_length<{predictor_args.max_length}> should be smaller than or equal to " - f"the maximum position embedding size<{max_position_embeddings}>" - ) - # update config parameter for inference predictor if predictor_args.decode_strategy == "greedy_search": predictor_args.top_p = 0.0 @@ -1362,8 +1254,10 @@ def create_predictor( config.avx_type = predictor_args.avx_type if config.quantization_config.quant_type is not None: + predictor_args.quant_type = config.quantization_config.quant_type config.quant_type = config.quantization_config.quant_type if "c8" in config.quant_type: + predictor_args.cachekv_int8_type = "static" config.cachekv_int8_type = "static" ptq_multicards_num = get_ptq_multicards_num(config.model_name_or_path) @@ -1631,7 +1525,7 @@ def predict(): target_texts.append("") else: - source_texts = ["你好,请问你是谁?"] * predictor_args.batch_size + source_texts = ["解释一下温故而知新"] * predictor_args.batch_size target_texts = [""] * predictor_args.batch_size batch_source_texts = batchfy_text(source_texts, predictor_args.batch_size) diff --git a/llm/utils/utils.py b/llm/utils/utils.py index a7839d79f5e1..7c5397d90cbd 100644 --- a/llm/utils/utils.py +++ b/llm/utils/utils.py @@ -34,7 +34,6 @@ AutoTokenizer, ChatGLMv2Tokenizer, LlamaForCausalLMPipe, - PretrainedConfig, Qwen2ForCausalLMPipe, ) from paddlenlp.transformers.tokenizer_utils import PretrainedTokenizer @@ -725,52 +724,6 @@ def init_chat_template( tokenizer.init_chat_template(chat_template_file) -def get_model_max_position_embeddings(config: PretrainedConfig) -> Optional[int]: - names = [ - "max_position_embeddings", # most of models - "max_sequence_length", # GLM model - "seq_length", # llama model - ] - for name in names: - max_length = config.get(name, None) - if max_length is not None: - return max_length - return None - - -def get_default_max_decoding_length(config: PretrainedConfig, default: int = 1024) -> int: - """get the default max decoding length from config. - - Args: - config (PretrainedConfig): the instance of PretrainedConfig - default (int): the default value of max decoding length - - Returns: - int: the default max_length of decoding length - """ - max_position_embeddings = get_model_max_position_embeddings(config) - if max_position_embeddings is None: - return default - return max_position_embeddings // 4 - - -def get_default_max_encoding_length(config: PretrainedConfig, default: int = 1024) -> int: - """get the default max encoding length from config. - - Args: - config (PretrainedConfig): the instance of PretrainedConfig - default (int): the default value of max encoding length - - Returns: - int: the default max_length of encoding length - """ - - max_position_embeddings = get_model_max_position_embeddings(config) - if max_position_embeddings is None: - return default - return max_position_embeddings // 4 * 3 - - def read_res(model_name_or_path: str, tensor_queue: mp.Queue, result_queue: mp.Queue): tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index a1b47c070f6d..bfd425293bdb 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -349,6 +349,7 @@ def __init__(self, config: LlamaConfig): self.quant_type = config.quant_type self.rope_theta = config.rope_theta + self.use_neox = True self.use_weight_only = False if config.quant_type == "weight_only_int8": @@ -562,7 +563,7 @@ def __init__(self, config: LlamaConfig): cache_v_out_scale_attrs=cache_v_out_scale_attrs, epsilon=self.epsilon, norm_type="rmsnorm", - use_neox_rotary_style=True, + use_neox_rotary_style=self.use_neox, cachekv_int8_type=config.cachekv_int8_type, rank_id=config.tensor_parallel_rank, trans_qkvw=(False if paddle.is_compiled_with_rocm() and self.quant_type == "a8w8" else True), @@ -683,7 +684,7 @@ def forward( from paddlenlp_ops import fused_get_rotary_embedding new_rope = fused_get_rotary_embedding( - input_ids, position_ids, self.head_dim_shape_tensor, position_offset, self.rope_theta, True + input_ids, position_ids, self.head_dim_shape_tensor, position_offset, self.rope_theta, self.use_neox ) with dy2st_nocheck_guard_context(): @@ -1009,7 +1010,7 @@ def set_state_dict(self, state_dict): ) if self.config.cachekv_int8_type == "static": - cache_scale_json_path = os.path.join(self.quant_model_path, "cachekv_act_scales.json") + cache_scale_json_path = os.path.join(self.quant_model_path, "cachekv_scales.json") if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq: cache_scale_json_path = os.path.join( self.quant_model_path, f"cachekv_act_scales_{self.config.tensor_parallel_rank}.json" diff --git a/paddlenlp/quantization/quantization_config.py b/paddlenlp/quantization/quantization_config.py index 0222aeb1ef3e..f5b04e188e15 100644 --- a/paddlenlp/quantization/quantization_config.py +++ b/paddlenlp/quantization/quantization_config.py @@ -64,9 +64,9 @@ def __init__( raise ValueError( f"weight_quantize_algo:{weight_quantize_algo} not in supported list ['weight_only_int8', 'weight_only_int4', 'llm.int8', 'a8w8', 'nf4', 'fp4']" ) - if quant_type is not None and quant_type not in ["weight_only_int8", "weight_only_int4", "a8w8"]: + if quant_type is not None and quant_type not in ["weight_only_int8", "weight_only_int4", "a8w8", "a8w8c8"]: raise ValueError( - f"quant_type:{quant_type} not in supported list ['weight_only_int8', 'weight_only_int4', 'a8w8']" + f"quant_type:{quant_type} not in supported list ['weight_only_int8', 'weight_only_int4', 'a8w8', 'a8w8c8']" ) self.weight_quantize_algo = weight_quantize_algo self.quant_type = quant_type From 6f9d819a85903ca7b0a2241398d6c53383de70b9 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Tue, 6 Aug 2024 14:54:16 +0000 Subject: [PATCH 2/7] update --- llm/predict/export_model.py | 5 ++- llm/predict/predictor.py | 19 +++++----- paddlenlp/experimental/model_utils.py | 26 ++++++++++---- .../transformers/llama/modeling.py | 36 +++++++++---------- 4 files changed, 50 insertions(+), 36 deletions(-) diff --git a/llm/predict/export_model.py b/llm/predict/export_model.py index 91b3bc4cb7c3..83dcc371427e 100644 --- a/llm/predict/export_model.py +++ b/llm/predict/export_model.py @@ -61,7 +61,10 @@ def main(): }, ) predictor.model.config.save_pretrained(export_args.output_path) - predictor.generation_config.save_pretrained(export_args.output_path) + if predictor.generation_config is not None: + predictor.generation_config.save_pretrained(export_args.output_path) + else: + predictor.model.generation_config.save_pretrained(export_args.output_path) predictor.tokenizer.save_pretrained(export_args.output_path) generate_rank_mapping(os.path.join(export_args.output_path, "rank_mapping.csv")) diff --git a/llm/predict/predictor.py b/llm/predict/predictor.py index 1841ca1eea58..10c634424bba 100644 --- a/llm/predict/predictor.py +++ b/llm/predict/predictor.py @@ -396,7 +396,7 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer): self.architectures = self.model_config.architectures[0].lower() - self.dtype = config.dtype or self.model_config + self.dtype = config.dtype or self.model_config.dtype self.pre_ids = paddle.full([config.batch_size, config.total_max_length], -1, dtype="int64") if config.device == "cpu" and config.avx_model: @@ -408,7 +408,6 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer): self.tgt_generation_mask = None self.tgt_pos = None else: - self.arange_tensor_encoder = paddle.arange(config.total_max_length, dtype=self.dtype) self.cache_kvs = [paddle.zeros(shape, dtype=self.dtype) for shape in self.cache_kvs_shape] self.num_layers, self.num_attention_heads, self.head_dim = ( len(self.cache_kvs), @@ -548,8 +547,8 @@ def _preprocess(self, source): # alibi encoder alibi_slopes = get_alibi_slopes(self.model_config.n_head) inputs["position_ids"] = paddle.to_tensor(alibi_slopes, dtype="float32") - - alibi = alibi_slopes[None, :, None, None] * self.arange_tensor_encoder + arange_tensor_encoder = paddle.arange(self.config.total_max_length, dtype=self.config.dtype) + alibi = alibi_slopes[None, :, None, None] * arange_tensor_encoder if self.model_config.tensor_parallel_degree > 1: block_size = self.model_config.n_head // self.model_config.tensor_parallel_degree @@ -773,8 +772,6 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer): self.dtype = config.dtype or self.model_config.dtype - self.block_size = config.block_size - try: self.rope_theta = self.model_config.rope_theta except: @@ -883,19 +880,21 @@ def init_model_inputs(self, config: PredictorArgument): # bloom model needs src_mask and tgt_mask! if "bloom" in self.architectures: - lower_one_tril = paddle.tril(paddle.ones(shape=(config.total_max_length, config.total_max_length), dtype=self.dtype)) + lower_one_tril = paddle.tril( + paddle.ones(shape=(config.total_max_length, config.total_max_length), dtype=self.dtype) + ) lower_one_tril = lower_one_tril[None, None, :, :] - self.model_inputs["src_mask"] = lower_one_tril.tile([self.batch_size, 1, 1, 1]) + self.model_inputs["src_mask"] = lower_one_tril.tile([config.batch_size, 1, 1, 1]) self.model_inputs["tgt_mask"] = paddle.full( shape=[config.batch_size, 1, 1, config.total_max_length], fill_value=1, dtype=self.dtype ) arange_tensor_encoder = paddle.arange(config.total_max_length).astype(self.dtype) alibi_slopes = get_alibi_slopes(self.num_attention_heads) alibi = alibi_slopes[None, :, None, None] * arange_tensor_encoder - alibi_encoder = alibi.tile([self.batch_size, 1, config.total_max_length, 1]) + alibi_encoder = alibi.tile([config.batch_size, 1, config.total_max_length, 1]) alibi_decoder = alibi.tile( [ - self.batch_size, + config.batch_size, 1, 1, 1, diff --git a/paddlenlp/experimental/model_utils.py b/paddlenlp/experimental/model_utils.py index b5a43eebd387..b187bb3700e9 100644 --- a/paddlenlp/experimental/model_utils.py +++ b/paddlenlp/experimental/model_utils.py @@ -391,7 +391,12 @@ def __init__( class CacheScaleLoader: def __init__( - self, scale_json_file_path="cache_scales.json", key_map_dict=None, num_of_layers=None, num_heads=None + self, + scale_json_file_path="cache_scales.json", + key_map_dict=None, + num_of_layers=None, + num_heads=None, + num_key_value_heads=None, ): with open(scale_json_file_path) as json_file: self.scale_dict = json.load(json_file) @@ -402,12 +407,21 @@ def __init__( scale_type_out = "cache_k_out_scale" else: scale_type_out = "cache_v_out_scale" - self.scale[scale_type] = np.full([num_of_layers, num_heads], fill_value=-1.0) - self.scale[scale_type_out] = np.full([num_of_layers, num_heads], fill_value=-1.0) + self.scale[scale_type] = np.full([num_of_layers, num_key_value_heads], fill_value=-1.0) + self.scale[scale_type_out] = np.full([num_of_layers, num_key_value_heads], fill_value=-1.0) for i in range(num_of_layers): if key_template.replace("#", str(i)) in self.scale_dict.keys(): - self.scale[scale_type][i, :] = [ - 127.0 / num for num in self.scale_dict[key_template.replace("#", str(i))] + if num_heads != num_key_value_heads: + self.scale[scale_type][i, :] = [ + 127.0 / self.scale_dict[key_template.replace("#", str(i))][j] + for j in range(0, num_heads, num_heads // num_key_value_heads) + ] + else: + self.scale[scale_type][i, :] = [ + 127.0 / self.scale_dict[key_template.replace("#", str(i))][j] + for j in range(0, num_key_value_heads) + ] + self.scale[scale_type_out][i, :] = [ + 1.0 / self.scale[scale_type][i, j] for j in range(0, num_key_value_heads) ] - self.scale[scale_type_out][i, :] = [1.0 / self.scale[scale_type][i, j] for j in range(num_heads)] diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index bfd425293bdb..b90355a31afc 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -801,11 +801,9 @@ def set_state_dict(self, state_dict): concated_ffn1_weight = np.concatenate( [unfused_state_dict["mlp.gate_proj.weight"], unfused_state_dict["mlp.up_proj.weight"]], axis=-1 ) - ffn1_weight_tensor = paddle.to_tensor(concated_ffn1_weight) - qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight) + qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight).cast(paddle.get_default_dtype()) if self.use_weight_only: - qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight) qkv_weight_tensor = paddle.transpose(qkv_weight_tensor, perm=[1, 0]) qkv_quanted_weight_tensor, qkv_weight_scale_tensor = weight_quantize( qkv_weight_tensor, algo=self.quant_algo @@ -817,11 +815,11 @@ def set_state_dict(self, state_dict): paddle.cast(paddle.to_tensor(concated_qkv_weight), "int8") ) else: - self.transformer_block.qkv_weights[idx].set_value( - qkv_weight_tensor.cast(self.transformer_block.qkv_weights[idx].dtype) - ) + self.transformer_block.qkv_weights[idx].set_value(qkv_weight_tensor) - linear_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.weight".format(idx)]) + linear_weight_tensor = paddle.to_tensor( + state_dict["llama.layers.{}.self_attn.o_proj.weight".format(idx)] + ).cast(paddle.get_default_dtype()) if self.use_weight_only: linear_quanted_weight_tensor, linear_weight_scale_tensor = weight_quantize( linear_weight_tensor, algo=self.quant_algo @@ -845,10 +843,9 @@ def set_state_dict(self, state_dict): ) ) else: - self.transformer_block.linear_weights[idx].set_value( - linear_weight_tensor.cast(self.transformer_block.linear_weights[idx].dtype) - ) + self.transformer_block.linear_weights[idx].set_value(linear_weight_tensor) + ffn1_weight_tensor = paddle.to_tensor(concated_ffn1_weight).cast(paddle.get_default_dtype()) if self.use_weight_only: ffn1_quanted_weight_tensor, ffn1_weight_scale_tensor = weight_quantize( ffn1_weight_tensor, algo=self.quant_algo @@ -865,11 +862,11 @@ def set_state_dict(self, state_dict): paddle.cast(paddle.to_tensor(concated_ffn1_weight).transpose((1, 0)), "int8") ) else: - self.transformer_block.ffn1_weights[idx].set_value( - ffn1_weight_tensor.cast(self.transformer_block.ffn1_weights[idx].dtype) - ) + self.transformer_block.ffn1_weights[idx].set_value(ffn1_weight_tensor) - ffn2_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.weight".format(idx)]) + ffn2_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.weight".format(idx)]).cast( + paddle.get_default_dtype() + ) if self.use_weight_only: ffn2_quanted_weight_tensor, ffn2_weight_scale_tensor = weight_quantize( ffn2_weight_tensor, algo=self.quant_algo @@ -893,9 +890,7 @@ def set_state_dict(self, state_dict): ) ) else: - self.transformer_block.ffn2_weights[idx].set_value( - ffn2_weight_tensor.cast(self.transformer_block.ffn2_weights[idx].dtype) - ) + self.transformer_block.ffn2_weights[idx].set_value(ffn2_weight_tensor) if "a8w8" in self.quant_type: if self.shift_smooth_all_linears: @@ -1019,7 +1014,8 @@ def set_state_dict(self, state_dict): cache_scale_json_path, cache_scale_map_dict, num_of_layers=self.config.num_hidden_layers, - num_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, + num_heads=self.num_attention_heads // self.config.tensor_parallel_degree, + num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, ) for k, v in cache_scales_loader.scale.items(): for i_layer, weight_scale in enumerate(v): @@ -1401,7 +1397,9 @@ def forward( @paddle.no_grad() def set_state_dict(self, state_dict): if "lm_head.weight" in state_dict: - self.lm_head.weight.set_value(state_dict["lm_head.weight"]) + self.lm_head.weight.set_value( + paddle.to_tensor(state_dict["lm_head.weight"]).cast(self.lm_head.weight.dtype) + ) self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()}) From 244802a32b5058c76b161d2939d28b2a35f7121e Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Wed, 7 Aug 2024 09:14:12 +0000 Subject: [PATCH 3/7] update --- csrc/generation/encode_rotary_qk.cu | 6 +++--- llm/predict/predictor.py | 10 ++++++++-- llm/utils/utils.py | 7 +++---- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/csrc/generation/encode_rotary_qk.cu b/csrc/generation/encode_rotary_qk.cu index d5f55a172592..3c0860feb1c3 100644 --- a/csrc/generation/encode_rotary_qk.cu +++ b/csrc/generation/encode_rotary_qk.cu @@ -111,6 +111,7 @@ void LaunchRotaryQK(const paddle::Tensor& q, auto cu_stream = q.stream(); dim3 grid(batch_size, head_num, seq_len * rotary_emb_dims); + dim3 grid_k(batch_size, kv_head_num, seq_len * rotary_emb_dims); const int last_dim = dim_head / rotary_emb_dims; auto getBlockSize = [](int dim) { if (dim > 256) { @@ -148,7 +149,6 @@ void LaunchRotaryQK(const paddle::Tensor& q, head_num, seq_len * rotary_emb_dims, last_dim); - dim3 grid_k(batch_size, kv_head_num, seq_len * rotary_emb_dims); RotaryKernel<<>>( k_data, cos_emb, @@ -172,7 +172,7 @@ void LaunchRotaryQK(const paddle::Tensor& q, head_num, seq_len * rotary_emb_dims, last_dim); - NeoXRotaryKernel<<>>( + NeoXRotaryKernel<<>>( k_data, cos_emb, sin_emb, @@ -180,7 +180,7 @@ void LaunchRotaryQK(const paddle::Tensor& q, k_out_data, rotary_emb_dims, batch_size, - head_num, + kv_head_num, seq_len * rotary_emb_dims, last_dim); } diff --git a/llm/predict/predictor.py b/llm/predict/predictor.py index 10c634424bba..16e2bb4ba825 100644 --- a/llm/predict/predictor.py +++ b/llm/predict/predictor.py @@ -1039,6 +1039,9 @@ def predict(self, input_texts: list[str], return_tokens=False): result = result_queue.get(timeout=1) outputs.append(result[-1]) output_tokens.append(result[-2]) + + read_res_process.terminate() + if return_tokens: return outputs, output_tokens else: @@ -1158,6 +1161,9 @@ def predict(self, input_texts: list[str], return_tokens=False): result = result_queue.get(timeout=1) outputs.append(result[-1]) output_tokens.append(result[-2]) + + read_res_process.terminate() + if return_tokens: return outputs, output_tokens else: @@ -1524,8 +1530,8 @@ def predict(): target_texts.append("") else: - source_texts = ["解释一下温故而知新"] * predictor_args.batch_size - target_texts = [""] * predictor_args.batch_size + source_texts = ["解释一下温故而知新", "解释一下温故而知新"] + target_texts = ["", ""] batch_source_texts = batchfy_text(source_texts, predictor_args.batch_size) batch_target_texts = batchfy_text(target_texts, predictor_args.batch_size) diff --git a/llm/utils/utils.py b/llm/utils/utils.py index 7c5397d90cbd..098248437298 100644 --- a/llm/utils/utils.py +++ b/llm/utils/utils.py @@ -725,11 +725,10 @@ def init_chat_template( def read_res(model_name_or_path: str, tensor_queue: mp.Queue, result_queue: mp.Queue): - tokenizer = AutoTokenizer.from_pretrained( - model_name_or_path, - ) + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) paddle.device.set_device("cpu") + paddle.disable_static() outputs = [] output_tensor = tensor_queue.get(timeout=1) @@ -746,7 +745,7 @@ def read_res(model_name_or_path: str, tensor_queue: mp.Queue, result_queue: mp.Q output_numpy = output_tensor[2 : bsz + 2].numpy() output_numpy[output_numpy == -1] = 2 outputs.append(output_numpy) - if output_tensor[0, 0] == -1: + if int(output_tensor[0, 0]) == -1: break output = np.concatenate(outputs, axis=1).tolist() seqs = tokenizer.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False) From a6bde288a0a97a5e7c49791e09793cfdc4035512 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Wed, 7 Aug 2024 17:08:06 +0000 Subject: [PATCH 4/7] support qwen2 bf16/wint8 --- llm/predict/predictor.py | 53 +- .../experimental/transformers/__init__.py | 1 + .../transformers/qwen2/__init__.py | 15 + .../transformers/qwen2/modeling.py | 895 ++++++++++++++++++ 4 files changed, 963 insertions(+), 1 deletion(-) create mode 100644 paddlenlp/experimental/transformers/qwen2/__init__.py create mode 100644 paddlenlp/experimental/transformers/qwen2/modeling.py diff --git a/llm/predict/predictor.py b/llm/predict/predictor.py index 16e2bb4ba825..a522dc60a852 100644 --- a/llm/predict/predictor.py +++ b/llm/predict/predictor.py @@ -547,9 +547,9 @@ def _preprocess(self, source): # alibi encoder alibi_slopes = get_alibi_slopes(self.model_config.n_head) inputs["position_ids"] = paddle.to_tensor(alibi_slopes, dtype="float32") + arange_tensor_encoder = paddle.arange(self.config.total_max_length, dtype=self.config.dtype) alibi = alibi_slopes[None, :, None, None] * arange_tensor_encoder - if self.model_config.tensor_parallel_degree > 1: block_size = self.model_config.n_head // self.model_config.tensor_parallel_degree alibi = alibi[ @@ -1380,6 +1380,32 @@ def create_predictor( dtype=predictor_args.dtype, ) model.eval() + elif "qwen2" in config.architectures[0].lower(): + if predictor_args.block_attn: + config.max_seq_len = predictor_args.total_max_length + config.block_size = predictor_args.block_size + from paddlenlp.experimental.transformers import ( + Qwen2ForCausalLMBlockInferenceModel as Qwen2InferenceModel, + ) + + model = Qwen2InferenceModel.from_pretrained( + predictor_args.model_name_or_path, + config=config, + dtype=predictor_args.dtype, + tensor_parallel_degree=tensor_parallel_degree, + tensor_parallel_rank=tensor_parallel_rank, + ) + else: + from paddlenlp.experimental.transformers import ( + Qwen2ForCausalLMInferenceModel as Qwen2InferenceModel, + ) + + model = Qwen2InferenceModel.from_pretrained( + predictor_args.model_name_or_path, + config=config, + dtype=predictor_args.dtype, + ) + model.eval() elif "qwen" in config.architectures[0].lower(): if model_args.model_type == "qwen-img2txt": # we use qwen for img2txt. @@ -1405,6 +1431,16 @@ def create_predictor( elif predictor_args.mode == "static": config = AutoConfig.from_pretrained(predictor_args.model_name_or_path) + config.quant_type = predictor_args.quant_type + config.cachekv_int8_type = predictor_args.cachekv_int8_type + + if config.quantization_config.quant_type is not None: + predictor_args.quant_type = config.quantization_config.quant_type + config.quant_type = config.quantization_config.quant_type + if "c8" in config.quant_type: + predictor_args.cachekv_int8_type = "static" + config.cachekv_int8_type = "static" + if "llama" in config.architectures[0].lower(): if predictor_args.block_attn: config.block_size = predictor_args.block_size @@ -1471,6 +1507,21 @@ def create_predictor( cache_kvs_shape = GPTForCausalLMInferenceModel.get_cache_kvs_shape( config, predictor_args.batch_size, predictor_args.total_max_length ) + elif "qwen2" in config.architectures[0].lower(): + if predictor_args.block_attn: + config.block_size = predictor_args.block_size + config.max_seq_len = predictor_args.total_max_length + from paddlenlp.experimental.transformers import ( + Qwen2ForCausalLMBlockInferenceModel as Qwen2InferenceModel, + ) + else: + from paddlenlp.experimental.transformers import ( + Qwen2ForCausalLMInferenceModel as Qwen2InferenceModel, + ) + cache_kvs_shape = Qwen2InferenceModel.get_cache_kvs_shape( + config, predictor_args.batch_size, predictor_args.total_max_length + ) + elif "qwen" in config.architectures[0].lower(): from paddlenlp.experimental.transformers import ( QWenForCausalLMInferenceModel, diff --git a/paddlenlp/experimental/transformers/__init__.py b/paddlenlp/experimental/transformers/__init__.py index 1c7c0e2c0077..cb5e927e9d73 100644 --- a/paddlenlp/experimental/transformers/__init__.py +++ b/paddlenlp/experimental/transformers/__init__.py @@ -20,3 +20,4 @@ from .llama import * from .opt import * from .qwen import * +from .qwen2 import * diff --git a/paddlenlp/experimental/transformers/qwen2/__init__.py b/paddlenlp/experimental/transformers/qwen2/__init__.py new file mode 100644 index 000000000000..0f0d00141b52 --- /dev/null +++ b/paddlenlp/experimental/transformers/qwen2/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .modeling import * diff --git a/paddlenlp/experimental/transformers/qwen2/modeling.py b/paddlenlp/experimental/transformers/qwen2/modeling.py new file mode 100644 index 000000000000..2d4d39ecbfff --- /dev/null +++ b/paddlenlp/experimental/transformers/qwen2/modeling.py @@ -0,0 +1,895 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from functools import partial + +import numpy as np +import paddle +from paddle import nn +from paddle.nn.quant import weight_quantize + +from paddlenlp.experimental.transformers.fused_transformer_layers import ( + FusedBlockMultiTransformer, + FusedBlockMultiTransformerWeightOnly, + FusedMultiTransformerBase, + FusedMultiTransformerConfig, + FusedMultiTransformerWeightOnly, +) +from paddlenlp.experimental.transformers.generation_utils import ( + GenerationBlockInferenceModel, + GenerationInferenceModel, +) +from paddlenlp.experimental.transformers.utils import infererence_model_from_pretrained +from paddlenlp.transformers import Qwen2Config, Qwen2PretrainedModel +from paddlenlp.transformers.model_outputs import ( # CausalLMOutputWithCrossAttentions, + BaseModelOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithPast, +) +from paddlenlp.transformers.model_utils import ( + dy2st_nocheck_guard_context, + register_base_model, +) +from paddlenlp.transformers.qwen2.modeling import Qwen2LMHead, Qwen2PretrainingCriterion +from paddlenlp.utils.log import logger + +__all__ = ["Qwen2ForCausalLMInferenceModel", "Qwen2ForCausalLMBlockInferenceModel"] + + +class FusedQwen2RMSNorm(nn.Layer): + def __init__(self, config): + super().__init__() + self.eps = config.rms_norm_eps + self.weight = paddle.create_parameter( + shape=[config.hidden_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Constant(1.0), + ) + + def forward(self, x): + result = paddle.incubate.nn.functional.fused_rms_norm(x, self.weight, None, self.eps, begin_norm_axis=1) + if isinstance(result, tuple): + return result[0] + return result + + +@register_base_model +class Qwen2InferenceModel(Qwen2PretrainedModel): + def __init__(self, config: Qwen2Config): + super(Qwen2PretrainedModel, self).__init__(config) + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.intermediate_size = config.intermediate_size + self.num_layers = config.num_hidden_layers + self.rms_norm_eps = config.rms_norm_eps + self.quant_type = config.quant_type + self.rope_theta = config.rope_theta + + self.use_neox = True + + self.use_weight_only = False + if config.quant_type == "weight_only_int8": + self.use_weight_only = True + self.quant_algo = "weight_only_int8" + elif config.quant_type == "weight_only_int4": + self.use_weight_only = True + self.quant_algo = "weight_only_int4" + + if self.use_weight_only: + assert ( + self.quant_type == "weight_only_int8" or self.quant_type == "weight_only_int4" + ), "Expected quant_type equal to 'weight_only_int8' or 'weight_only_int4', but received {}".format( + self.quant_type + ) + + self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size) + + ln_scale_attrs = [paddle.ParamAttr(name="fuseqwen2.{}.ln_scale".format(i)) for i in range(self.num_layers)] + qkv_weight_attrs = [ + paddle.ParamAttr( + name="fuseqwen2.{}.qkv_weight".format(i), initializer=paddle.nn.initializer.Constant(value=0) + ) + for i in range(self.num_layers) + ] + qkv_bias_attrs = [paddle.ParamAttr(name="fuseqwen2.{}.qkv_bias".format(i)) for i in range(self.num_layers)] + out_proj_weight_attrs = [ + paddle.ParamAttr( + name="fuseqwen2.{}.out_proj_weight".format(i), initializer=paddle.nn.initializer.Constant(value=0) + ) + for i in range(self.num_layers) + ] + ffn_ln_scale_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.ffn_ln_scale".format(i)) for i in range(self.num_layers) + ] + ffn1_weight_attrs = [ + paddle.ParamAttr( + name="fuseqwen2.{}.ffn1_weight".format(i), initializer=paddle.nn.initializer.Constant(value=0) + ) + for i in range(self.num_layers) + ] + ffn2_weight_attrs = [ + paddle.ParamAttr( + name="fuseqwen2.{}.ffn2_weight".format(i), initializer=paddle.nn.initializer.Constant(value=0) + ) + for i in range(self.num_layers) + ] + + qkv_weight_scale_attrs = None + out_proj_weight_scale_attrs = None + ffn1_weight_scale_attrs = None + ffn2_weight_scale_attrs = None + + qkv_out_scale_attrs = None + linear_out_scale_attrs = None + ffn1_out_scale_attrs = None + ffn2_out_scale_attrs = None + linear_shift_attrs = None + linear_smooth_attrs = None + ffn2_shift_attrs = None + ffn2_smooth_attrs = None + + ln_bias_attrs = None + out_proj_bias_attrs = None + ffn_ln_bias_attrs = None + ffn1_bias_attrs = None + ffn2_bias_attrs = None + + if self.use_weight_only: + qkv_weight_scale_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.qkv_weight_scale".format(i)) for i in range(self.num_layers) + ] + out_proj_weight_scale_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.out_proj_weight_scale".format(i)) for i in range(self.num_layers) + ] + ffn1_weight_scale_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.ffn1_weight_scale".format(i)) for i in range(self.num_layers) + ] + ffn2_weight_scale_attrs = [ + paddle.ParamAttr(name="fuseqwen2.{}.ffn2_weight_scale".format(i)) for i in range(self.num_layers) + ] + + cache_k_scale_attrs = None + cache_v_scale_attrs = None + cache_k_out_scale_attrs = None + cache_v_out_scale_attrs = None + + transformer_config = FusedMultiTransformerConfig( + embed_dim=self.hidden_size, + num_heads=self.num_attention_heads, + kv_num_heads=self.num_key_value_heads, + dim_feedforward=self.intermediate_size, + quant_type=self.quant_type, + activation="swiglu", + num_layers=config.num_hidden_layers, + nranks=config.tensor_parallel_degree, + ring_id=-1, + ln_scale_attrs=ln_scale_attrs, + qkv_weight_attrs=qkv_weight_attrs, + qkv_weight_scale_attrs=qkv_weight_scale_attrs, + linear_weight_attrs=out_proj_weight_attrs, + linear_weight_scale_attrs=out_proj_weight_scale_attrs, + ffn_ln_scale_attrs=ffn_ln_scale_attrs, + ffn1_weight_attrs=ffn1_weight_attrs, + ffn1_weight_scale_attrs=ffn1_weight_scale_attrs, + ffn2_weight_attrs=ffn2_weight_attrs, + ffn2_weight_scale_attrs=ffn2_weight_scale_attrs, + qkv_out_scale_attrs=qkv_out_scale_attrs, + linear_out_scale_attrs=linear_out_scale_attrs, + ffn1_out_scale_attrs=ffn1_out_scale_attrs, + ffn2_out_scale_attrs=ffn2_out_scale_attrs, + linear_shift_attrs=linear_shift_attrs, + linear_smooth_attrs=linear_smooth_attrs, + ffn2_shift_attrs=ffn2_shift_attrs, + ffn2_smooth_attrs=ffn2_smooth_attrs, + ln_bias_attrs=ln_bias_attrs, + qkv_bias_attrs=qkv_bias_attrs, + linear_bias_attrs=out_proj_bias_attrs, + ffn_ln_bias_attrs=ffn_ln_bias_attrs, + ffn1_bias_attrs=ffn1_bias_attrs, + ffn2_bias_attrs=ffn2_bias_attrs, + cache_k_scale_attrs=cache_k_scale_attrs, + cache_v_scale_attrs=cache_v_scale_attrs, + cache_k_out_scale_attrs=cache_k_out_scale_attrs, + cache_v_out_scale_attrs=cache_v_out_scale_attrs, + epsilon=self.rms_norm_eps, + norm_type="rmsnorm", + use_neox_rotary_style=self.use_neox, + cachekv_int8_type=config.cachekv_int8_type, + rank_id=config.tensor_parallel_rank, + ) + + self.set_transformer_block(transformer_config) + + self.norm = FusedQwen2RMSNorm(config) + + self.cache_kvs = None + self.head_dim_shape_tensor = paddle.ones((self.hidden_size // self.num_attention_heads), dtype="int8") + + def set_transformer_block(self, transformer_config): + if self.use_weight_only: + self.transformer_block = FusedMultiTransformerWeightOnly(transformer_config) + else: + self.transformer_block = FusedMultiTransformerBase(transformer_config) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @paddle.no_grad() + def set_state_dict(self, state_dict): + head_size = self.hidden_size // self.num_attention_heads + self.embed_tokens.weight.set_value( + paddle.to_tensor(state_dict["qwen2.embed_tokens.weight"]).cast(self.embed_tokens.weight.dtype) + ) + self.norm.weight.set_value(paddle.to_tensor(state_dict["qwen2.norm.weight"]).cast(self.norm.weight.dtype)) + + for idx in range(self.num_layers): + unfused_state_dict = {} + ln_scale = paddle.to_tensor(state_dict["qwen2.layers.{}.input_layernorm.weight".format(idx)]).cast( + self.transformer_block.ln_scales[idx].dtype + ) + self.transformer_block.ln_scales[idx].set_value(ln_scale) + + unfused_state_dict["qwen2.self_attn.q_proj.weight"] = state_dict[ + "qwen2.layers.{}.self_attn.q_proj.weight".format(idx) + ] + unfused_state_dict["qwen2.self_attn.k_proj.weight"] = state_dict[ + "qwen2.layers.{}.self_attn.k_proj.weight".format(idx) + ] + unfused_state_dict["qwen2.self_attn.v_proj.weight"] = state_dict[ + "qwen2.layers.{}.self_attn.v_proj.weight".format(idx) + ] + + concated_qkv_weight = ( + np.concatenate( + [ + unfused_state_dict["qwen2.self_attn.q_proj.weight"], + unfused_state_dict["qwen2.self_attn.k_proj.weight"], + unfused_state_dict["qwen2.self_attn.v_proj.weight"], + ], + axis=-1, + ) + .transpose(1, 0) + .reshape( + ( + self.num_attention_heads // self.config.tensor_parallel_degree + + 2 * self.num_key_value_heads // self.config.tensor_parallel_degree + ) + * (head_size), + self.hidden_size, + ) + ) + + qkv_weight = paddle.to_tensor(concated_qkv_weight).cast(paddle.get_default_dtype()) + if self.use_weight_only: + qkv_weight = paddle.transpose(qkv_weight, perm=[1, 0]) + qkv_quanted_weight, qkv_weight_scale = weight_quantize(qkv_weight, algo=self.quant_algo) + self.transformer_block.qkv_weights[idx].set_value(qkv_quanted_weight) + self.transformer_block.qkv_weights_scale[idx].set_value(qkv_weight_scale) + else: + self.transformer_block.qkv_weights[idx].set_value(qkv_weight) + + unfused_state_dict["qwen2.self_attn.q_proj.bias"] = state_dict[ + "qwen2.layers.{}.self_attn.q_proj.bias".format(idx) + ] + unfused_state_dict["qwen2.self_attn.k_proj.bias"] = state_dict[ + "qwen2.layers.{}.self_attn.k_proj.bias".format(idx) + ] + unfused_state_dict["qwen2.self_attn.v_proj.bias"] = state_dict[ + "qwen2.layers.{}.self_attn.v_proj.bias".format(idx) + ] + + concated_qkv_biases = np.concatenate( + [ + unfused_state_dict["qwen2.self_attn.q_proj.bias"], + unfused_state_dict["qwen2.self_attn.k_proj.bias"], + unfused_state_dict["qwen2.self_attn.v_proj.bias"], + ], + axis=-1, + ) + qkv_bias = paddle.to_tensor(concated_qkv_biases).cast(self.transformer_block.qkv_biases[idx].dtype) + self.transformer_block.qkv_biases[idx].set_value(qkv_bias) + + linear_weight = paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.weight".format(idx)]).cast(paddle.get_default_dtype()) + if self.use_weight_only: + linear_quanted_weight, linear_weight_scale = weight_quantize(linear_weight, algo=self.quant_algo) + self.transformer_block.linear_weights[idx].set_value(linear_quanted_weight) + self.transformer_block.linear_weights_scale[idx].set_value(linear_weight_scale) + else: + self.transformer_block.linear_weights[idx].set_value(linear_weight) + + ffn_ln_scale = paddle.to_tensor( + state_dict["qwen2.layers.{}.post_attention_layernorm.weight".format(idx)], + ).cast( + self.transformer_block.ffn_ln_scales[idx].dtype, + ) + self.transformer_block.ffn_ln_scales[idx].set_value(ffn_ln_scale) + + up_weight = paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.up_proj.weight".format(idx)]).cast( + paddle.get_default_dtype() + ) + gate_weight = paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.gate_proj.weight".format(idx)]).cast( + paddle.get_default_dtype() + ) + ffn1_weight = paddle.concat(x=[gate_weight, up_weight], axis=-1) + if self.use_weight_only: + ffn1_quanted_weight, ffn1_weight_scale = weight_quantize(ffn1_weight, algo=self.quant_algo) + self.transformer_block.ffn1_weights[idx].set_value(ffn1_quanted_weight) + self.transformer_block.ffn1_weights_scale[idx].set_value(ffn1_weight_scale) + else: + self.transformer_block.ffn1_weights[idx].set_value(ffn1_weight) + + ffn2_weight = paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.weight".format(idx)]).cast( + paddle.get_default_dtype() + ) + if self.use_weight_only: + ffn2_quanted_weight, ffn2_weight_scale = weight_quantize(ffn2_weight, algo=self.quant_algo) + self.transformer_block.ffn2_weights[idx].set_value(ffn2_quanted_weight) + self.transformer_block.ffn2_weights_scale[idx].set_value(ffn2_weight_scale) + else: + self.transformer_block.ffn2_weights[idx].set_value(ffn2_weight) + + def remove_padding(self, input_ids, seq_lens_this_time): + cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) + token_num = paddle.sum(seq_lens_this_time) + from paddlenlp_ops import get_padding_offset + + ids_remove_padding, cum_offsets, padding_offset = get_padding_offset( + input_ids, cum_offsets_now, token_num, seq_lens_this_time + ) + return ids_remove_padding, padding_offset, cum_offsets + + # This function is a little different from prepare_input_ids_for_generation in paddlenlp/transformers/generation/utils.py, + # it is used to generate fake input_ids according to inputs_embeds length. + @staticmethod + def prepare_input_ids_for_generation(bos_token_id, encoder_output=None): + batch_size = 1 + seq_len = 1 + if bos_token_id is None: + raise ValueError("`bos_token_id` should be defined when no " "`input_ids` are provided.") + if encoder_output is not None: + batch_size = encoder_output.shape[0] + seq_len = encoder_output.shape[1] + return paddle.full([batch_size, seq_len], bos_token_id, dtype="int64") + + def forward( + self, + input_ids=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=None, + cache_kvs=None, + pre_caches=None, + seq_len_encoder=None, + seq_len_decoder=None, + past_key_values=None, + output_attentions=False, + output_hidden_states=None, + return_dict=False, + **kwargs, + ): + # kwargs["cache"] is used used to distinguish between encoder and decoder phase. + past_key_values = kwargs.get("cache", None) + is_decoder = past_key_values is not None + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + # generate a fake input_ids according to inputs_embeds + # this is usually occurred in img2txt multimodal model when first enter into this forward function. + if input_ids is None and inputs_embeds is not None: + input_ids = self.prepare_input_ids_for_generation(self.config.bos_token_id, inputs_embeds) + if inputs_embeds is not None: + batch, seq_len, hidden_dim = inputs_embeds.shape + inputs_embeds = inputs_embeds.reshape([batch * seq_len, hidden_dim]) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if past_key_values is None: + past_key_values = tuple([None] * self.config.num_hidden_layers) + + if not is_decoder: + ids_remove_padding, padding_offset, cum_offsets = self.remove_padding(input_ids, seq_len_encoder) + else: + ids_remove_padding = input_ids + padding_offset = None + cum_offsets = None + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(ids_remove_padding) + + hidden_states = inputs_embeds + + # decoder layers + presents = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + seq_lens = seq_len_decoder if is_decoder else seq_len_encoder + + position_offset = 0 + if not is_decoder and pre_caches is not None: + position_offset = 128 + + from paddlenlp_ops import fused_get_rotary_embedding + + new_rope = fused_get_rotary_embedding( + input_ids, position_ids, self.head_dim_shape_tensor, position_offset, self.rope_theta, self.use_neox + ) + + with dy2st_nocheck_guard_context(): + hidden_states, _ = self.transformer_block( + input_ids, + hidden_states, + cum_offsets=cum_offsets, + padding_offset=padding_offset, + attn_mask=paddle.cast(attention_mask, dtype=hidden_states.dtype), + caches=cache_kvs, + pre_caches=pre_caches, + pre_caches_length=position_offset, + seq_lens=seq_lens, + rotary_embs=new_rope, + rotary_emb_dims=1, + time_step=paddle.increment(paddle.shape(attention_mask)[-1], -1) if is_decoder else None, + ) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class Qwen2ForCausalLMInferenceModel(GenerationInferenceModel, Qwen2PretrainedModel): + def __init__(self, config: Qwen2Config, **kwargs): + super(Qwen2ForCausalLMInferenceModel, self).__init__(config) + self.qwen2 = Qwen2InferenceModel(config) + if config.tie_word_embeddings: + self.lm_head = Qwen2LMHead(config, embedding_weights=self.qwen2.embed_tokens.weight, transpose_y=True) + self.tie_weights() + else: + self.lm_head = Qwen2LMHead(config) + self.criterion = Qwen2PretrainingCriterion(config) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): + return infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args, kwargs) + + @classmethod + def get_cache_kvs_shape( + cls, config: Qwen2Config, max_batch_size: int = None, max_length: int = None + ) -> list[list[int]]: + """get cache_kvs tensor for qwen model + + Args: + max_batch_size (int): the max batch size + max_length (int | None, optional): the max_length of cache_kvs. Defaults to None. + + Returns: + list[paddle.Tensor]: the list tensor shape for cache + """ + if max_length is None: + max_length = config.max_position_embeddings + + cache_kvs = [] + for _ in range(config.num_hidden_layers): + cache_kvs.append( + [ + 2, + max_batch_size, + config.num_key_value_heads // max(config.tensor_parallel_degree, 1), + max_length, + config.hidden_size // config.num_attention_heads, + ] + ) + return cache_kvs + + def prepare_inputs_for_generation( + self, + input_ids, + cache_kvs, + seq_len_encoder, + seq_len_decoder, + tgt_ids, + tgt_pos, + tgt_generation_mask, + **kwargs, + ): + position_ids = kwargs.get("position_ids", None) + attention_mask = kwargs.get("attention_mask", None) + cache = kwargs.get("cache", None) + pre_caches = kwargs.get("pre_caches", None) + inputs_embeds = kwargs.get("inputs_embeds", None) + if cache is not None: + input_ids = tgt_ids + position_ids = tgt_pos + attention_mask = (tgt_generation_mask - 1) * 1e4 + # make inputs_embeds be none in decoder phase. + # in forward function, it will be assigned according to input_ids. + inputs_embeds = None + else: + attention_mask = (attention_mask - 1) * 1e4 + model_inputs = { + "input_ids": input_ids, + "inputs_embeds": inputs_embeds, + "position_ids": position_ids, + "attention_mask": attention_mask, + "cache_kvs": cache_kvs, + "seq_len_encoder": seq_len_encoder, + "seq_len_decoder": seq_len_decoder, + "cache": cache, + "pre_caches": pre_caches, + } + return model_inputs + + def forward( + self, + input_ids, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=False, + cache=None, + cache_kvs=None, + pre_caches=None, + seq_len_encoder=None, + seq_len_decoder=None, + past_key_values=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.qwen2( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache=cache, + cache_kvs=cache_kvs, + pre_caches=pre_caches, + seq_len_encoder=seq_len_encoder, + seq_len_decoder=seq_len_decoder, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + # if labels is None,means we need full output, instead of tensor_parallel_output + # tensor_parallel_output is togather with ParallelCrossEntropy + tensor_parallel_output = ( + self.config.tensor_parallel_output and labels is not None and self.config.tensor_parallel_degree > 1 + ) + lm_logits = self.lm_head(hidden_states, tensor_parallel_output=tensor_parallel_output) + + loss = None + if labels is not None: + loss = self.criterion(lm_logits, labels) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @paddle.no_grad() + def set_state_dict(self, state_dict): + if "lm_head.weight" in state_dict: + lm_head_weight = paddle.to_tensor(state_dict["lm_head.weight"]).cast(self.lm_head.weight.dtype) + self.lm_head.weight.set_value(lm_head_weight) + self.qwen2.set_state_dict({k: state_dict[k] for k in state_dict.keys()}) + + +@register_base_model +class Qwen2BlockInferenceModel(Qwen2InferenceModel): + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.max_seq_len = config.max_seq_len + self.block_size = config.block_size + + def set_transformer_block(self, transformer_config): + if self.use_weight_only: + self.transformer_block = FusedBlockMultiTransformerWeightOnly(transformer_config) + else: + self.transformer_block = FusedBlockMultiTransformer(transformer_config) + + def remove_padding(self, input_ids, seq_lens_this_time): + cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time) + token_num = paddle.sum(seq_lens_this_time) + from paddlenlp_ops import get_padding_offset_v2 + + ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2( + input_ids, cum_offsets_now, token_num, seq_lens_this_time + ) + return ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k + + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + caches=None, + pre_caches=None, + output_attentions=False, + output_hidden_states=None, + return_dict=False, + **kwargs, + ): + + seq_lens_this_time = kwargs.get("seq_lens_this_time", None) + rope_emb = kwargs.get("rope_emb", None) + ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k = self.remove_padding( + input_ids, seq_lens_this_time + ) + kwargs["cu_seqlens_q"] = cu_seqlens_q + kwargs["cu_seqlens_k"] = cu_seqlens_k + kwargs["padding_offsets"] = padding_offset + kwargs["max_input_length"] = self.max_seq_len + + inputs_embeds = self.embed_tokens(ids_remove_padding) + + with dy2st_nocheck_guard_context(): + hidden_states, _ = self.transformer_block( + input_ids=input_ids, + src=inputs_embeds, + cum_offsets=cum_offsets, + attn_mask=attention_mask, + caches=caches, + pre_caches=pre_caches, + rotary_embs=rope_emb, + **kwargs, + ) + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + +class Qwen2ForCausalLMBlockInferenceModel(GenerationBlockInferenceModel, Qwen2PretrainedModel): + """ + Dynamic Batching for Qwen2 Model with pretraining tasks on top. + """ + + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.qwen2 = Qwen2BlockInferenceModel(config) + if config.tie_word_embeddings: + self.lm_head = Qwen2LMHead(config, embedding_weights=self.qwen2.embed_tokens.weight, transpose_y=True) + self.tie_weights() + else: + self.lm_head = Qwen2LMHead(config) + + @classmethod + def _get_tensor_parallel_mappings(cls, config: Qwen2Config, is_split=True): + + logger.info("Qwen2 inference model _get_tensor_parallel_mappings") + + from paddlenlp.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers): + final_actions = {} + + base_actions = { + "lm_head.weight": partial(fn, is_column=True), + # Row Linear + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), + } + + # Column Linear + if config.fuse_attention_qkv: + base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True) + else: + base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) + # if we have enough num_key_value_heads to split, then split it. + if config.num_key_value_heads % config.tensor_parallel_degree == 0: + base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) + + if config.fuse_attention_ffn: + base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial( + fn, is_column=True, is_naive_2fuse=True + ) + else: + base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True) + + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) + + return mappings + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): + return infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args, kwargs) + + @classmethod + def get_cache_kvs_shape( + cls, config: Qwen2Config, max_batch_size: int = None, max_length: int = None + ) -> list[list[int]]: + """get cache_kvs tensor for Qwen2 model + + Args: + max_batch_size (int): the max batch size + max_length (int | None, optional): the max_length of cache_kvs. Defaults to None. + + Returns: + list[paddle.Tensor]: the list tensor shape for cache + """ + max_block_per_seq = (config.max_seq_len + config.block_size - 1) // config.block_size + if max_batch_size == -1: + max_block_nums = None + else: + max_block_nums = max_batch_size * max_block_per_seq + + cache_kvs = [] + for _ in range(config.num_hidden_layers): + cache_kv_shape = [ + max_block_nums, + config.num_key_value_heads // max(config.tensor_parallel_degree, 1), + config.block_size, + config.hidden_size // config.num_attention_heads, + ] + cache_kvs.append(cache_kv_shape) + cache_kvs.append(cache_kv_shape) + return cache_kvs + + def prepare_inputs_for_generation(self, **kwargs): + # only last token for inputs_ids if cache is defined in kwargs + input_ids = kwargs["input_ids"] + src_mask = kwargs.get("src_mask", None) + block_tables = kwargs.get("block_tables", None) + + pre_caches = kwargs.get("pre_caches", None) + caches = kwargs.get("caches", None) + + rope_emb = kwargs["rope_emb"] + seq_lens_this_time = kwargs["seq_lens_this_time"] + seq_lens_encoder = kwargs["seq_lens_encoder"] + seq_lens_decoder = kwargs["seq_lens_decoder"] + k_quant_scales = kwargs.get("k_quant_scales", None) + v_quant_scales = kwargs.get("v_quant_scales", None) + k_dequant_scales = kwargs.get("k_dequant_scales", None) + v_dequant_scales = kwargs.get("v_dequant_scales", None) + model_inputs = { + "input_ids": input_ids, + "src_mask": src_mask, + "rope_emb": rope_emb, + "pre_caches": pre_caches, + "caches": caches, + "seq_lens_this_time": seq_lens_this_time, + "seq_lens_encoder": seq_lens_encoder, + "seq_lens_decoder": seq_lens_decoder, + "block_tables": block_tables, + "k_quant_scales": k_quant_scales, + "v_quant_scales": v_quant_scales, + "k_dequant_scales": k_dequant_scales, + "v_dequant_scales": v_dequant_scales, + } + return model_inputs + + def forward( + self, + input_ids, + src_mask=None, + pre_caches=None, + caches=None, + seq_lens_this_time=None, + seq_lens_encoder=None, + seq_lens_decoder=None, + rope_emb=None, + block_tables=None, + k_quant_scales=None, + v_quant_scales=None, + k_dequant_scales=None, + v_dequant_scales=None, + ): + outputs = self.qwen2( + input_ids, + src_mask=src_mask, + caches=caches, + rope_emb=rope_emb, + block_tables=block_tables, + pre_caches=pre_caches, + seq_lens_this_time=seq_lens_this_time, + seq_lens_encoder=seq_lens_encoder, + seq_lens_decoder=seq_lens_decoder, + k_quant_scales=k_quant_scales, + v_quant_scales=v_quant_scales, + k_dequant_scales=k_dequant_scales, + v_dequant_scales=v_dequant_scales, + ) + + hidden_states = outputs[0] + logits = self.lm_head( + hidden_states, + tensor_parallel_output=False, + ) + + return logits + + @paddle.no_grad() + def set_state_dict(self, state_dict): + if "lm_head.weight" in state_dict: + self.lm_head.weight.set_value( + paddle.to_tensor(state_dict["lm_head.weight"]).cast(self.lm_head.weight.dtype) + ) + self.qwen2.set_state_dict({k: state_dict[k] for k in state_dict.keys()}) From 1b063ac40e045c306404ac616b25ece3237b804e Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Thu, 8 Aug 2024 05:03:50 +0000 Subject: [PATCH 5/7] add qwen2 ptq map --- .../transformers/qwen2/modeling.py | 4 +++- .../transformers/qwen2/ptq_scales_map.json | 21 +++++++++++++++++++ .../qwen2/ptq_scales_map_shift_smooth.json | 21 +++++++++++++++++++ 3 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 paddlenlp/experimental/transformers/qwen2/ptq_scales_map.json create mode 100644 paddlenlp/experimental/transformers/qwen2/ptq_scales_map_shift_smooth.json diff --git a/paddlenlp/experimental/transformers/qwen2/modeling.py b/paddlenlp/experimental/transformers/qwen2/modeling.py index 2d4d39ecbfff..c512593b82fb 100644 --- a/paddlenlp/experimental/transformers/qwen2/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2/modeling.py @@ -308,7 +308,9 @@ def set_state_dict(self, state_dict): qkv_bias = paddle.to_tensor(concated_qkv_biases).cast(self.transformer_block.qkv_biases[idx].dtype) self.transformer_block.qkv_biases[idx].set_value(qkv_bias) - linear_weight = paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.weight".format(idx)]).cast(paddle.get_default_dtype()) + linear_weight = paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.weight".format(idx)]).cast( + paddle.get_default_dtype() + ) if self.use_weight_only: linear_quanted_weight, linear_weight_scale = weight_quantize(linear_weight, algo=self.quant_algo) self.transformer_block.linear_weights[idx].set_value(linear_quanted_weight) diff --git a/paddlenlp/experimental/transformers/qwen2/ptq_scales_map.json b/paddlenlp/experimental/transformers/qwen2/ptq_scales_map.json new file mode 100644 index 000000000000..a069eddb3681 --- /dev/null +++ b/paddlenlp/experimental/transformers/qwen2/ptq_scales_map.json @@ -0,0 +1,21 @@ +{ + "act_scale":{ + "qkv_in_scale": "qwen2.layers.#.self_attn.q_proj.activation_quanter", + "out_linear_in_scale": "qwen2.layers.#.self_attn.o_proj.activation_quanter", + "ffn1_in_scale": "qwen2.layers.#.mlp.gate_proj.activation_quanter", + "ffn2_in_scale": "qwen2.layers.#.mlp.down_proj.activation_quanter" + }, + "weight_scale":{ + "q_weight_scale":"qwen2.layers.#.self_attn.q_proj.weight_quanter", + "k_weight_scale":"qwen2.layers.#.self_attn.k_proj.weight_quanter", + "v_weight_scale":"qwen2.layers.#.self_attn.v_proj.weight_quanter", + "out_linear_weight_scale":"qwen2.layers.#.self_attn.o_proj.weight_quanter", + "ffn1_1_weight_scale":"qwen2.layers.#.mlp.gate_proj.weight_quanter", + "ffn1_2_weight_scale":"qwen2.layers.#.mlp.up_proj.weight_quanter", + "ffn2_weight_scale":"qwen2.layers.#.mlp.down_proj.weight_quanter" + }, + "cachekv_scale":{ + "cache_k_scale": "qwen2.layers.#.self_attn.cachek_matmul.activation_quanter", + "cache_v_scale": "qwen2.layers.#.self_attn.cachev_matmul.activation_quanter" + } + } \ No newline at end of file diff --git a/paddlenlp/experimental/transformers/qwen2/ptq_scales_map_shift_smooth.json b/paddlenlp/experimental/transformers/qwen2/ptq_scales_map_shift_smooth.json new file mode 100644 index 000000000000..af6a04229f56 --- /dev/null +++ b/paddlenlp/experimental/transformers/qwen2/ptq_scales_map_shift_smooth.json @@ -0,0 +1,21 @@ +{ + "act_scale":{ + "qkv_in_scale": "qwen2.layers.#.self_attn.q_proj.activation_quanter", + "out_linear_in_scale": "qwen2.layers.#.self_attn.o_proj.layer.activation_quanter", + "ffn1_in_scale": "qwen2.layers.#.mlp.gate_proj.activation_quanter", + "ffn2_in_scale": "qwen2.layers.#.mlp.down_proj.layer.activation_quanter" + }, + "weight_scale":{ + "q_weight_scale":"qwen2.layers.#.self_attn.q_proj.weight_quanter", + "k_weight_scale":"qwen2.layers.#.self_attn.k_proj.weight_quanter", + "v_weight_scale":"qwen2.layers.#.self_attn.v_proj.weight_quanter", + "out_linear_weight_scale":"qwen2.layers.#.self_attn.o_proj.layer.weight_quanter", + "ffn1_1_weight_scale":"qwen2.layers.#.mlp.gate_proj.weight_quanter", + "ffn1_2_weight_scale":"qwen2.layers.#.mlp.up_proj.weight_quanter", + "ffn2_weight_scale":"qwen2.layers.#.mlp.down_proj.layer.weight_quanter" + }, + "cachekv_scale":{ + "cache_k_scale": "qwen2.layers.#.self_attn.cachek_matmul.activation_quanter", + "cache_v_scale": "qwen2.layers.#.self_attn.cachev_matmul.activation_quanter" + } +} \ No newline at end of file From 6091ccb23fe5fe4c26736963547e3a130feee87e Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Mon, 12 Aug 2024 03:40:43 +0000 Subject: [PATCH 6/7] update --- llm/predict/predictor.py | 2 +- paddlenlp/experimental/transformers/qwen2/modeling.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/llm/predict/predictor.py b/llm/predict/predictor.py index ba52c11f8111..26b498187ed8 100644 --- a/llm/predict/predictor.py +++ b/llm/predict/predictor.py @@ -548,9 +548,9 @@ def _preprocess(self, source): # alibi encoder alibi_slopes = get_alibi_slopes(self.model_config.n_head) inputs["position_ids"] = paddle.to_tensor(alibi_slopes, dtype="float32") - arange_tensor_encoder = paddle.arange(self.config.total_max_length, dtype=self.config.dtype) alibi = alibi_slopes[None, :, None, None] * arange_tensor_encoder + if self.model_config.tensor_parallel_degree > 1: block_size = self.model_config.n_head // self.model_config.tensor_parallel_degree alibi = alibi[ diff --git a/paddlenlp/experimental/transformers/qwen2/modeling.py b/paddlenlp/experimental/transformers/qwen2/modeling.py index c512593b82fb..189552aff006 100644 --- a/paddlenlp/experimental/transformers/qwen2/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2/modeling.py @@ -1,6 +1,4 @@ # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From cc76b2545ad1e55386493771a427623fcb931c7d Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Mon, 12 Aug 2024 03:43:39 +0000 Subject: [PATCH 7/7] fix tune_cublaslt_gemm.cu --- csrc/generation/tune_cublaslt_gemm.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/generation/tune_cublaslt_gemm.cu b/csrc/generation/tune_cublaslt_gemm.cu index 74d5a8acea64..780f0b9ef30c 100644 --- a/csrc/generation/tune_cublaslt_gemm.cu +++ b/csrc/generation/tune_cublaslt_gemm.cu @@ -327,7 +327,7 @@ void FindAlgo(const cublasLtHandle_t& ltHandle, sizeof(customOption))); CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &k, sizeof(k))); - int splitK_val = 0; + int splitK_val = 1; uint32_t redScheme = CUBLASLT_REDUCTION_SCHEME_NONE; CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute( &algo, @@ -346,10 +346,10 @@ void FindAlgo(const cublasLtHandle_t& ltHandle, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &splitKSequenceA[l - 1], sizeof(splitKSequenceA[l - 1]))); - for (redScheme = 0; + for (redScheme = 1; redScheme < (int)CUBLASLT_REDUCTION_SCHEME_MASK && (AlgoCount < AlgoCombinations); - redScheme++) { + redScheme <<= 1) { CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,