-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Fix bug for binary_cross_entropy_with_logits loss #54869
[BugFix] Fix bug for binary_cross_entropy_with_logits loss #54869
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
paddle/phi/api/yaml/op_compat.yaml
Outdated
@@ -2800,7 +2800,7 @@ | |||
- op: sigmoid_cross_entropy_with_logits | |||
backward: sigmoid_cross_entropy_with_logits_grad | |||
inputs : | |||
{x: X, label: Label} | |||
{x: X, label: Label, pos_weight: PosWeight} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个不用改,这是为了兼容原先的写法的,新增参数,不需要兼容旧的,后续新IR重构后原先驼峰式的命名写法都会删除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/paddle/nn/functional/loss.py
Outdated
|
||
helper.append_op( | ||
type="sigmoid_cross_entropy_with_logits", | ||
inputs={"X": logit, "Label": label, "PosWeight": log_weight}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里用pos_weight即可
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddle/phi/infermeta/multiary.cc
Outdated
@@ -3410,5 +3410,61 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row, | |||
out_count->set_dims({-1}); | |||
out_count->set_dtype(DataType::INT32); | |||
} | |||
|
|||
void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
函数位置按字典序放置
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddle/phi/infermeta/multiary.h
Outdated
@@ -646,4 +646,12 @@ void MoeInferMeta(const MetaTensor& x, | |||
const std::string& act_type, | |||
MetaTensor* out); | |||
|
|||
void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
70e5c66
to
673c884
Compare
PR types
Bug fixes
PR changes
OPs
Description
Fix bug of pos_weight calculation for binary_cross_entropy_with_logits loss. Related issue: #54730