-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Customizable Python Layer in Dygraph #32130
Conversation
@@ -279,6 +279,8 @@ class TracedGradOp { | |||
|
|||
void SetType(const std::string& type) { op_->SetType(type); } | |||
|
|||
const framework::OperatorBase& InnerOp() { return op_->InnerOp(); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add const limit after InnerOp()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done,thx.
@@ -279,6 +279,6 @@ std::shared_ptr<GradOpNode> CreateGradOpNode( | |||
const NameVarBaseMap& outs, const framework::AttributeMap& attrs, | |||
const platform::Place& place, | |||
const std::map<std::string, std::string>& inplace_map); | |||
|
|||
void ClearNoNeedBufferInputs(OpBase* op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add blank line before and after this statement?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx.
auto a = ptr->cast<std::shared_ptr<VarBase>>(); | ||
input_vars.push_back(a); | ||
} | ||
} catch (py::cast_error& err) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里需要报错吧,可以提示现在不支持,但跳过是不是不太好,会让用户困惑?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
添加了注释,收集所有的tensor变量,忽略所有非tensor变量。
auto a = ptr->second.cast<std::shared_ptr<VarBase>>(); | ||
input_vars.push_back(a); | ||
} | ||
} catch (py::cast_error&) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done,thx
output_vars.push_back(temp_out); | ||
} catch (py::cast_error&) { | ||
PADDLE_THROW(platform::errors::Unimplemented( | ||
"The output of forward should be `Tensor`.")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyLayer.forward?写全一点?让用户清楚是哪个forword,下面有多处相同报错,建议都完善一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx.
paddle/fluid/operators/py_layer_op.h
Outdated
class PyLayerOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
void InferShape(framework::InferShapeContext* ctx) const override {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果不支持的话,加PADDLE_THROW,不要为空,否则误用的时候不太容易找问题,可以加几个异常单测,覆盖一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
添加Log提醒这里没有实现Infershape。
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里预期就是不支持不同类型运算的吧,这里X是Duplicable的,如果X的两个输入类型不一致,比如一个是float,一个是int,会报错吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,不支持不同类型的运算,如果两个类型不一致会触发警告:『The dtype of left and right variables are not the same, left dtype is paddle.int64, but right dtype is paddle.float32, the right dtype will convert to paddle.int64』
ops::PyLayerGradOpMaker<paddle::framework::OpDesc>); | ||
|
||
REGISTER_OP_CPU_KERNEL( | ||
py_layer, ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext, float>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
数据类型需要补全吗?这个要支持所有类型吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done,thx.
python/paddle/nn/layer/py_layer.py
Outdated
return super(FunctionMeta, cls).__init__(name, bases, attrs) | ||
|
||
|
||
class PyLayer(with_mateclass(FunctionMeta, CFunction)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
先移到autograd下吧,补全文档
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx.
@@ -0,0 +1,158 @@ | |||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里也是,新建目录只管理了一个文件,建议直接移到imperative下,不需要加目录
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx.
python/paddle/autograd/py_layer.py
Outdated
""" | ||
Build a custom `Layer` by creating subclasses. Subclasses need to follow the following rules: | ||
1. Subclasses contain `forward` and `backward` function. Both forward and backward are @staticmethod. | ||
Their first argument should be a context and `None` can be included in the returned result. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None
can还是can not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thx,done.
python/paddle/autograd/py_layer.py
Outdated
@staticmethod | ||
def forward(ctx, *args, **kwargs): | ||
raise NotImplementedError( | ||
"You must implement the forward function for custom layer.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
custom layer
-> PyLayer
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thx, done.
python/paddle/autograd/py_layer.py
Outdated
@staticmethod | ||
def backward(ctx, *args, **kwargs): | ||
raise NotImplementedError( | ||
"You must implement the backward function for custom layer.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thx, done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
c7405d3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
f01faec
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Function optimization
PR changes
APIs
Describe
功能描述:
为了支持动态图Python端自定义OP的功能,添加了如下API
class PyLayer
: 通过继承PyLayer构造一个子类的方式实现自定义的Layer
。这个子类需要遵循以下规则:forward
和backward
函数,backward和forward必须用@staticmethod
修饰。它们的第一个参数必须是context
,并且它们的返回结果中不能包含None
。ctx.save_for_backward
存储需要的tensor,之后可以backward中使用。PyLayerContext:可以利用它把某些对象从forward传递到backward。具体用法在下面示例代码中。
PyLayerContext.save_for_backward
:在forward使用,将tensor存储到PyLayerContext对象中。PyLayerContext.saved_tensor
:从PyLayerContext对象中获取保存的tensor。上述API添加在paddle.autograd下,通过继承
paddle.autograd.PyLayer
构造自定义的Layer。TODO:
示例:
API的英文文档:
http://10.136.157.23:8090/documentation/docs/en/api/paddle/autograd/py_layer/PyLayer_en.html
http://10.136.157.23:8090/documentation/docs/en/api/paddle/autograd/py_layer/PyLayerContext_en.html