Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix llama3 static run #8849

Merged
merged 1 commit into from
Aug 28, 2024

Conversation

yuanlehome
Copy link
Collaborator

@yuanlehome yuanlehome commented Jul 31, 2024

PR types

Bug fixes

PR changes

Others

Description

修复llama3散op静态图推理的一系列问题,精度正常

Copy link

paddle-bot bot commented Jul 31, 2024

Thanks for your contribution!

Copy link

codecov bot commented Jul 31, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 53.87%. Comparing base (34a71c8) to head (10d3e95).
Report is 227 commits behind head on develop.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #8849      +/-   ##
===========================================
+ Coverage    53.81%   53.87%   +0.06%     
===========================================
  Files          652      652              
  Lines       104356   104356              
===========================================
+ Hits         56155    56220      +65     
+ Misses       48201    48136      -65     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@yuanlehome yuanlehome force-pushed the fix_llama3_static_run branch from e63c9b6 to 02680fa Compare August 27, 2024 06:27
DesmonDay
DesmonDay previously approved these changes Aug 27, 2024
Copy link
Contributor

@DesmonDay DesmonDay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -1407,6 +1407,7 @@ def _post_process_(
# compute next_tokens
if use_top_p:
logits = logits / temperature
probs = paddle.cast(probs, paddle.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是在支持下top_p_sampling在bf16算子的支持?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bf16算子kernel实现上是支持的,这里不cast成fp32是因为会报错

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

具体报错原因是什么?

Copy link
Collaborator Author

@yuanlehome yuanlehome Aug 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我又确认了下,确实是kernel没有注册bf16,这个我在paddle侧支持了,所以这里添加cast的逻辑已移除,验证也没有问题

Copy link
Contributor

@DesmonDay DesmonDay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@DesmonDay DesmonDay merged commit 2f567e6 into PaddlePaddle:develop Aug 28, 2024
10 of 12 checks passed
Mangodadada pushed a commit to Mangodadada/PaddleNLP that referenced this pull request Sep 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants