Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bloom] Add kv cache support for flash attention & fix bugs #7735

Merged
merged 12 commits into from
Dec 29, 2023
Merged

Conversation

w5688414
Copy link
Contributor

@w5688414 w5688414 commented Dec 27, 2023

PR types

PR changes

Description

TODO:

  • 精度定位
  • Add CI itest

Copy link

paddle-bot bot commented Dec 27, 2023

Thanks for your contribution!

Copy link

codecov bot commented Dec 27, 2023

Codecov Report

Attention: 1 lines in your changes are missing coverage. Please review.

Comparison is base (5de7e57) 57.29% compared to head (a5f3286) 57.30%.
Report is 5 commits behind head on develop.

Files Patch % Lines
paddlenlp/peft/prefix/utils.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff            @@
##           develop    #7735   +/-   ##
========================================
  Coverage    57.29%   57.30%           
========================================
  Files          584      584           
  Lines        87646    87628   -18     
========================================
- Hits         50219    50215    -4     
+ Misses       37427    37413   -14     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

JunnYu
JunnYu previously approved these changes Dec 28, 2023
Copy link
Member

@JunnYu JunnYu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@w5688414 w5688414 self-assigned this Dec 28, 2023
@w5688414 w5688414 changed the title [bloom] Add kv cache support for flash attention [bloom] Add kv cache support for flash attention & fix bugs Dec 28, 2023
JunnYu
JunnYu previously approved these changes Dec 29, 2023
Copy link
Member

@JunnYu JunnYu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -17,10 +17,10 @@

def bloom_postprocess_past_key_value(past_key_values):
# (layer_num, bs, head_num/tensor_parallel_degree, prefixlen, head_dim)*2
past_key_values = paddle.transpose(past_key_values, perm=[2, 0, 3, 1, 4]).split(2)
keys, values = paddle.transpose(past_key_values, perm=[2, 0, 1, 3, 4]).split(2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块 @lugimzzz 看看呗,之前用 bloom 训过 ptuning,精度是对齐的,如果这块调整之后是否会影响目前对齐的版本,前端的推理是否也需要调整呢?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

训练测试过就可以,确认不影响推理

@@ -3,6 +3,7 @@ inference-predict:
mode: dynamic
max_length: 40
batch_size: 2
use_flash_attention: false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的配置是否要设置成 true 呢?

Copy link
Contributor Author

@w5688414 w5688414 Dec 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我单独写了一个单测,在单测里面加了use_flash_attention:true的配置

https://github.com/PaddlePaddle/PaddleNLP/pull/7735/files#diff-378fff328c26822fbce1c8f410ab466ca2b2b9f47b37167a1159de1ac67f3f31R81

Copy link
Member

@JunnYu JunnYu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@w5688414 w5688414 merged commit fb8f2be into develop Dec 29, 2023
10 of 11 checks passed
JunnYu pushed a commit that referenced this pull request Dec 29, 2023
* Add kv cache support for flash attention

* Update chatglm flash attention version check

* Add test for flash attention

* Fix unitest bug

* Add flash attention to predictor

* Add flash attention2

* Add flash attention unitests

* fix prefix decoder

* remove unused comments

* Update unitest

* Update unitest
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants