Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR]Fix cross_entropy_with_softmax vjp bug && check_grad passed #59451

Merged
merged 1 commit into from
Nov 29, 2023

Conversation

changeyoung98
Copy link
Contributor

@changeyoung98 changeyoung98 commented Nov 28, 2023

PR types

Others

PR changes

Others

Description

pcard-67164
Fix cross_entropy_with_softmax vjp bug && check_grad passed

  • cross_entropy_with_softmax 反向out_grads只需要第二个输出的梯度,原本的vjp生成逻辑是默认全部的梯度都被使用,因此获取out_grads的index是从0开始递增,在这种情况下会获取到不被使用的梯度,导致后面被使用的梯度被剪掉。修改生成逻辑变为根据out_grad变量名字去掉'_grad'找到对应out变量在output_list中的index,然后从grad列表中获取对应out_grad。
  • 修复#58682 的反向报错问题,开启pir反向单测。

Copy link

paddle-bot bot commented Nov 28, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@xiaoguoguo626807 xiaoguoguo626807 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -149,7 +149,7 @@ def gen_op_vjp_str(
index_0 = fwd_outputs_list.index(bw_input_name)
else:
vjp_param_name = 'out_grads'
grad_idx += 1
grad_idx = fwd_outputs_list.index(bw_input_name[:-5])
Copy link
Contributor

@Aurelius84 Aurelius84 Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
grad_idx = fwd_outputs_list.index(bw_input_name[:-5])
offset = len("_grad")
grad_idx = fwd_outputs_list.index(bw_input_name[:-offset])

这里建议不要直接写 -5,不熟悉背景的其他开发者会比较奇怪这个magic_number 是怎么来的。之前我升级gen.py时遇到过影响其他magic_number导致CI失败问题,排查起来不是很友好。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,后面fix一下

@Aurelius84 Aurelius84 merged commit b8a02c0 into PaddlePaddle:develop Nov 29, 2023
29 of 30 checks passed
@changeyoung98 changeyoung98 deleted the czy-grad1 branch November 29, 2023 08:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants