Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Customizable Python Layer in Dygraph #32130

Merged
merged 26 commits into from
Apr 15, 2021

Conversation

hbwx24
Copy link
Contributor

@hbwx24 hbwx24 commented Apr 7, 2021

PR types

Function optimization

PR changes

APIs

Describe

功能描述:
为了支持动态图Python端自定义OP的功能,添加了如下API

  • class PyLayer: 通过继承PyLayer构造一个子类的方式实现自定义的Layer。这个子类需要遵循以下规则:

    • 子类必须包含forwardbackward函数,backward和forward必须用@staticmethod修饰。它们的第一个参数必须是context,并且它们的返回结果中不能包含None
    • backward第一个输入为context,其他输入为forward输出的tensor的梯度,因此backward输入tensor个数 = forward输出tensor个数。若需要forward的输入输出计算梯度,可以使用ctx.save_for_backward存储需要的tensor,之后可以backward中使用。
    • backward输出只能为Tensor 或者tuple/list[Tensor],因为backward的输出tensor是forward输入tensor的梯度,因此backward输出tensor个数 = forward输入tensor个数。
    • 使用apply调用自定义的layer。
  • PyLayerContext:可以利用它把某些对象从forward传递到backward。具体用法在下面示例代码中。

    • PyLayerContext.save_for_backward:在forward使用,将tensor存储到PyLayerContext对象中。
    • PyLayerContext.saved_tensor:从PyLayerContext对象中获取保存的tensor。
  • 上述API添加在paddle.autograd下,通过继承paddle.autograd.PyLayer构造自定义的Layer。
    image

TODO:

  • 目前只支持最基础的功能,还有待进一步完善。
  • 目前只支持动态图,后续根据需求添加与静态图相关的功能。

示例:

import paddle
import numpy as np
from paddle.autograd import PyLayer


class double_tanh(PyLayer):
    @staticmethod
    def forward(ctx, x1, x2, func1, func2=paddle.square):
          ctx.func = func2
          y1 = func1(x1)
          y2 = func1(x2)
          # 保存backward需要的tensor
          ctx.save_for_backward(y1, y2)
          return y1, y2

    @staticmethod
    def backward(ctx, dy1, dy2):
        # dy1:y1的梯度;dy2:y2的梯度。
        # 获取在forward中保存的tensor
        y1, y2 = ctx.saved_tensor()
        re1 = dy1 * (1 - ctx.func(y1))
        re2 = dy2 * (1 - paddle.square(y2))
        return re1, re2


input1 = paddle.randn([2, 3]).astype("float64")
input2 = input1.detach().clone()
input1.stop_gradient = False
input2.stop_gradient = False

# 使用自定义layer
z = double_tanh.apply(input1, input1, paddle.tanh, paddle.square)

z = z[0] + z[1]
z.mean().backward()

z2 = paddle.tanh(input2) + paddle.tanh(input2)
z2.mean().backward()
print(np.array_equal(input1.grad,input2.grad))
print(input1.grad-input2.grad)

API的英文文档:
http://10.136.157.23:8090/documentation/docs/en/api/paddle/autograd/py_layer/PyLayer_en.html
image

image
image

image
http://10.136.157.23:8090/documentation/docs/en/api/paddle/autograd/py_layer/PyLayerContext_en.html
image
image
image

@@ -279,6 +279,8 @@ class TracedGradOp {

void SetType(const std::string& type) { op_->SetType(type); }

const framework::OperatorBase& InnerOp() { return op_->InnerOp(); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add const limit after InnerOp() ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx.

@@ -279,6 +279,6 @@ std::shared_ptr<GradOpNode> CreateGradOpNode(
const NameVarBaseMap& outs, const framework::AttributeMap& attrs,
const platform::Place& place,
const std::map<std::string, std::string>& inplace_map);

void ClearNoNeedBufferInputs(OpBase* op);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add blank line before and after this statement?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx.

auto a = ptr->cast<std::shared_ptr<VarBase>>();
input_vars.push_back(a);
}
} catch (py::cast_error& err) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要报错吧,可以提示现在不支持,但跳过是不是不太好,会让用户困惑?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

添加了注释,收集所有的tensor变量,忽略所有非tensor变量。

auto a = ptr->second.cast<std::shared_ptr<VarBase>>();
input_vars.push_back(a);
}
} catch (py::cast_error&) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx

output_vars.push_back(temp_out);
} catch (py::cast_error&) {
PADDLE_THROW(platform::errors::Unimplemented(
"The output of forward should be `Tensor`."));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyLayer.forward?写全一点?让用户清楚是哪个forword,下面有多处相同报错,建议都完善一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx.

class PyLayerOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果不支持的话,加PADDLE_THROW,不要为空,否则误用的时候不太容易找问题,可以加几个异常单测,覆盖一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

添加Log提醒这里没有实现Infershape。

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里预期就是不支持不同类型运算的吧,这里X是Duplicable的,如果X的两个输入类型不一致,比如一个是float,一个是int,会报错吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,不支持不同类型的运算,如果两个类型不一致会触发警告:『The dtype of left and right variables are not the same, left dtype is paddle.int64, but right dtype is paddle.float32, the right dtype will convert to paddle.int64』

ops::PyLayerGradOpMaker<paddle::framework::OpDesc>);

REGISTER_OP_CPU_KERNEL(
py_layer, ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext, float>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

数据类型需要补全吗?这个要支持所有类型吧

Copy link
Contributor Author

@hbwx24 hbwx24 Apr 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx.

return super(FunctionMeta, cls).__init__(name, bases, attrs)


class PyLayer(with_mateclass(FunctionMeta, CFunction)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

先移到autograd下吧,补全文档

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx.

@@ -0,0 +1,158 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里也是,新建目录只管理了一个文件,建议直接移到imperative下,不需要加目录

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx.

"""
Build a custom `Layer` by creating subclasses. Subclasses need to follow the following rules:
1. Subclasses contain `forward` and `backward` function. Both forward and backward are @staticmethod.
Their first argument should be a context and `None` can be included in the returned result.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None can还是can not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx,done.

@staticmethod
def forward(ctx, *args, **kwargs):
raise NotImplementedError(
"You must implement the forward function for custom layer.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

custom layer -> PyLayer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx, done.

@staticmethod
def backward(ctx, *args, **kwargs):
raise NotImplementedError(
"You must implement the backward function for custom layer.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx, done.

@hbwx24 hbwx24 changed the title Custom op/save for bk python 端自定义OP Apr 9, 2021
TCChenlong
TCChenlong previously approved these changes Apr 9, 2021
Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@hbwx24 hbwx24 changed the title python 端自定义OP Customizable Python Layer Apr 9, 2021
@hbwx24 hbwx24 changed the title Customizable Python Layer Customizable Python Layer in Dygraph Apr 9, 2021
chenwhql
chenwhql previously approved these changes Apr 9, 2021
Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

lanxianghit
lanxianghit previously approved these changes Apr 13, 2021
XieYunshen
XieYunshen previously approved these changes Apr 13, 2021
raindrops2sea
raindrops2sea previously approved these changes Apr 13, 2021
XiaoguangHu01
XiaoguangHu01 previously approved these changes Apr 14, 2021
Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

XieYunshen
XieYunshen previously approved these changes Apr 14, 2021
lanxianghit
lanxianghit previously approved these changes Apr 14, 2021
TCChenlong
TCChenlong previously approved these changes Apr 14, 2021
Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

ForFishes
ForFishes previously approved these changes Apr 14, 2021
Copy link
Member

@ForFishes ForFishes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chenwhql chenwhql merged commit 29f6522 into PaddlePaddle:develop Apr 15, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants