Skip to content

Commit

Permalink
Auto codegen for supporting calling new_ir api in static operants (Pa…
Browse files Browse the repository at this point in the history
…ddlePaddle#56955)

* support new ir primitive operator in static operants

* support more vjp code gen

* support more vjp code gen

* support more vjp code gen

* use code gen

* fix operants codegen

* support more vjp code gen

* Fix ci build error

* set FLAGS_tensor_operants_mode to static in generated_vjp for testing

* fix bugs

* change the order of ops_name of divide_grad

* replace FLAGS_enable_new_ir_in_executor by FLAGS_enable_new_ir_api in codegen and test_vjp_prim

---------

Co-authored-by: Charles-hit <wanghao107@baidu.com>
Co-authored-by: 0x45f <wangzhen45@baidu.com>
  • Loading branch information
3 people authored Sep 6, 2023
1 parent c62902e commit 3eafa1f
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 22 deletions.
82 changes: 68 additions & 14 deletions paddle/fluid/prim/api/auto_code_generated/tensor_operants_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ class StaticTensorOperants : public TensorOperantsBase {
#include "paddle/fluid/prim/api/manual_prim/prim_manual_api.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/fluid/primitive/backend/backend.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
PHI_DECLARE_bool(enable_new_ir_api);
"""


Expand All @@ -219,47 +224,88 @@ class StaticTensorOperants : public TensorOperantsBase {
namespace prim {
using DescTensor = paddle::prim::DescTensor;
using LazyTensor = paddle::primitive::LazyTensor;
Tensor StaticTensorOperants::add(const Tensor& x, const Scalar& y) {
return paddle::prim::add<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
if (FLAGS_enable_new_ir_api) {
return paddle::primitive::backend::add<LazyTensor>(x, paddle::primitive::backend::full<LazyTensor>(x.shape(), y, x.dtype(), x.place()));
} else {
return paddle::prim::add<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
}
}
Tensor StaticTensorOperants::subtract(const Tensor& x, const Scalar& y) {
return paddle::prim::subtract<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
if (FLAGS_enable_new_ir_api) {
return paddle::primitive::backend::subtract<LazyTensor>(x, paddle::primitive::backend::full<LazyTensor>(x.shape(), y, x.dtype(), x.place()));
} else {
return paddle::prim::subtract<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
}
}
Tensor StaticTensorOperants::multiply(const Tensor& x, const Scalar& y) {
return paddle::prim::scale<DescTensor>(x, y, 0.0f, true);
if (FLAGS_enable_new_ir_api) {
return paddle::primitive::backend::scale<LazyTensor>(x, y, 0.0f, true);
} else {
return paddle::prim::scale<DescTensor>(x, y, 0.0f, true);
}
}
Tensor StaticTensorOperants::divide(const Tensor& x, const Scalar& y) {
return paddle::prim::divide<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
if (FLAGS_enable_new_ir_api) {
return paddle::primitive::backend::divide<LazyTensor>(x, paddle::primitive::backend::full<LazyTensor>(x.shape(), y, x.dtype(), x.place()));
} else {
return paddle::prim::divide<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
}
}
Tensor StaticTensorOperants::add(const Scalar& x, const Tensor& y) {
return paddle::prim::add<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
if (FLAGS_enable_new_ir_api) {
return paddle::primitive::backend::add<LazyTensor>(paddle::primitive::backend::full<LazyTensor>(y.shape(), x, y.dtype(), y.place()), y);
} else {
return paddle::prim::add<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
}
}
Tensor StaticTensorOperants::subtract(const Scalar& x, const Tensor& y) {
return paddle::prim::subtract<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
if (FLAGS_enable_new_ir_api) {
return paddle::primitive::backend::subtract<LazyTensor>(paddle::primitive::backend::full<LazyTensor>(y.shape(), x, y.dtype(), y.place()), y);
} else {
return paddle::prim::subtract<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
}
}
Tensor StaticTensorOperants::multiply(const Scalar& x, const Tensor& y) {
return paddle::prim::scale<DescTensor>(y, x, 0.0f, true);
if (FLAGS_enable_new_ir_api) {
return paddle::primitive::backend::scale<LazyTensor>(y, x, 0.0f, true);
} else {
return paddle::prim::scale<DescTensor>(y, x, 0.0f, true);
}
}
Tensor StaticTensorOperants::divide(const Scalar& x, const Tensor& y) {
return paddle::prim::divide<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
if (FLAGS_enable_new_ir_api) {
return paddle::primitive::backend::divide<LazyTensor>(paddle::primitive::backend::full<LazyTensor>(y.shape(), x, y.dtype(), y.place()), y);
} else {
return paddle::prim::divide<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
}
}
Tensor StaticTensorOperants::pow(const Tensor& x, const Tensor& y) {
return paddle::prim::elementwise_pow<DescTensor>(x, y);
if (FLAGS_enable_new_ir_api) {
return paddle::primitive::backend::elementwise_pow<LazyTensor>(x, y);
} else {
return paddle::prim::elementwise_pow<DescTensor>(x, y);
}
}
Tensor StaticTensorOperants::pow(const Tensor& x, const Scalar& y) {
return paddle::prim::elementwise_pow<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
if (FLAGS_enable_new_ir_api) {
return paddle::primitive::backend::elementwise_pow<LazyTensor>(x, paddle::primitive::backend::full<LazyTensor>(x.shape(), y, x.dtype(), x.place()));
} else {
return paddle::prim::elementwise_pow<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
}
}
"""


Expand Down Expand Up @@ -339,13 +385,21 @@ def gene_eager_tensor_operants_implementation(self):

def gene_static_tensor_func_call(self):
api_func_name = self.get_api_func_name()

backend_static_func_name = (
'paddle::primitive::backend::' + api_func_name + '<LazyTensor>'
)
prim_static_func_name = (
'paddle::prim::' + api_func_name + '<DescTensor>'
)
prim_static_func_parameters = self.get_func_args()
static_func_parameters = self.get_func_args()

static_tensor_func_call = f"""if (FLAGS_enable_new_ir_api) {{
return {backend_static_func_name}({static_func_parameters});
}} else {{
return {prim_static_func_name}({static_func_parameters});
}}"""

return f"""return {prim_static_func_name}({prim_static_func_parameters});"""
return static_tensor_func_call

def gene_static_tensor_operants_implementation(self):
api_code = ""
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/prim/utils/static/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ cc_library(
cc_library(
static_tensor_operants
SRCS static_tensor_operants.cc
DEPS static_prim_api)
DEPS static_prim_api primitive_backend_static_experimental)
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
#include "paddle/fluid/primitive/type/lazy_tensor.h"
#include "paddle/fluid/primitive/utils/utils.h"
#include "paddle/ir/core/operation.h"
#include "paddle/phi/core/flags.h"

PHI_DECLARE_string(tensor_operants_mode);

namespace paddle {
namespace primitive {
Expand Down Expand Up @@ -95,6 +97,7 @@ for (size_t i=0; i< stop_gradients[0].size(); i++ ) {
{% endmacro %}

{% macro body_prim(api) %}
FLAGS_tensor_operants_mode = "static";
{% for i in range(api.outputs|length) %}
{% if api.outputs[i].typename=='Tensor' %}
paddle::Tensor* {{api.outputs[i].name}} = !stop_gradients[{{i}}][0] ? &vjp_res[{{i}}][0] : nullptr;
Expand Down
7 changes: 2 additions & 5 deletions paddle/fluid/primitive/rule/vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ void divide_grad(const Tensor& x,
Tensor* dy) {
if (dy) {
// dy = -(x/y^2) * dout
auto denominator =
elementwise_pow<T>(y, full<T>(y.shape(), 2.0, y.dtype(), y.place()));
auto dy_res = scale<T>(
multiply<T>(divide<T>(x, denominator), out_grad), -1.0, 0.0, true);
auto dy_res = -(x / y.pow(2.0)) * out_grad;
if (x.dims() != y.dims()) {
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
Expand All @@ -61,7 +58,7 @@ void divide_grad(const Tensor& x,
if (dx) {
// dx = (1/y) * dout
auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0, y.dtype());
auto dx_res = multiply<T>(divide<T>(one_tensor, y), out_grad);
auto dx_res = one_tensor / y * out_grad;
if (y.dims() != x.dims()) {
// Maybe need reduce here
auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/core/extended_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ DataType ExtendedTensor::dtype() const {

DataLayout ExtendedTensor::layout() const {
PADDLE_THROW(phi::errors::Unavailable(
"ExtendedTensor does not support `dtype` method."));
"ExtendedTensor does not support `layout` method."));
}

bool ExtendedTensor::valid() const {
Expand Down
6 changes: 5 additions & 1 deletion test/prim/new_ir_prim/test_vjp_prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class TestVjpPrim(unittest.TestCase):
def test_divide_grad_prim_case1(self):
newir_program = get_ir_divide_program()
paddle.framework.core._set_prim_backward_enabled(True)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
dout = newir_program.block().ops[-2].result(0)
out_grads = [[dout]]
stop_gradients = [[False], [False]]
Expand All @@ -83,9 +84,9 @@ def test_divide_grad_prim_case1(self):
"pd.full",
"pd.elementwise_pow",
"pd.divide",
"pd.multiply",
"pd.full",
"pd.scale",
"pd.multiply",
"pd.full_int_array",
"pd.sum",
"pd.full_int_array",
Expand All @@ -101,6 +102,7 @@ def test_divide_grad_prim_case1(self):
for idx, op in enumerate(newir_program.block().ops):
self.assertEqual(op.name(), all_op_names[idx])
paddle.framework.core._set_prim_backward_enabled(False)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})

def test_divide_grad_no_prim(self):
newir_program = get_ir_divide_program()
Expand All @@ -123,6 +125,7 @@ def test_divide_grad_no_prim(self):
def test_sum_grad_prim(self):
newir_program = get_ir_sum_program()
paddle.framework.core._set_prim_backward_enabled(True)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
dout = newir_program.block().ops[-3].result(0)
out_grads = [[dout]]
stop_gradients = [[False], [True]]
Expand All @@ -147,6 +150,7 @@ def test_sum_grad_prim(self):
for idx, op in enumerate(newir_program.block().ops):
self.assertEqual(op.name(), all_op_names[idx])
paddle.framework.core._set_prim_backward_enabled(False)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})

def test_sum_grad_no_prim(self):
newir_program = get_ir_sum_program()
Expand Down

0 comments on commit 3eafa1f

Please sign in to comment.