Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NewIR]Split python api and vjp #56518

Merged
merged 61 commits into from
Aug 28, 2023

Conversation

xiaoguoguo626807
Copy link
Contributor

PR types

others

PR changes

others

Description

pcard-67164
split算子前向调用+反向调用链代码开发

cyber-pioneer and others added 30 commits August 10, 2023 11:42
@Aurelius84 Aurelius84 changed the title 【newir】Split python api and vjp [NewIR]Split python api and vjp Aug 25, 2023
Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall

@@ -18,5 +18,16 @@
#include "paddle/ir/core/builtin_op.h"

namespace paddle {
namespace dialect {} // namespace dialect
namespace dialect {
ir::OpResult split_grad(std::vector<ir::OpResult> out_grads,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ir::OpResult split_grad(std::vector<ir::OpResult> out_grads,
ir::OpResult split_grad(const std::vector<ir::OpResult>& out_grads,

@0x45f 我看 pd_api.h 里的 concat、add_n 的vector 入参也是值copy,这个是不是可以优化为 const &?

Copy link
Contributor Author

@xiaoguoguo626807 xiaoguoguo626807 Aug 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个在build 函数手写完备后,可加入自动代码生成一起完善

@@ -51,9 +51,28 @@ class AddNOp : public ir::Op<AddNOp, OpYamlInfoInterface> {
static void InferMeta(phi::InferMetaContext *infer_meta);
};

class SplitGradOp : public ir::Op<SplitGradOp, OpYamlInfoInterface> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不需要有 paddle::dialect::InferMetaInterface 接口么?我看pd_op.h 里其他的GradOp 都是有的

Copy link
Contributor Author

@xiaoguoguo626807 xiaoguoguo626807 Aug 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个splitgrad yaml 中没有infermeta, 使用的是invoke算子concat 的infermeta, 已补充

@@ -14,6 +14,7 @@

#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是可以按需 include 必要的头文件

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

splitgrad 调用了full op, 需要引入此头文件
int axis = axis_.owner()
->dyn_castpaddle::dialect::FullOp()
.attributes()
.at("value")
.dyn_castpaddle::dialect::ScalarAttribute()
.data()
.to();

.dyn_cast<paddle::dialect::ScalarAttribute>()
.data()
.to<int>();
(void)axis;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要添加这行代码吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
}

void SplitGradOp::Verify() {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据前面GetOpInfo的描述,split 存在两个输入:out_grad 和 axis,Verify 中有必要做一下校验

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

false,
false),
OpInputInfo(
"axis", "paddle::dialect::ScalarAttribute", false, false, true)};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

axis 是可变 attribute 的话,是不是有必要添加一个axis 为Scalar类型的Build 接口?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, 此处build接口无使用样例,存在覆盖率问题

YuanRisheng
YuanRisheng previously approved these changes Aug 25, 2023
@@ -402,5 +402,74 @@ TEST(VJP, Add_BackwardTest) {
ASSERT_EQ(dx.data<float>()[0], 1.0);
ASSERT_EQ(dy.data<float>()[0], 1.0);
}

TEST(VJP, SplitBackwardTest) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这儿单测能否在python端用新执行器去测试? c++这种直接在InterpreterCore 中取值的方案不是很推荐。

@xiaoguoguo626807 xiaoguoguo626807 merged commit 7995a38 into PaddlePaddle:develop Aug 28, 2023
BeingGod pushed a commit to BeingGod/Paddle that referenced this pull request Sep 9, 2023
* support ir api form prim

* convert vector of int to intarray

* add reference of lbfgs

* add reference of lbfgs

* support ir api for prim

* Add more gen api

* concat python api to concat_grad

* fix gen conflict

* support vjp prim mode in new ir

* remove useless code

* add vjp autogen v1.0

* add test for prim

* resolve type conflict

* modify utils

* remove useless code

* add split op and modify some bug of vectorType

* fix conflict

* add concat python test

* add split python api to vjp

* modify build bug

* modify run bug

* fix conflict bug

* build bug fix

* modify python api bug

* modify test

* fix conflict

* fluid backward recover

* recover conflict

* reply review comments

* modify opruntimeinfo num

---------

Co-authored-by: cyber-pioneer <chenzhuo@tju.edu.cn>
Co-authored-by: Charles-hit <wanghao107@baidu.com>
Co-authored-by: 0x45f <wangzhen45@baidu.com>
Co-authored-by: chenzhiyang <1792266893@qq.com>
Co-authored-by: Chen Zhiyang <chenzhiyang99@126.com>
@xiaoguoguo626807 xiaoguoguo626807 deleted the split_vjp branch June 12, 2024 02:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants