Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix pylayer problem with amp #39950

Merged
merged 3 commits into from
Feb 27, 2022
Merged

Conversation

zhiqiu
Copy link
Contributor

@zhiqiu zhiqiu commented Feb 25, 2022

PR types

Bug fixes

PR changes

APIs

Describe

When using PyLayer with amp, the forward part of PyLayer may be invoked with amp enabled, while the backward part is not, thus resulting in errors.

related #39881

import paddle
from paddle.autograd import PyLayer

class MyMM(PyLayer):
    @staticmethod
    def forward(ctx, a, b):
        ctx.save_for_backward(a, b)
        return a.mm(b)
    
    @staticmethod
    def backward(ctx, grad):
        a, b = ctx.saved_tensor()  
        # NOTE(zhiqiu): a and b is float32 now, while grad is fp16 when forward runs with auto_cast()
        # thus, the mm operation raise errors because of the dtype of inputs are inconsistent
        return grad.mm(b.t()), a.t().mm(grad)

x = paddle.rand([10, 10])
y = paddle.rand([10, 10])
x.stop_gradient=False
y.stop_gradient=False

with paddle.amp.auto_cast():
   res = MyMM.apply(x, y)
   loss = paddle.mean(res)

loss.backward()
  • before
    run failed.
    image
  • after
    run ok

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

chenwhql
chenwhql previously approved these changes Feb 25, 2022
Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

zhangbo9674
zhangbo9674 previously approved these changes Feb 26, 2022
Copy link
Contributor

@zhangbo9674 zhangbo9674 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhiqiu zhiqiu dismissed stale reviews from zhangbo9674 and chenwhql via 3b051be February 26, 2022 03:05
Copy link
Contributor

@zhangbo9674 zhangbo9674 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhiqiu zhiqiu merged commit 282e09d into PaddlePaddle:develop Feb 27, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants