Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PRIM][IR]support add vjp #56163

Merged
merged 63 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
f630e33
[prim][newir] add basic framework for primitive
cxxly Jul 20, 2023
c8bd625
support desctensor in new ir
Charles-hit Jul 24, 2023
5612359
add vjp interface
zhangbo9674 Jul 24, 2023
4d8079f
Merge commit 'refs/pull/55660/head' of https://github.com/PaddlePaddl…
Charles-hit Jul 25, 2023
fe5605b
support vjp in new ir
Charles-hit Jul 25, 2023
f9389ec
support vjp in new ir
Charles-hit Jul 26, 2023
67cf1fc
polish vjp interface
Charles-hit Jul 27, 2023
35f867b
fix stop_gradients set
Charles-hit Jul 27, 2023
5fe88d5
resolve conflict
Charles-hit Jul 27, 2023
703c168
fix vjp dispatch
Charles-hit Jul 27, 2023
0738201
add comment
Charles-hit Jul 27, 2023
d49d38a
add vjp test for new ir
Charles-hit Jul 27, 2023
a9e9d01
add test for tanh vjp
Charles-hit Jul 27, 2023
4df18b5
[prim][newir] add basic framework for primitive
cxxly Jul 20, 2023
5a65b50
support desctensor in new ir
Charles-hit Jul 24, 2023
5a3710a
support vjp in new ir
Charles-hit Jul 25, 2023
c035675
support vjp in new ir
Charles-hit Jul 26, 2023
a9b8240
polish vjp interface
Charles-hit Jul 27, 2023
901352c
fix stop_gradients set
Charles-hit Jul 27, 2023
de4ac55
fix vjp dispatch
Charles-hit Jul 27, 2023
f3da449
add comment
Charles-hit Jul 27, 2023
84b92dd
add vjp test for new ir
Charles-hit Jul 27, 2023
690a0b9
add test for tanh vjp
Charles-hit Jul 27, 2023
4ee2d44
add eager and static backend for warp lower level api
cxxly Jul 28, 2023
866dc2c
support call_vjp pybind
Charles-hit Jul 28, 2023
0d3d7d6
support call_vjp pybind
Charles-hit Jul 28, 2023
dc3e7be
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Charles-hit Aug 1, 2023
b4579f2
polish code and add test for vjp
Charles-hit Aug 2, 2023
be05029
remove useless code
Charles-hit Aug 2, 2023
619bcd0
polish code
Charles-hit Aug 2, 2023
e57d1f0
remove useless code
Charles-hit Aug 2, 2023
ac8b2a6
support mean vjp
Charles-hit Aug 3, 2023
5612b2f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Charles-hit Aug 3, 2023
afcb454
add test for mean vjp and support has_vjp function
Charles-hit Aug 3, 2023
40d7ab0
fix call_vjp
Charles-hit Aug 3, 2023
d9a78f6
polish code
Charles-hit Aug 4, 2023
ed442ff
add primitive ops set for backend
Charles-hit Aug 4, 2023
95efe5e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Charles-hit Aug 4, 2023
f802b36
add vjp test for tanh_
Charles-hit Aug 6, 2023
820b313
fix inference CI
Charles-hit Aug 7, 2023
4f320f0
fix inference ci
Charles-hit Aug 7, 2023
fe1b035
modify fluid cmake
Charles-hit Aug 7, 2023
587bea0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Charles-hit Aug 7, 2023
c155302
remove useless deps
Charles-hit Aug 7, 2023
d4f37b2
add cmake
Charles-hit Aug 7, 2023
bde35c0
fix comment
Charles-hit Aug 10, 2023
aaa32d9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Charles-hit Aug 10, 2023
d06b608
fix test
Charles-hit Aug 10, 2023
cd562f0
polish code
Charles-hit Aug 10, 2023
f9bca3c
modify backward stop_gradients
Charles-hit Aug 10, 2023
7ba60cc
modify static_backend.cc
Charles-hit Aug 10, 2023
87adbb4
Merge commit 'refs/pull/56137/head' of https://github.com/PaddlePaddl…
Charles-hit Aug 10, 2023
0a1ff71
support add and add_inplace vjp
Charles-hit Aug 10, 2023
096db4e
remove useless code
Charles-hit Aug 10, 2023
f07fd0f
remove useless code
Charles-hit Aug 10, 2023
bc5a8e8
Merge commit 'refs/pull/56137/head' of https://github.com/PaddlePaddl…
Charles-hit Aug 10, 2023
05842d0
remove cout
Charles-hit Aug 11, 2023
30add70
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Charles-hit Aug 11, 2023
c176eef
remove cout
Charles-hit Aug 11, 2023
19e3fcf
fix add_grad
Charles-hit Aug 14, 2023
05be317
fix add test exe
Charles-hit Aug 14, 2023
3f8531d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Charles-hit Aug 15, 2023
f5141b9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Charles-hit Aug 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
# TODO(wanghao107)
# remove this file and support Vjp methods
# code gen.
vjp_interface_gen_op_list = ["tanh", "mean"]
vjp_interface_gen_op_list = ["tanh", "mean", "add"]
50 changes: 50 additions & 0 deletions paddle/fluid/ir/dialect/pd_op_vjp_manual.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,55 @@ std::vector<std::vector<ir::OpResult>> MeanOp::Vjp(
}
return res;
}

