-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NewIR]Split python api and vjp #56518
[NewIR]Split python api and vjp #56518
Conversation
…into split_vjp2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall
@@ -18,5 +18,16 @@ | |||
#include "paddle/ir/core/builtin_op.h" | |||
|
|||
namespace paddle { | |||
namespace dialect {} // namespace dialect | |||
namespace dialect { | |||
ir::OpResult split_grad(std::vector<ir::OpResult> out_grads, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ir::OpResult split_grad(std::vector<ir::OpResult> out_grads, | |
ir::OpResult split_grad(const std::vector<ir::OpResult>& out_grads, |
@0x45f 我看 pd_api.h 里的 concat、add_n 的vector 入参也是值copy,这个是不是可以优化为 const &?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个在build 函数手写完备后,可加入自动代码生成一起完善
@@ -51,9 +51,28 @@ class AddNOp : public ir::Op<AddNOp, OpYamlInfoInterface> { | |||
static void InferMeta(phi::InferMetaContext *infer_meta); | |||
}; | |||
|
|||
class SplitGradOp : public ir::Op<SplitGradOp, OpYamlInfoInterface> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是不需要有 paddle::dialect::InferMetaInterface 接口么?我看pd_op.h 里其他的GradOp 都是有的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个splitgrad yaml 中没有infermeta, 使用的是invoke算子concat 的infermeta, 已补充
@@ -14,6 +14,7 @@ | |||
|
|||
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h" | |||
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" | |||
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是可以按需 include 必要的头文件
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
splitgrad 调用了full op, 需要引入此头文件
int axis = axis_.owner()
->dyn_castpaddle::dialect::FullOp()
.attributes()
.at("value")
.dyn_castpaddle::dialect::ScalarAttribute()
.data()
.to();
.dyn_cast<paddle::dialect::ScalarAttribute>() | ||
.data() | ||
.to<int>(); | ||
(void)axis; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不需要添加这行代码吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); | ||
} | ||
|
||
void SplitGradOp::Verify() {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据前面GetOpInfo的描述,split 存在两个输入:out_grad 和 axis,Verify 中有必要做一下校验
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
false, | ||
false), | ||
OpInputInfo( | ||
"axis", "paddle::dialect::ScalarAttribute", false, false, true)}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
axis 是可变 attribute 的话,是不是有必要添加一个axis 为Scalar类型的Build 接口?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, 此处build接口无使用样例,存在覆盖率问题
@@ -402,5 +402,74 @@ TEST(VJP, Add_BackwardTest) { | |||
ASSERT_EQ(dx.data<float>()[0], 1.0); | |||
ASSERT_EQ(dy.data<float>()[0], 1.0); | |||
} | |||
|
|||
TEST(VJP, SplitBackwardTest) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这儿单测能否在python端用新执行器去测试? c++这种直接在InterpreterCore
中取值的方案不是很推荐。
* support ir api form prim * convert vector of int to intarray * add reference of lbfgs * add reference of lbfgs * support ir api for prim * Add more gen api * concat python api to concat_grad * fix gen conflict * support vjp prim mode in new ir * remove useless code * add vjp autogen v1.0 * add test for prim * resolve type conflict * modify utils * remove useless code * add split op and modify some bug of vectorType * fix conflict * add concat python test * add split python api to vjp * modify build bug * modify run bug * fix conflict bug * build bug fix * modify python api bug * modify test * fix conflict * fluid backward recover * recover conflict * reply review comments * modify opruntimeinfo num --------- Co-authored-by: cyber-pioneer <chenzhuo@tju.edu.cn> Co-authored-by: Charles-hit <wanghao107@baidu.com> Co-authored-by: 0x45f <wangzhen45@baidu.com> Co-authored-by: chenzhiyang <1792266893@qq.com> Co-authored-by: Chen Zhiyang <chenzhiyang99@126.com>
PR types
others
PR changes
others
Description
pcard-67164
split算子前向调用+反向调用链代码开发