diff --git a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py index 3201651e4696c..6d788eb85a5a6 100644 --- a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py +++ b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py @@ -21,4 +21,4 @@ # TODO(wanghao107) # remove this file and support Vjp methods # code gen. -vjp_interface_gen_op_list = ["tanh", "mean"] +vjp_interface_gen_op_list = ["tanh", "mean", "add"] diff --git a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc index be43ddd60491c..ee03e826b652d 100644 --- a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc +++ b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc @@ -98,5 +98,55 @@ std::vector> MeanOp::Vjp( } return res; } + +std::vector> AddOp::Vjp( + ir::Operation* op, + const std::vector>& out_grads, + const std::vector>& stop_gradients) { + AddOp op_obj = op->dyn_cast(); + Tensor x(std::make_shared(op_obj.x())); + Tensor y(std::make_shared(op_obj.y())); + Tensor out_grad( + std::make_shared(out_grads[0][0])); + int axis = -1; + + std::vector> tensor_res = + primitive::experimental::add_vjp(x, y, out_grad, axis, stop_gradients); + std::vector> res(2, std::vector(1)); + for (size_t i = 0; i < 2; ++i) { + if (tensor_res[i][0].defined()) { + res[i][0] = std::static_pointer_cast( + tensor_res[i][0].impl()) + ->getValue() + .dyn_cast(); + } + } + return res; +} + +std::vector> Add_Op::Vjp( + ir::Operation* op, + const std::vector>& out_grads, + const std::vector>& stop_gradients) { + Add_Op op_obj = op->dyn_cast(); + Tensor x(std::make_shared(op_obj.x())); + Tensor y(std::make_shared(op_obj.y())); + Tensor out_grad( + std::make_shared(out_grads[0][0])); + int axis = -1; + + std::vector> tensor_res = + primitive::experimental::add_vjp(x, y, out_grad, axis, stop_gradients); + std::vector> res(2, std::vector(1)); + for (size_t i = 0; i < 2; ++i) { + if (tensor_res[i][0].defined()) { + res[i][0] = std::static_pointer_cast( + tensor_res[i][0].impl()) + ->getValue() + .dyn_cast(); + } + } + return res; +} } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/primitive/backend/static_backend.cc b/paddle/fluid/primitive/backend/static_backend.cc index b041d3710c25d..be3f29eeca49d 100644 --- a/paddle/fluid/primitive/backend/static_backend.cc +++ b/paddle/fluid/primitive/backend/static_backend.cc @@ -59,6 +59,31 @@ Tensor mean_grad(const Tensor& x, return Tensor(std::make_shared(op_res)); } +template <> +std::tuple add_grad(const Tensor& x, + const Tensor& y, + const Tensor& out_grad, + int axis) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult y_res = std::static_pointer_cast(y.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult out_grad_res = + std::static_pointer_cast(out_grad.impl()) + ->getValue() + .dyn_cast(); + + std::tuple op_res = + paddle::dialect::add_grad(x_res, y_res, out_grad_res, axis); + + return std::make_tuple( + Tensor(std::make_shared( + std::get<0>(op_res))), + Tensor(std::make_shared( + std::get<1>(op_res)))); +} } // namespace experimental } // namespace backend } // namespace primitive diff --git a/paddle/fluid/primitive/backend/static_backend.h b/paddle/fluid/primitive/backend/static_backend.h index 09835bb759674..063532d8cd23a 100644 --- a/paddle/fluid/primitive/backend/static_backend.h +++ b/paddle/fluid/primitive/backend/static_backend.h @@ -37,6 +37,12 @@ Tensor mean_grad(const Tensor& x, const IntArray& axis = {}, bool keepdim = false, bool reduce_all = false); + +template +std::tuple add_grad(const Tensor& x, + const Tensor& y, + const Tensor& out_grad, + int axis); } // namespace experimental } // namespace backend } // namespace primitive diff --git a/paddle/fluid/primitive/rule/vjp/vjp.cc b/paddle/fluid/primitive/rule/vjp/vjp.cc index b5f0acf98c1d8..110a2d3389923 100644 --- a/paddle/fluid/primitive/rule/vjp/vjp.cc +++ b/paddle/fluid/primitive/rule/vjp/vjp.cc @@ -111,6 +111,47 @@ std::vector> mean_vjp( return vjp_res; } +std::vector> add_vjp( + const Tensor& x, + const Tensor& y, + const Tensor& out_grad, + int axis, + const std::vector>& stop_gradients) { + std::vector> vjp_res( + 2, std::vector(1)); + // get mean_grad res. + std::tuple op_res = + backend::experimental::add_grad( + x, y, out_grad, axis); + + // set op stop_gradient info + // TODO(wanghao107): Replace with more generic code. + // Support set stop_gradients for all ops. + ir::Operation* grad_op = + std::static_pointer_cast( + std::get<0>(op_res).impl()) + ->getValue() + .dyn_cast() + .owner(); + std::vector ir_stop_gradients(2); + for (size_t i = 0; i < 2; i++) { + if (stop_gradients[i][0]) { + ir_stop_gradients[i] = + ir::BoolAttribute::get(ir::IrContext::Instance(), true); + } else { + ir_stop_gradients[i] = + ir::BoolAttribute::get(ir::IrContext::Instance(), false); + } + } + grad_op->set_attribute( + "stop_gradient", + ir::ArrayAttribute::get(ir::IrContext::Instance(), ir_stop_gradients)); + + // construct vjp result by op result and stop_gradients info + vjp_res[0][0] = !stop_gradients[0][0] ? std::get<0>(op_res) : vjp_res[0][0]; + vjp_res[1][0] = !stop_gradients[1][0] ? std::get<1>(op_res) : vjp_res[1][0]; + return vjp_res; +} } // namespace experimental } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/rule/vjp/vjp.h b/paddle/fluid/primitive/rule/vjp/vjp.h index 48bc2affa9db4..8ef03c39c6eb6 100644 --- a/paddle/fluid/primitive/rule/vjp/vjp.h +++ b/paddle/fluid/primitive/rule/vjp/vjp.h @@ -46,6 +46,13 @@ std::vector> mean_vjp( bool reduce_all, const std::vector>& stop_gradients); +std::vector> add_vjp( + const Tensor& x, + const Tensor& y, + const Tensor& out_grad, + int axis, + const std::vector>& stop_gradients); + namespace details { // NOTE: this namespace will store // primitive ops grad composite rules. diff --git a/test/cpp/prim/test_vjp.cc b/test/cpp/prim/test_vjp.cc index 9eb865a579765..c4409e5dd2a0c 100644 --- a/test/cpp/prim/test_vjp.cc +++ b/test/cpp/prim/test_vjp.cc @@ -36,6 +36,7 @@ PD_DECLARE_KERNEL(tanh, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(mean, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(tanh_grad, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(mean_grad, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(add_grad, CPU, ALL_LAYOUT); namespace paddle { namespace framework { @@ -204,5 +205,133 @@ TEST(VJP, MeanBackwardTest) { ASSERT_EQ(grad_out_tensor.data()[3], 0.25); } +TEST(VJP, AddBackwardTest) { + ir::IrContext* ctx = ir::IrContext::Instance(); + ir::Program program((ctx)); + paddle::dialect::APIBuilder::Instance().SetProgram(&program); + + std::shared_ptr builder = + paddle::dialect::APIBuilder::Instance().GetBuilder(); + paddle::dialect::FullOp op1 = builder->Build( + std::vector{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); + paddle::dialect::FullOp op2 = builder->Build( + std::vector{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); + paddle::dialect::AddOp op3 = + builder->Build(op1.out(), op2.out()); + + paddle::dialect::FullOp op4 = builder->Build( + std::vector{1}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); + + std::vector> stop_gradients{{false}, {false}}; + std::vector> out_grads{{op4.out()}}; + + ir::OpInfo op3_info = ctx->GetRegisteredOpInfo("pd.add"); + auto add_vjp_interface_impl = + op3_info.GetInterfaceImpl(); + add_vjp_interface_impl->vjp_(op3.operation(), out_grads, stop_gradients); + + auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); + + auto place = platform::CPUPlace(); + Scope scope; + + ProgramDesc prog_desc; + InterpreterCore test_core(place, {}, std::move(kernel_program), &scope); + std::stringstream os; + os << reinterpret_cast( + const_cast(test_core.Impl())); + std::string prefix_str = os.str(); + test_core.SetSkipGcVars({prefix_str + "_inner_var_2", + prefix_str + "_inner_var_4", + prefix_str + "_inner_var_5"}); + test_core.Run({}); + auto out_tensor = + test_core.local_scope() == nullptr + ? scope.FindVar(prefix_str + "_inner_var_2")->Get() + : test_core.local_scope() + ->FindVar(prefix_str + "_inner_var_2") + ->Get(); + auto dx = + test_core.local_scope() == nullptr + ? scope.FindVar(prefix_str + "_inner_var_4")->Get() + : test_core.local_scope() + ->FindVar(prefix_str + "_inner_var_4") + ->Get(); + + auto dy = + test_core.local_scope() == nullptr + ? scope.FindVar(prefix_str + "_inner_var_5")->Get() + : test_core.local_scope() + ->FindVar(prefix_str + "_inner_var_5") + ->Get(); + ASSERT_EQ(out_tensor.data()[0], 4.0); + ASSERT_EQ(dx.data()[0], 1.0); + ASSERT_EQ(dy.data()[0], 1.0); +} + +TEST(VJP, Add_BackwardTest) { + ir::IrContext* ctx = ir::IrContext::Instance(); + ir::Program program((ctx)); + paddle::dialect::APIBuilder::Instance().SetProgram(&program); + + std::shared_ptr builder = + paddle::dialect::APIBuilder::Instance().GetBuilder(); + paddle::dialect::FullOp op1 = builder->Build( + std::vector{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); + paddle::dialect::FullOp op2 = builder->Build( + std::vector{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); + paddle::dialect::Add_Op op3 = + builder->Build(op1.out(), op2.out()); + + paddle::dialect::FullOp op4 = builder->Build( + std::vector{1}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); + + std::vector> stop_gradients{{false}, {false}}; + std::vector> out_grads{{op4.out()}}; + + ir::OpInfo op3_info = ctx->GetRegisteredOpInfo("pd.add_"); + auto add_inplace_vjp_interface_impl = + op3_info.GetInterfaceImpl(); + add_inplace_vjp_interface_impl->vjp_( + op3.operation(), out_grads, stop_gradients); + + auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); + + auto place = platform::CPUPlace(); + Scope scope; + + ProgramDesc prog_desc; + InterpreterCore test_core(place, {}, std::move(kernel_program), &scope); + std::stringstream os; + os << reinterpret_cast( + const_cast(test_core.Impl())); + std::string prefix_str = os.str(); + test_core.SetSkipGcVars({prefix_str + "_inner_var_0", + prefix_str + "_inner_var_3", + prefix_str + "_inner_var_4"}); + test_core.Run({}); + auto out_tensor = + test_core.local_scope() == nullptr + ? scope.FindVar(prefix_str + "_inner_var_0")->Get() + : test_core.local_scope() + ->FindVar(prefix_str + "_inner_var_0") + ->Get(); + auto dx = + test_core.local_scope() == nullptr + ? scope.FindVar(prefix_str + "_inner_var_3")->Get() + : test_core.local_scope() + ->FindVar(prefix_str + "_inner_var_3") + ->Get(); + + auto dy = + test_core.local_scope() == nullptr + ? scope.FindVar(prefix_str + "_inner_var_4")->Get() + : test_core.local_scope() + ->FindVar(prefix_str + "_inner_var_4") + ->Get(); + ASSERT_EQ(out_tensor.data()[0], 4.0); + ASSERT_EQ(dx.data()[0], 1.0); + ASSERT_EQ(dy.data()[0], 1.0); +} } // namespace framework } // namespace paddle