std::vector<std::vector<ir::OpResult>> AddOp::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
AddOp op_obj = op->dyn_cast<AddOp>();
Tensor x(std::make_shared<primitive::experimental::DescTensor>(op_obj.x()));
Tensor y(std::make_shared<primitive::experimental::DescTensor>(op_obj.y()));
Tensor out_grad(
std::make_shared<primitive::experimental::DescTensor>(out_grads[0][0]));
int axis = -1;

std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::add_vjp(x, y, out_grad, axis, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(2, std::vector<ir::OpResult>(1));
for (size_t i = 0; i < 2; ++i) {
if (tensor_res[i][0].defined()) {
res[i][0] = std::static_pointer_cast<primitive::experimental::DescTensor>(
tensor_res[i][0].impl())
->getValue()
.dyn_cast<ir::OpResult>();
}
}
return res;
}

std::vector<std::vector<ir::OpResult>> Add_Op::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
Add_Op op_obj = op->dyn_cast<Add_Op>();
Tensor x(std::make_shared<primitive::experimental::DescTensor>(op_obj.x()));
Tensor y(std::make_shared<primitive::experimental::DescTensor>(op_obj.y()));
Tensor out_grad(
std::make_shared<primitive::experimental::DescTensor>(out_grads[0][0]));
int axis = -1;

std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::add_vjp(x, y, out_grad, axis, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(2, std::vector<ir::OpResult>(1));
for (size_t i = 0; i < 2; ++i) {
if (tensor_res[i][0].defined()) {
res[i][0] = std::static_pointer_cast<primitive::experimental::DescTensor>(
tensor_res[i][0].impl())
->getValue()
.dyn_cast<ir::OpResult>();
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

自动代码生成时展开为两层循环,以确保输出两层都遍历到

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码生成的时候根据算子来决定即可

return res;
}
} // namespace dialect
} // namespace paddle
25 changes: 25 additions & 0 deletions paddle/fluid/primitive/backend/static_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,31 @@ Tensor mean_grad<DescTensor>(const Tensor& x,
return Tensor(std::make_shared<primitive::experimental::DescTensor>(op_res));
}

