Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dtype] add fp16 support for dist_kernel #56184

Merged
merged 14 commits into from
Aug 15, 2023

Conversation

jinyouzhi
Copy link
Contributor

@jinyouzhi jinyouzhi commented Aug 11, 2023

PR types

Function optimization

PR changes

OPs

Description

refer to #50915

@paddle-bot
Copy link

paddle-bot bot commented Aug 11, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added contributor External developers status: proposed labels Aug 11, 2023
@luotao1 luotao1 added the HappyOpenSource 快乐开源活动issue与PR label Aug 11, 2023
@jinyouzhi
Copy link
Contributor Author

Ready for review~ @zhangting2020

@zhangting2020
Copy link
Contributor

请问之前的实现引起推理精度下降的问题,在这个PR上有做过验证吗? #53525

}

private:
T p_order_;
Ty p_order;
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

变量的命名请参考google-cpp-styleguide

phi::DistGradKernel,
float,
double,
phi::dtype::float16) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bfloat16也记得注册一下,另外添加对应单测

x (Tensor): 1-D to 6-D Tensor, its data type is float32 or float64.
y (Tensor): 1-D to 6-D Tensor, its data type is float32 or float64.
x (Tensor): 1-D to 6-D Tensor, its data type is float16, float32 or float64.
y (Tensor): 1-D to 6-D Tensor, its data type is float16, float32 or float64.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注册bfloat16类型后,文档记得更新

)
check_variable_and_dtype(
y, 'dtype', ['float16', 'float32', 'float64'], 'dist'
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,bfloat16类型的支持

@jinyouzhi
Copy link
Contributor Author

请问之前的实现引起推理精度下降的问题,在这个PR上有做过验证吗? #53525

比较尴尬,手上没有Turing及以前架构的卡,重新编一个Ampare的可行吗?
实际上,上次应该是因为改动了cuda_reduce_utils.h才会引起regression,这次是没有动这个,理论上应该不会引发。

Copy link
Contributor

@zhangting2020 zhangting2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.
记得同步更新下中文文档

@luotao1 luotao1 merged commit ea590ef into PaddlePaddle:develop Aug 15, 2023
@jinyouzhi jinyouzhi deleted the dist_fp16_new branch August 17, 2023 07:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers HappyOpenSource 快乐开源活动issue与PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants