Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR]Fix cross_entropy_with_softmax vjp bug && check_grad passed #59451

Merged
merged 1 commit into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/op_generator/op_interface_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def gen_op_vjp_str(
index_0 = fwd_outputs_list.index(bw_input_name)
else:
vjp_param_name = 'out_grads'
grad_idx += 1
grad_idx = fwd_outputs_list.index(bw_input_name[:-5])
Copy link
Contributor

@Aurelius84 Aurelius84 Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
grad_idx = fwd_outputs_list.index(bw_input_name[:-5])
offset = len("_grad")
grad_idx = fwd_outputs_list.index(bw_input_name[:-offset])

这里建议不要直接写 -5,不熟悉背景的其他开发者会比较奇怪这个magic_number 是怎么来的。之前我升级gen.py时遇到过影响其他magic_number导致CI失败问题,排查起来不是很友好。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,后面fix一下

index_0 = grad_idx
if op_grad_info.input_optional_list[idx] == 'true':
if input_type == 'Tensor':
Expand Down
24 changes: 12 additions & 12 deletions test/legacy_test/test_softmax_with_cross_entropy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,22 +160,22 @@ def test_check_grad(self):
if core.is_compiled_with_rocm():
if self.python_api is not None:
self.check_grad(
["Logits"], "Loss", max_relative_error=5e-1, check_pir=False
["Logits"], "Loss", max_relative_error=5e-1, check_pir=True
)
# HIP will have accuracy fail when using float32 in CPU place
self.check_grad(
["Logits"], "Loss", max_relative_error=5e-1, check_pir=False
["Logits"], "Loss", max_relative_error=5e-1, check_pir=True
)
else:
if self.python_api is not None:
self.check_grad(
["Logits"],
"Loss",
numeric_grad_delta=0.001,
check_pir=False,
check_pir=True,
)
self.check_grad(
["Logits"], "Loss", numeric_grad_delta=0.001, check_pir=False
["Logits"], "Loss", numeric_grad_delta=0.001, check_pir=True
)


Expand Down Expand Up @@ -517,9 +517,9 @@ def test_check_output(self):

def test_check_grad(self):
if self.python_api is not None:
self.check_grad(["Logits"], "Loss", check_pir=False)
self.check_grad(["Logits"], "Loss", check_pir=True)
self.check_grad(
["Logits"], "Loss", max_relative_error=0.1, check_pir=False
["Logits"], "Loss", max_relative_error=0.1, check_pir=True
)


Expand All @@ -540,10 +540,10 @@ def initParams(self):
def test_check_grad(self):
if self.python_api is not None:
self.check_grad(
["Logits"], "Loss", max_relative_error=0.1, check_pir=False
["Logits"], "Loss", max_relative_error=0.1, check_pir=True
)
self.check_grad(
["Logits"], "Loss", max_relative_error=0.1, check_pir=False
["Logits"], "Loss", max_relative_error=0.1, check_pir=True
)


Expand Down Expand Up @@ -574,15 +574,15 @@ def test_check_grad(self):
# HIP will have accuracy fail when using float32 in CPU place
if self.python_api is not None:
self.check_grad(
["Logits"], "Loss", max_relative_error=0.1, check_pir=False
["Logits"], "Loss", max_relative_error=0.1, check_pir=True
)
self.check_grad(
["Logits"], "Loss", max_relative_error=0.1, check_pir=False
["Logits"], "Loss", max_relative_error=0.1, check_pir=True
)
else:
if self.python_api is not None:
self.check_grad(["Logits"], "Loss", check_pir=False)
self.check_grad(["Logits"], "Loss", check_pir=False)
self.check_grad(["Logits"], "Loss", check_pir=True)
self.check_grad(["Logits"], "Loss", check_pir=True)


class TestSoftmaxWithCrossEntropyOp3(TestSoftmaxWithCrossEntropyOp):
Expand Down