template <>
std::tuple<Tensor, Tensor> add_grad<DescTensor>(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis) {
ir::OpResult x_res = std::static_pointer_cast<DescTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult y_res = std::static_pointer_cast<DescTensor>(y.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult out_grad_res =
std::static_pointer_cast<DescTensor>(out_grad.impl())
->getValue()
.dyn_cast<ir::OpResult>();

std::tuple<ir::OpResult, ir::OpResult> op_res =
paddle::dialect::add_grad(x_res, y_res, out_grad_res, axis);

return std::make_tuple(
Tensor(std::make_shared<primitive::experimental::DescTensor>(
std::get<0>(op_res))),
Tensor(std::make_shared<primitive::experimental::DescTensor>(
std::get<1>(op_res))));
}
} // namespace experimental
} // namespace backend
} // namespace primitive
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/primitive/backend/static_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ Tensor mean_grad(const Tensor& x,
const IntArray& axis = {},
bool keepdim = false,
bool reduce_all = false);

template <typename T>
std::tuple<Tensor, Tensor> add_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis);
} // namespace experimental
} // namespace backend
} // namespace primitive
Expand Down
41 changes: 41 additions & 0 deletions paddle/fluid/primitive/rule/vjp/vjp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,47 @@ std::vector<std::vector<paddle::Tensor>> mean_vjp(
return vjp_res;
}

std::vector<std::vector<paddle::Tensor>> add_vjp(
const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis,
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res(
2, std::vector<paddle::Tensor>(1));
// get mean_grad res.
std::tuple<Tensor, Tensor> op_res =
backend::experimental::add_grad<primitive::experimental::DescTensor>(
x, y, out_grad, axis);

// set op stop_gradient info
// TODO(wanghao107): Replace with more generic code.
// Support set stop_gradients for all ops.
ir::Operation* grad_op =
std::static_pointer_cast<primitive::experimental::DescTensor>(
std::get<0>(op_res).impl())
->getValue()
.dyn_cast<ir::OpResult>()
.owner();
std::vector<ir::Attribute> ir_stop_gradients(2);
for (size_t i = 0; i < 2; i++) {
if (stop_gradients[i][0]) {
ir_stop_gradients[i] =
ir::BoolAttribute::get(ir::IrContext::Instance(), true);
} else {
ir_stop_gradients[i] =
ir::BoolAttribute::get(ir::IrContext::Instance(), false);
}
}
grad_op->set_attribute(
"stop_gradient",
ir::ArrayAttribute::get(ir::IrContext::Instance(), ir_stop_gradients));

// construct vjp result by op result and stop_gradients info
vjp_res[0][0] = !stop_gradients[0][0] ? std::get<0>(op_res) : vjp_res[0][0];
vjp_res[1][0] = !stop_gradients[1][0] ? std::get<1>(op_res) : vjp_res[1][0];
return vjp_res;
}
} // namespace experimental
} // namespace primitive
} // namespace paddle
7 changes: 7 additions & 0 deletions paddle/fluid/primitive/rule/vjp/vjp.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ std::vector<std::vector<paddle::Tensor>> mean_vjp(
bool reduce_all,
const std::vector<std::vector<bool>>& stop_gradients);

std::vector<std::vector<paddle::Tensor>> add_vjp(
const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis,
const std::vector<std::vector<bool>>& stop_gradients);

namespace details {
// NOTE: this namespace will store
// primitive ops grad composite rules.
Expand Down
129 changes: 129 additions & 0 deletions test/cpp/prim/test_vjp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ PD_DECLARE_KERNEL(tanh, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(mean, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh_grad, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(mean_grad, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add_grad, CPU, ALL_LAYOUT);

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -204,5 +205,133 @@ TEST(VJP, MeanBackwardTest) {
ASSERT_EQ(grad_out_tensor.data<float>()[3], 0.25);
}

TEST(VJP, AddBackwardTest) {
ir::IrContext* ctx = ir::IrContext::Instance();
ir::Program program((ctx));
paddle::dialect::APIBuilder::Instance().SetProgram(&program);

std::shared_ptr<ir::Builder> builder =
paddle::dialect::APIBuilder::Instance().GetBuilder();
paddle::dialect::FullOp op1 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace());
paddle::dialect::FullOp op2 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace());
paddle::dialect::AddOp op3 =
builder->Build<paddle::dialect::AddOp>(op1.out(), op2.out());

paddle::dialect::FullOp op4 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace());

std::vector<std::vector<bool>> stop_gradients{{false}, {false}};
std::vector<std::vector<ir::OpResult>> out_grads{{op4.out()}};

ir::OpInfo op3_info = ctx->GetRegisteredOpInfo("pd.add");
auto add_vjp_interface_impl =
op3_info.GetInterfaceImpl<paddle::dialect::VjpInterface>();
add_vjp_interface_impl->vjp_(op3.operation(), out_grads, stop_gradients);

auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);

