From fb8f2bec6a42f633676c03c0f0c595999b093092 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Fri, 29 Dec 2023 20:27:47 +0800 Subject: [PATCH] [bloom] Add kv cache support for flash attention & fix bugs (#7735) * Add kv cache support for flash attention * Update chatglm flash attention version check * Add test for flash attention * Fix unitest bug * Add flash attention to predictor * Add flash attention2 * Add flash attention unitests * fix prefix decoder * remove unused comments * Update unitest * Update unitest --- llm/predictor.py | 10 ++++- paddlenlp/data/data_collator.py | 1 - paddlenlp/peft/prefix/utils.py | 4 +- paddlenlp/transformers/bloom/modeling.py | 51 ++++++++-------------- paddlenlp/transformers/chatglm/modeling.py | 30 +++++-------- tests/fixtures/llm/finetune.yaml | 1 + tests/fixtures/llm/predictor.yaml | 1 + tests/fixtures/llm/prefix_tuning.yaml | 1 + tests/llm/test_predictor.py | 24 ++++++++++ tests/llm/test_prefix_tuning.py | 1 - 10 files changed, 67 insertions(+), 57 deletions(-) diff --git a/llm/predictor.py b/llm/predictor.py index b432e79c84d5..19fd6e8d840d 100644 --- a/llm/predictor.py +++ b/llm/predictor.py @@ -75,6 +75,10 @@ class PredictorArgument: "help": "the decoding strategy of generation, which should be one of ['sampling', 'greedy_search', 'beam_search']. Default to sampling" }, ) + use_flash_attention: bool = field( + default=False, + metadata={"help": "Whether to use flash attention"}, + ) mode: str = field( default="dynamic", metadata={"help": "the type of predictor, it should be one of [dynamic, static]"} @@ -241,6 +245,7 @@ def __init__( if self.model is None: self.model = AutoModelForCausalLM.from_pretrained( config.model_name_or_path, + use_flash_attention=config.use_flash_attention, dtype=dtype, tensor_parallel_degree=self.tensor_parallel_degree, tensor_parallel_rank=self.tensor_parallel_rank, @@ -685,7 +690,9 @@ def create_predictor( tensor_parallel_degree: int = 1, tensor_parallel_rank: int = 0, ): - tokenizer = AutoTokenizer.from_pretrained(predictor_args.model_name_or_path) + tokenizer = AutoTokenizer.from_pretrained( + predictor_args.model_name_or_path, + ) # init chat_template for tokenizer init_chat_template(tokenizer, predictor_args.model_name_or_path, predictor_args.chat_template) @@ -727,6 +734,7 @@ def create_predictor( model = AutoModelForCausalLM.from_pretrained( predictor_args.model_name_or_path, dtype=predictor_args.dtype, + use_flash_attention=predictor_args.use_flash_attention, tensor_parallel_degree=tensor_parallel_degree, tensor_parallel_rank=tensor_parallel_rank, ) diff --git a/paddlenlp/data/data_collator.py b/paddlenlp/data/data_collator.py index 9514f1225b2d..aecd186a91e4 100644 --- a/paddlenlp/data/data_collator.py +++ b/paddlenlp/data/data_collator.py @@ -405,7 +405,6 @@ def __call__(self, features, return_tensors=None): return_tensors=return_tensors, return_attention_mask=self.return_attention_mask, ) - # prepare decoder_input_ids if ( labels is not None diff --git a/paddlenlp/peft/prefix/utils.py b/paddlenlp/peft/prefix/utils.py index b42412d8a446..684245c2f380 100644 --- a/paddlenlp/peft/prefix/utils.py +++ b/paddlenlp/peft/prefix/utils.py @@ -17,10 +17,10 @@ def bloom_postprocess_past_key_value(past_key_values): # (layer_num, bs, head_num/tensor_parallel_degree, prefixlen, head_dim)*2 - past_key_values = paddle.transpose(past_key_values, perm=[2, 0, 3, 1, 4]).split(2) + keys, values = paddle.transpose(past_key_values, perm=[2, 0, 1, 3, 4]).split(2) # keys: [layer_num, bs, head_num/tensor_parallel_degree, head_dim, prefixlen] # value: [layer_num, bs, head_num/tensor_parallel_degree, prefixlen, head_dim] - keys, values = past_key_values[0].transpose([0, 1, 2, 4, 3]), past_key_values[1] + # keys, values = past_key_values[0].transpose([0, 1, 2, 4, 3]), past_key_values[1] return tuple(zip(keys, values)) diff --git a/paddlenlp/transformers/bloom/modeling.py b/paddlenlp/transformers/bloom/modeling.py index 17fb6f80b677..1f7ad2d299f9 100644 --- a/paddlenlp/transformers/bloom/modeling.py +++ b/paddlenlp/transformers/bloom/modeling.py @@ -378,9 +378,23 @@ def forward( (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) batch_size, q_length, _, _ = query_layer.shape + + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size, kv_length, self.num_heads, head_dim] + # - value: [batch_size, kv_length, self.num_heads, head_dim] + key_layer = paddle.concat((past_key, key_layer), axis=1) + value_layer = paddle.concat((past_value, value_layer), axis=1) + + if use_cache is True: + present = (key_layer, value_layer) + else: + present = None + version = paddle.version.full_version version_check = True - if version != "0.0.0" and version <= "2.5.2": + if self.config.use_flash_attention and version != "0.0.0" and version <= "2.5.2": logger.warning( "PaddlePaddle version 2.5.3 or higher is required, please upgrade your PaddlePaddle to 2.5.3 or other higher version." ) @@ -397,6 +411,8 @@ def forward( key_states, value_states, attn_mask=attention_mask, + dropout_p=self.config.attention_dropout, + training=self.training, is_causal=False, ) attn_weights = None @@ -404,39 +420,10 @@ def forward( attn_output = attn_output.reshape([attn_output.shape[0], attn_output.shape[1], -1]) output_tensor = self.dense(attn_output) - query_layer = query_layer.transpose([0, 2, 1, 3]) - key_layer = key_layer.transpose([0, 2, 3, 1]) - value_layer = value_layer.transpose([0, 2, 1, 3]) - if layer_past is not None: - past_key, past_value = layer_past - # concatenate along seq_length dimension: - # - key: [batch_size, self.num_heads, head_dim, kv_length] - # - value: [batch_size, self.num_heads, kv_length, head_dim] - key_layer = paddle.concat((past_key, key_layer), axis=3) - value_layer = paddle.concat((past_value, value_layer), axis=2) - - if use_cache: - present = (key_layer, value_layer) - else: - present = None else: - query_layer = query_layer.transpose([0, 2, 1, 3]) key_layer = key_layer.transpose([0, 2, 3, 1]) value_layer = value_layer.transpose([0, 2, 1, 3]) - if layer_past is not None: - past_key, past_value = layer_past - # concatenate along seq_length dimension: - # - key: [batch_size, self.num_heads, head_dim, kv_length] - # - value: [batch_size, self.num_heads, kv_length, head_dim] - key_layer = paddle.concat((past_key, key_layer), axis=3) - value_layer = paddle.concat((past_value, value_layer), axis=2) - - if use_cache is True: - present = (key_layer, value_layer) - else: - present = None - _, _, _, kv_length = key_layer.shape query_layer = query_layer.reshape([batch_size * self.num_heads, q_length, self.head_dim]) @@ -449,7 +436,6 @@ def forward( attention_scores = baddbmm( alibi, batch1=query_layer, batch2=key_layer, beta=self.beta, alpha=self.inv_norm_factor ) - # change view to [batch_size, num_heads, q_length, kv_length] # attention_scores = matmul_result.reshape([batch_size, self.num_heads, q_length, kv_length]) @@ -949,14 +935,13 @@ def forward( seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[3] + past_key_values_length = past_key_values[0][0].shape[1] seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: attention_mask = paddle.ones([batch_size, seq_length_with_past], dtype="bool") elif attention_mask.dtype != paddle.bool: attention_mask = paddle.cast(attention_mask, "bool") - if len(attention_mask.shape) > 2: _attention_mask = paddle.ones([batch_size, seq_length_with_past], dtype="bool") alibi = build_alibi_tensor(_attention_mask, self.config.n_head, dtype=hidden_states.dtype) diff --git a/paddlenlp/transformers/chatglm/modeling.py b/paddlenlp/transformers/chatglm/modeling.py index 1fb1bcabf45b..dd103d070642 100644 --- a/paddlenlp/transformers/chatglm/modeling.py +++ b/paddlenlp/transformers/chatglm/modeling.py @@ -245,9 +245,19 @@ def forward( q_layer, k_layer, v_layer = paddle.split(mixed_layer, 3, axis=-1) # [s, b, n, h/n] q_layer, k_layer = self._core_attention(q_layer, k_layer, position_ids, rotary_embeds) + + if cache is not None: + cache_k, cache_v = cache[0], cache[1] + # [s + c, b, n, h/n] + k_layer = paddle.concat([cache_k, k_layer], axis=0) + v_layer = paddle.concat([cache_v, v_layer], axis=0) + + cache_kv = None + if use_cache: + cache_kv = (k_layer, v_layer) version = paddle.version.full_version version_check = True - if version != "0.0.0" and version <= "2.5.2": + if self.config.use_flash_attention and version != "0.0.0" and version <= "2.5.2": logger.warning( "PaddlePaddle version 2.5.3 or higher is required, please upgrade your PaddlePaddle to 2.5.3 or other higher version." ) @@ -255,15 +265,6 @@ def forward( if self.config.use_flash_attention and version_check: # Paddle Flash Attention input [ bz, seqlen, nhead, head_dim] # Torch Flash Attention input [ bz, nhead, seqlen, head_dim] - if cache is not None: - cache_k, cache_v = cache[0], cache[1] - # [s + c, b, n, h/n] - k_layer = paddle.concat([cache_k, k_layer], axis=0) - v_layer = paddle.concat([cache_v, v_layer], axis=0) - cache_kv = None - if use_cache: - cache_kv = (k_layer, v_layer) - # [s, b, n, h/n] = > [batch_size, seq_len, num_heads, head_dim] q_layer = paddle.transpose(q_layer, [1, 0, 2, 3]) k_layer = paddle.transpose(k_layer, [1, 0, 2, 3]) @@ -286,18 +287,9 @@ def forward( output, attention_probs = attn_output, attn_weights else: - if cache is not None: - cache_k, cache_v = cache[0], cache[1] - # [s + c, b, n, h/n] - k_layer = paddle.concat([cache_k, k_layer], axis=0) - v_layer = paddle.concat([cache_v, v_layer], axis=0) seq_length, batch_size, num_heads, hidden_size = k_layer.shape - cache_kv = None - if use_cache: - cache_kv = (k_layer, v_layer) - attention_scale_coeff = float(layer_id) + 1.0 if self.attention_scale: # [s, b, n, h/n] diff --git a/tests/fixtures/llm/finetune.yaml b/tests/fixtures/llm/finetune.yaml index bc9c1d73fd41..0e08f04d4c59 100644 --- a/tests/fixtures/llm/finetune.yaml +++ b/tests/fixtures/llm/finetune.yaml @@ -17,6 +17,7 @@ finetune: fp16_opt_level: "O2" do_train: true do_eval: true + use_flash_attention: true disable_tqdm: true load_best_model_at_end: true eval_with_do_generation: false diff --git a/tests/fixtures/llm/predictor.yaml b/tests/fixtures/llm/predictor.yaml index f251443b7e9e..5c08c8f28f83 100644 --- a/tests/fixtures/llm/predictor.yaml +++ b/tests/fixtures/llm/predictor.yaml @@ -3,6 +3,7 @@ inference-predict: mode: dynamic max_length: 40 batch_size: 2 + use_flash_attention: false decode_strategy: greedy_search dtype: float16 data_file: tests/fixtures/llm/data/train.json diff --git a/tests/fixtures/llm/prefix_tuning.yaml b/tests/fixtures/llm/prefix_tuning.yaml index cd4c96e8c17a..e8d17f867e99 100644 --- a/tests/fixtures/llm/prefix_tuning.yaml +++ b/tests/fixtures/llm/prefix_tuning.yaml @@ -17,6 +17,7 @@ prefix_tuning: do_train: true do_eval: true disable_tqdm: true + use_flash_attention: false load_best_model_at_end: true eval_with_do_generation: false metric_for_best_model: "accuracy" diff --git a/tests/llm/test_predictor.py b/tests/llm/test_predictor.py index 6878d373e5cb..41fc190566fc 100644 --- a/tests/llm/test_predictor.py +++ b/tests/llm/test_predictor.py @@ -77,6 +77,30 @@ def test_predictor(self): else: self.assertGreaterEqual(count / len(result_0), 0.4) + def test_flash_attention(self): + self.run_predictor({"inference_model": False, "use_flash_attention": False}) + result_0 = self._read_result(os.path.join(self.output_dir, "predict.json")) + + self.run_predictor({"inference_model": False, "use_flash_attention": True}) + result_1 = self._read_result(os.path.join(self.output_dir, "predict.json")) + + # compare the generation result of dygraph & flash attention model + assert len(result_0) == len(result_1) + + count, full_match = 0, 0 + for inference_item, no_inference_item in zip(result_0, result_1): + if self.model_name_or_path == "__internal_testing__/tiny-random-llama": + min_length = 5 + else: + min_length = min(len(inference_item), len(no_inference_item)) + count += int(inference_item[: min_length // 2] == no_inference_item[: min_length // 2]) + full_match += int(inference_item[:min_length] == no_inference_item[:min_length]) + + if self.model_name_or_path == "__internal_testing__/tiny-random-llama": + self.assertGreaterEqual(count / len(result_0), 0.2) + else: + self.assertEqual(full_match / len(result_0), 1.0) + def test_wint8(self): self.run_predictor({"inference_model": True, "quant_type": "weight_only_int8"}) result_0 = self._read_result(os.path.join(self.output_dir, "predict.json")) diff --git a/tests/llm/test_prefix_tuning.py b/tests/llm/test_prefix_tuning.py index c649dfd60059..9074c564bd6b 100644 --- a/tests/llm/test_prefix_tuning.py +++ b/tests/llm/test_prefix_tuning.py @@ -55,7 +55,6 @@ def test_prefix_tuning(self): prefix_tuning_config["dataset_name_or_path"] = self.data_dir prefix_tuning_config["output_dir"] = self.output_dir - with argv_context_guard(prefix_tuning_config): from finetune_generation import main