Skip to content

Commit

Permalink
[Prim][PIR] Support dynamic shape for elementwise_pow_grad (#65692)
Browse files Browse the repository at this point in the history
* support the dynamic shape for elementwize_pow_grad and add the test in dynamic shape case for pow_grad

* fix the bug in the backward compose
  • Loading branch information
zeroRains authored Jul 6, 2024
1 parent 3f8331d commit 7766939
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 22 deletions.
63 changes: 41 additions & 22 deletions paddle/fluid/primitive/rule/vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -715,31 +715,43 @@ void elementwise_pow_grad(const Tensor& x,
auto lnx = log<T>(x);
auto x_pow_y = elementwise_pow<T>(x, y);
auto dy_res = lnx * x_pow_y * out_grad;
if (out_grad.dims() != y.dims()) {
phi::DDim reduce_dim =
get_reduce_dims_from_out(out_grad.dims(), y.dims());
auto dy_reduce_res =
dy_res.sum(common::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
set_output<T>(dy_tmp, dy);
if (has_dynamic_shape(out_grad.shape()) || has_dynamic_shape(y.shape())) {
auto dy_reduce_res = reduce_as<T>(dy_res, y);
set_output<T>(dy_reduce_res, dy);
} else {
set_output<T>(dy_res, dy);
if (out_grad.dims() != y.dims()) {
phi::DDim reduce_dim =
get_reduce_dims_from_out(out_grad.dims(), y.dims());
auto dy_reduce_res =
dy_res.sum(common::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
set_output<T>(dy_tmp, dy);
} else {
set_output<T>(dy_res, dy);
}
}
} // indicate we will compute dy
if (dx) {
// dx = y * x^(y-1)
auto tmp_z = y - 1.0;
auto x_pow_z = elementwise_pow<T>(x, tmp_z);
auto dx_res = y * x_pow_z * out_grad;
if (out_grad.dims() != x.dims()) {
auto reduce_dim = get_reduce_dims_from_out(out_grad.dims(), x.dims());
auto dx_reduce_res =
dx_res.sum(common::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
set_output<T>(dx_tmp, dx);

if (has_dynamic_shape(out_grad.shape()) || has_dynamic_shape(x.shape())) {
Tensor one_tensor = full_scalar<T>(1.0, y.dtype());
Tensor x_pow_z = elementwise_pow<T>(x, y - one_tensor);
Tensor dx_res = y * x_pow_z * out_grad;
auto dx_reduce_res = reduce_as<T>(dx_res, x);
set_output<T>(dx_reduce_res, dx);
} else {
set_output<T>(dx_res, dx);
auto tmp_z = y - 1.0;
auto x_pow_z = elementwise_pow<T>(x, tmp_z);
auto dx_res = y * x_pow_z * out_grad;
if (out_grad.dims() != x.dims()) {
auto reduce_dim = get_reduce_dims_from_out(out_grad.dims(), x.dims());
auto dx_reduce_res =
dx_res.sum(common::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
set_output<T>(dx_tmp, dx);
} else {
set_output<T>(dx_res, dx);
}
}
} // indicate we will compute dx
}
Expand All @@ -750,9 +762,16 @@ void pow_grad(const Tensor& x,
const Scalar& y,
Tensor* x_grad) {
if (x_grad) {
auto y_value = y.to<float>();
auto dx_res = y_value * x.pow(y_value - 1) * out_grad;
set_output<T>(dx_res, x_grad);
if (has_dynamic_shape(x.shape())) {
Tensor y_tensor = backend::full_with_tensor<T>(shape<T>(x), y, x.dtype());
Tensor one_tensor = full_scalar<T>(1.0, x.dtype());
auto dx_res = y_tensor * elementwise_pow<T>(x, y - one_tensor) * out_grad;
set_output<T>(dx_res, x_grad);
} else {
auto y_value = y.to<float>();
auto dx_res = y_value * x.pow(y_value - 1) * out_grad;
set_output<T>(dx_res, x_grad);
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions python/paddle/autograd/backward_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
"pd_op.relu",
"pd_op.sigmoid",
"pd_op.divide",
"pd_op.pow",
"pd_op.elementwise_pow",
]


Expand Down
185 changes: 185 additions & 0 deletions test/prim/pir_prim/test_prim_sub_graph_backward_dynamic_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ def divide_net(x, y):
return x / y


def elementwise_pow_net(x, y):
return paddle.pow(x, y)


def pow_net(x):
return paddle.pow(x, 3.2)


def apply_to_static(net, use_cinn, input_spec=None):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
Expand Down Expand Up @@ -1145,5 +1153,182 @@ def setUp(self):
self.tol = 1e-5


class TestPrimElementwisePowWithGrad1(TestPrimTwoWithGrad):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [1, 1, 40]
self.init_x_shape = [None, None, 40]
self.y_shape = [30, 200, 40]
self.init_y_shape = [None, None, 40]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.y = np.random.random(self.y_shape).astype(self.dtype)
self.net = elementwise_pow_net
self.enable_cinn = False
self.tol = 1e-5


class TestPrimElementwisePowWithGrad2(TestPrimTwoWithGrad):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [1, 200, 1]
self.init_x_shape = [None, None, 1]
self.y_shape = [30, 200, 40]
self.init_y_shape = [None, None, 40]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.y = np.random.random(self.y_shape).astype(self.dtype)
self.net = elementwise_pow_net
self.enable_cinn = False
self.tol = 1e-5


class TestPrimElementwisePowWithGrad3(TestPrimTwoWithGrad):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [30, 200, 1]
self.init_x_shape = [None, None, 1]
self.y_shape = [30, 200, 40]
self.init_y_shape = [None, None, 40]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.y = np.random.random(self.y_shape).astype(self.dtype)
self.net = elementwise_pow_net
self.enable_cinn = False
self.tol = 1e-5


class TestPrimElementwisePowWithGrad4(TestPrimTwoWithGrad):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [30, 200, 40]
self.init_x_shape = [None, None, 40]
self.y_shape = [1, 1, 40]
self.init_y_shape = [None, None, 40]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.y = np.random.random(self.y_shape).astype(self.dtype)
self.net = elementwise_pow_net
self.enable_cinn = False
self.tol = 1e-5


class TestPrimElementwisePowWithGrad5(TestPrimTwoWithGrad):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [30, 200, 40]
self.init_x_shape = [None, None, 40]
self.y_shape = [1, 200, 1]
self.init_y_shape = [None, None, 1]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.y = np.random.random(self.y_shape).astype(self.dtype)
self.net = elementwise_pow_net
self.enable_cinn = False
self.tol = 1e-5


class TestPrimElementwisePowWithGrad6(TestPrimTwoWithGrad):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [30, 200, 40]
self.init_x_shape = [None, None, 40]
self.y_shape = [30, 200, 1]
self.init_y_shape = [None, None, 1]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.y = np.random.random(self.y_shape).astype(self.dtype)
self.net = elementwise_pow_net
self.enable_cinn = False
self.tol = 1e-5


class TestPrimElementwisePowWithGrad7(TestPrimTwoWithGrad):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [30, 200, 40]
self.init_x_shape = [None, None, 40]
self.y_shape = [30, 200, 40]
self.init_y_shape = [None, None, 40]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.y = np.random.random(self.y_shape).astype(self.dtype)
self.net = elementwise_pow_net
self.enable_cinn = False
self.tol = 1e-5


class TestPrimElementwisePowWithGrad8(TestPrimTwoWithGrad):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [30, 200, 40]
self.init_x_shape = [None, None, 40]
self.y_shape = [40]
self.init_y_shape = [None]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.y = np.random.random(self.y_shape).astype(self.dtype)
self.net = elementwise_pow_net
self.enable_cinn = False
self.tol = 1e-5


class TestPrimElementwisePowWithGrad9(TestPrimTwoWithGrad):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [40]
self.init_x_shape = [None]
self.y_shape = [30, 200, 40]
self.init_y_shape = [None, None, 40]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.y = np.random.random(self.y_shape).astype(self.dtype)
self.net = elementwise_pow_net
self.enable_cinn = False
self.tol = 1e-5


class TestPrimElementwisePowWithGrad10(TestPrimTwoWithGrad):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [30, 200, 40]
self.init_x_shape = [None, None, None]
self.y_shape = [200, 40]
self.init_y_shape = self.y_shape
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.y = np.random.random(self.y_shape).astype(self.dtype)
self.net = elementwise_pow_net
self.enable_cinn = False
self.tol = 1e-5


class TestPrimElementwisePowWithGrad11(TestPrimTwoWithGrad):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [200, 40]
self.init_x_shape = self.x_shape
self.y_shape = [30, 200, 40]
self.init_y_shape = [None, None, None]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.y = np.random.random(self.y_shape).astype(self.dtype)
self.net = elementwise_pow_net
self.enable_cinn = False
self.tol = 1e-5


class TestPrimPowWithGrad(TestPrimBaseWithGrad):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [100, 20, 30]
self.init_x_shape = [None, None, 30]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.net = pow_net
self.enable_cinn = False
self.tol = 1e-6


if __name__ == "__main__":
unittest.main()

0 comments on commit 7766939

Please sign in to comment.