auto place = platform::CPUPlace();
Scope scope;

ProgramDesc prog_desc;
InterpreterCore test_core(place, {}, std::move(kernel_program), &scope);
std::stringstream os;
os << reinterpret_cast<NewIRInterpreter*>(
const_cast<InterpreterBaseImpl*>(test_core.Impl()));
std::string prefix_str = os.str();
test_core.SetSkipGcVars({prefix_str + "_inner_var_2",
prefix_str + "_inner_var_4",
prefix_str + "_inner_var_5"});
test_core.Run({});
auto out_tensor =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_2")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar(prefix_str + "_inner_var_2")
->Get<phi::DenseTensor>();
auto dx =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_4")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar(prefix_str + "_inner_var_4")
->Get<phi::DenseTensor>();

auto dy =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_5")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar(prefix_str + "_inner_var_5")
->Get<phi::DenseTensor>();
ASSERT_EQ(out_tensor.data<float>()[0], 4.0);
ASSERT_EQ(dx.data<float>()[0], 1.0);
ASSERT_EQ(dy.data<float>()[0], 1.0);
}

TEST(VJP, Add_BackwardTest) {
ir::IrContext* ctx = ir::IrContext::Instance();
ir::Program program((ctx));
paddle::dialect::APIBuilder::Instance().SetProgram(&program);

std::shared_ptr<ir::Builder> builder =
paddle::dialect::APIBuilder::Instance().GetBuilder();
paddle::dialect::FullOp op1 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace());
paddle::dialect::FullOp op2 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace());
paddle::dialect::Add_Op op3 =
builder->Build<paddle::dialect::Add_Op>(op1.out(), op2.out());

paddle::dialect::FullOp op4 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace());

std::vector<std::vector<bool>> stop_gradients{{false}, {false}};
std::vector<std::vector<ir::OpResult>> out_grads{{op4.out()}};

ir::OpInfo op3_info = ctx->GetRegisteredOpInfo("pd.add_");
auto add_inplace_vjp_interface_impl =
op3_info.GetInterfaceImpl<paddle::dialect::VjpInterface>();
add_inplace_vjp_interface_impl->vjp_(
op3.operation(), out_grads, stop_gradients);

auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);

auto place = platform::CPUPlace();
Scope scope;

ProgramDesc prog_desc;
InterpreterCore test_core(place, {}, std::move(kernel_program), &scope);
std::stringstream os;
os << reinterpret_cast<NewIRInterpreter*>(
const_cast<InterpreterBaseImpl*>(test_core.Impl()));
std::string prefix_str = os.str();
test_core.SetSkipGcVars({prefix_str + "_inner_var_0",
prefix_str + "_inner_var_3",
prefix_str + "_inner_var_4"});
test_core.Run({});
auto out_tensor =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_0")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar(prefix_str + "_inner_var_0")
->Get<phi::DenseTensor>();
auto dx =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_3")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar(prefix_str + "_inner_var_3")
->Get<phi::DenseTensor>();

auto dy =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_4")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar(prefix_str + "_inner_var_4")
->Get<phi::DenseTensor>();
ASSERT_EQ(out_tensor.data<float>()[0], 4.0);
ASSERT_EQ(dx.data<float>()[0], 1.0);
ASSERT_EQ(dy.data<float>()[0], 1.0);
}
} // namespace framework
} // namespace paddle