-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stage1 support bf16 #58212
Stage1 support bf16 #58212
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
use_pure_bf16(bool): Whether to use the pure bf16 training. Default False. | ||
|
||
use_amp_guard(bool): Whether to use `amp_guard` when constructing the program. | ||
Default True. Only takes effect when `use_pure_fp16` or `use_pure_bf16` is turned on. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可能会引起不兼容升级,如果用户原来使用的时候设置了use_fp16_guard,现在可能无法生效了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
save_dtype=None, | ||
dtype="bfloat16", | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
90~108行这2个针对不同dtype的代码,可以简化下。其他设置都一样,只有dtype不同
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
c91d1d6
to
ba0ec56
Compare
optional bool use_optimizer_fp16 = 12 | ||
optional bool use_pure_bf16 = 11 [ default = false ]; | ||
optional bool use_amp_guard = 12 [ default = true ]; | ||
optional bool use_optimizer_fp16 = 13 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
原来的字段不要修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for set_tests_properties(test_dygraph_sharding_stage1_bf16 PROPERTIES TIMEOUT "200")
PR types
New features
PR changes
APIs
Description
stage1 support bf16
refine
dygraph_group_sharded_stage1_fp16.py
用户接口改动:
DistributedStrategy
的amp_configs
添加use_pure_bf16
选项