Skip to content

Commit

Permalink
[bloom] Add kv cache support for flash attention & fix bugs (#7735)
Browse files Browse the repository at this point in the history
* Add kv cache support for flash attention

* Update chatglm flash attention version check

* Add test for flash attention

* Fix unitest bug

* Add flash attention to predictor

* Add flash attention2

* Add flash attention unitests

* fix prefix decoder

* remove unused comments

* Update unitest

* Update unitest
  • Loading branch information
w5688414 committed Dec 29, 2023
1 parent fda20a7 commit fb8f2be
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 57 deletions.
10 changes: 9 additions & 1 deletion llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ class PredictorArgument:
"help": "the decoding strategy of generation, which should be one of ['sampling', 'greedy_search', 'beam_search']. Default to sampling"
},
)
use_flash_attention: bool = field(
default=False,
metadata={"help": "Whether to use flash attention"},
)

mode: str = field(
default="dynamic", metadata={"help": "the type of predictor, it should be one of [dynamic, static]"}
Expand Down Expand Up @@ -241,6 +245,7 @@ def __init__(
if self.model is None:
self.model = AutoModelForCausalLM.from_pretrained(
config.model_name_or_path,
use_flash_attention=config.use_flash_attention,
dtype=dtype,
tensor_parallel_degree=self.tensor_parallel_degree,
tensor_parallel_rank=self.tensor_parallel_rank,
Expand Down Expand Up @@ -685,7 +690,9 @@ def create_predictor(
tensor_parallel_degree: int = 1,
tensor_parallel_rank: int = 0,
):
tokenizer = AutoTokenizer.from_pretrained(predictor_args.model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(
predictor_args.model_name_or_path,
)
# init chat_template for tokenizer
init_chat_template(tokenizer, predictor_args.model_name_or_path, predictor_args.chat_template)

Expand Down Expand Up @@ -727,6 +734,7 @@ def create_predictor(
model = AutoModelForCausalLM.from_pretrained(
predictor_args.model_name_or_path,
dtype=predictor_args.dtype,
use_flash_attention=predictor_args.use_flash_attention,
tensor_parallel_degree=tensor_parallel_degree,
tensor_parallel_rank=tensor_parallel_rank,
)
Expand Down
1 change: 0 additions & 1 deletion paddlenlp/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,6 @@ def __call__(self, features, return_tensors=None):
return_tensors=return_tensors,
return_attention_mask=self.return_attention_mask,
)

# prepare decoder_input_ids
if (
labels is not None
Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/peft/prefix/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

def bloom_postprocess_past_key_value(past_key_values):
# (layer_num, bs, head_num/tensor_parallel_degree, prefixlen, head_dim)*2
past_key_values = paddle.transpose(past_key_values, perm=[2, 0, 3, 1, 4]).split(2)
keys, values = paddle.transpose(past_key_values, perm=[2, 0, 1, 3, 4]).split(2)
# keys: [layer_num, bs, head_num/tensor_parallel_degree, head_dim, prefixlen]
# value: [layer_num, bs, head_num/tensor_parallel_degree, prefixlen, head_dim]
keys, values = past_key_values[0].transpose([0, 1, 2, 4, 3]), past_key_values[1]
# keys, values = past_key_values[0].transpose([0, 1, 2, 4, 3]), past_key_values[1]
return tuple(zip(keys, values))


Expand Down
51 changes: 18 additions & 33 deletions paddlenlp/transformers/bloom/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,23 @@ def forward(
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)

batch_size, q_length, _, _ = query_layer.shape

if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size, kv_length, self.num_heads, head_dim]
# - value: [batch_size, kv_length, self.num_heads, head_dim]
key_layer = paddle.concat((past_key, key_layer), axis=1)
value_layer = paddle.concat((past_value, value_layer), axis=1)

if use_cache is True:
present = (key_layer, value_layer)
else:
present = None

version = paddle.version.full_version
version_check = True
if version != "0.0.0" and version <= "2.5.2":
if self.config.use_flash_attention and version != "0.0.0" and version <= "2.5.2":
logger.warning(
"PaddlePaddle version 2.5.3 or higher is required, please upgrade your PaddlePaddle to 2.5.3 or other higher version."
)
Expand All @@ -397,46 +411,19 @@ def forward(
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.config.attention_dropout,
training=self.training,
is_causal=False,
)
attn_weights = None
# [batch_size, seq_len, num_heads, head_dim] = > [batch_size, seq_len, hidden_size]
attn_output = attn_output.reshape([attn_output.shape[0], attn_output.shape[1], -1])
output_tensor = self.dense(attn_output)

query_layer = query_layer.transpose([0, 2, 1, 3])
key_layer = key_layer.transpose([0, 2, 3, 1])
value_layer = value_layer.transpose([0, 2, 1, 3])
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size, self.num_heads, head_dim, kv_length]
# - value: [batch_size, self.num_heads, kv_length, head_dim]
key_layer = paddle.concat((past_key, key_layer), axis=3)
value_layer = paddle.concat((past_value, value_layer), axis=2)

if use_cache:
present = (key_layer, value_layer)
else:
present = None
else:

query_layer = query_layer.transpose([0, 2, 1, 3])
key_layer = key_layer.transpose([0, 2, 3, 1])
value_layer = value_layer.transpose([0, 2, 1, 3])
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size, self.num_heads, head_dim, kv_length]
# - value: [batch_size, self.num_heads, kv_length, head_dim]
key_layer = paddle.concat((past_key, key_layer), axis=3)
value_layer = paddle.concat((past_value, value_layer), axis=2)

if use_cache is True:
present = (key_layer, value_layer)
else:
present = None

_, _, _, kv_length = key_layer.shape

query_layer = query_layer.reshape([batch_size * self.num_heads, q_length, self.head_dim])
Expand All @@ -449,7 +436,6 @@ def forward(
attention_scores = baddbmm(
alibi, batch1=query_layer, batch2=key_layer, beta=self.beta, alpha=self.inv_norm_factor
)

# change view to [batch_size, num_heads, q_length, kv_length]
# attention_scores = matmul_result.reshape([batch_size, self.num_heads, q_length, kv_length])

Expand Down Expand Up @@ -949,14 +935,13 @@ def forward(
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[3]
past_key_values_length = past_key_values[0][0].shape[1]
seq_length_with_past = seq_length_with_past + past_key_values_length

if attention_mask is None:
attention_mask = paddle.ones([batch_size, seq_length_with_past], dtype="bool")
elif attention_mask.dtype != paddle.bool:
attention_mask = paddle.cast(attention_mask, "bool")

if len(attention_mask.shape) > 2:
_attention_mask = paddle.ones([batch_size, seq_length_with_past], dtype="bool")
alibi = build_alibi_tensor(_attention_mask, self.config.n_head, dtype=hidden_states.dtype)
Expand Down
30 changes: 11 additions & 19 deletions paddlenlp/transformers/chatglm/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,25 +245,26 @@ def forward(
q_layer, k_layer, v_layer = paddle.split(mixed_layer, 3, axis=-1)
# [s, b, n, h/n]
q_layer, k_layer = self._core_attention(q_layer, k_layer, position_ids, rotary_embeds)

if cache is not None:
cache_k, cache_v = cache[0], cache[1]
# [s + c, b, n, h/n]
k_layer = paddle.concat([cache_k, k_layer], axis=0)
v_layer = paddle.concat([cache_v, v_layer], axis=0)

cache_kv = None
if use_cache:
cache_kv = (k_layer, v_layer)
version = paddle.version.full_version
version_check = True
if version != "0.0.0" and version <= "2.5.2":
if self.config.use_flash_attention and version != "0.0.0" and version <= "2.5.2":
logger.warning(
"PaddlePaddle version 2.5.3 or higher is required, please upgrade your PaddlePaddle to 2.5.3 or other higher version."
)
version_check = False
if self.config.use_flash_attention and version_check:
# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
# Torch Flash Attention input [ bz, nhead, seqlen, head_dim]
if cache is not None:
cache_k, cache_v = cache[0], cache[1]
# [s + c, b, n, h/n]
k_layer = paddle.concat([cache_k, k_layer], axis=0)
v_layer = paddle.concat([cache_v, v_layer], axis=0)
cache_kv = None
if use_cache:
cache_kv = (k_layer, v_layer)

# [s, b, n, h/n] = > [batch_size, seq_len, num_heads, head_dim]
q_layer = paddle.transpose(q_layer, [1, 0, 2, 3])
k_layer = paddle.transpose(k_layer, [1, 0, 2, 3])
Expand All @@ -286,18 +287,9 @@ def forward(

output, attention_probs = attn_output, attn_weights
else:
if cache is not None:
cache_k, cache_v = cache[0], cache[1]
# [s + c, b, n, h/n]
k_layer = paddle.concat([cache_k, k_layer], axis=0)
v_layer = paddle.concat([cache_v, v_layer], axis=0)

seq_length, batch_size, num_heads, hidden_size = k_layer.shape

cache_kv = None
if use_cache:
cache_kv = (k_layer, v_layer)

attention_scale_coeff = float(layer_id) + 1.0
if self.attention_scale:
# [s, b, n, h/n]
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/llm/finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ finetune:
fp16_opt_level: "O2"
do_train: true
do_eval: true
use_flash_attention: true
disable_tqdm: true
load_best_model_at_end: true
eval_with_do_generation: false
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/llm/predictor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ inference-predict:
mode: dynamic
max_length: 40
batch_size: 2
use_flash_attention: false
decode_strategy: greedy_search
dtype: float16
data_file: tests/fixtures/llm/data/train.json
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/llm/prefix_tuning.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ prefix_tuning:
do_train: true
do_eval: true
disable_tqdm: true
use_flash_attention: false
load_best_model_at_end: true
eval_with_do_generation: false
metric_for_best_model: "accuracy"
Expand Down
24 changes: 24 additions & 0 deletions tests/llm/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,30 @@ def test_predictor(self):
else:
self.assertGreaterEqual(count / len(result_0), 0.4)

def test_flash_attention(self):
self.run_predictor({"inference_model": False, "use_flash_attention": False})
result_0 = self._read_result(os.path.join(self.output_dir, "predict.json"))

self.run_predictor({"inference_model": False, "use_flash_attention": True})
result_1 = self._read_result(os.path.join(self.output_dir, "predict.json"))

# compare the generation result of dygraph & flash attention model
assert len(result_0) == len(result_1)

count, full_match = 0, 0
for inference_item, no_inference_item in zip(result_0, result_1):
if self.model_name_or_path == "__internal_testing__/tiny-random-llama":
min_length = 5
else:
min_length = min(len(inference_item), len(no_inference_item))
count += int(inference_item[: min_length // 2] == no_inference_item[: min_length // 2])
full_match += int(inference_item[:min_length] == no_inference_item[:min_length])

if self.model_name_or_path == "__internal_testing__/tiny-random-llama":
self.assertGreaterEqual(count / len(result_0), 0.2)
else:
self.assertEqual(full_match / len(result_0), 1.0)

def test_wint8(self):
self.run_predictor({"inference_model": True, "quant_type": "weight_only_int8"})
result_0 = self._read_result(os.path.join(self.output_dir, "predict.json"))
Expand Down
1 change: 0 additions & 1 deletion tests/llm/test_prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def test_prefix_tuning(self):

prefix_tuning_config["dataset_name_or_path"] = self.data_dir
prefix_tuning_config["output_dir"] = self.output_dir

with argv_context_guard(prefix_tuning_config):
from finetune_generation import main

Expand Down

0 comments on commit fb8f2be

Please sign in to comment.