From ba1771b6b264678e64ce2b72b33e24517d4d60b7 Mon Sep 17 00:00:00 2001 From: zerorains Date: Wed, 26 Jun 2024 08:17:53 +0000 Subject: [PATCH 1/2] support dynamic shape for relu_grad and add the dynamic shape test cases for relu_grad and sigmoid_grad --- paddle/fluid/primitive/rule/vjp/details.h | 14 ++++---- python/paddle/autograd/backward_utils.py | 2 ++ ...t_prim_sub_graph_backward_dynamic_shape.py | 32 +++++++++++++++++++ 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/primitive/rule/vjp/details.h b/paddle/fluid/primitive/rule/vjp/details.h index 3eaafe9006efe..5513c7deaa63a 100644 --- a/paddle/fluid/primitive/rule/vjp/details.h +++ b/paddle/fluid/primitive/rule/vjp/details.h @@ -1287,12 +1287,14 @@ void masked_select_grad(const Tensor& x, template void relu_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { - auto condition = greater_than( - out, full(common::vectorize(out.dims()), 0.0, out.dtype())); - auto res = - where(condition, - out_grad, - full(common::vectorize(out.dims()), 0.0, out.dtype())); + Tensor zeros; + if (has_dynamic_shape(out.shape())) { + zeros = backend::full_with_tensor(shape(out), 0.0, out.dtype()); + } else { + zeros = full(common::vectorize(out.dims()), 0.0, out.dtype()); + } + auto mask = greater_than(out, zeros); + auto res = cast(mask, out.dtype()) * out_grad; set_output(res, x_grad); } } diff --git a/python/paddle/autograd/backward_utils.py b/python/paddle/autograd/backward_utils.py index 4810e7845f9c7..1c657ee1f61b6 100644 --- a/python/paddle/autograd/backward_utils.py +++ b/python/paddle/autograd/backward_utils.py @@ -48,6 +48,8 @@ "pd_op.concat", "pd_op.split", "pd_op.multiply", + "pd_op.relu", + "pd_op.sigmoid", ] diff --git a/test/prim/pir_prim/test_prim_sub_graph_backward_dynamic_shape.py b/test/prim/pir_prim/test_prim_sub_graph_backward_dynamic_shape.py index 2b4c35d0bf1a5..a95443330ad80 100644 --- a/test/prim/pir_prim/test_prim_sub_graph_backward_dynamic_shape.py +++ b/test/prim/pir_prim/test_prim_sub_graph_backward_dynamic_shape.py @@ -91,6 +91,14 @@ def multiply_net(x, y): return x * y +def relu_net(x): + return paddle.nn.functional.relu(x) + + +def sigmoid_net(x): + return paddle.nn.functional.sigmoid(x) + + def apply_to_static(net, use_cinn, input_spec=None): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = use_cinn @@ -782,5 +790,29 @@ def setUp(self): self.tol = 1e-5 +class TestPrimReluWithGrad(TestPrimBaseWithGrad): + def setUp(self): + np.random.seed(2023) + self.dtype = "float32" + self.x_shape = [30, 200, 40] + self.init_x_shape = [None, None, None] + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.net = relu_net + self.enable_cinn = False + self.tol = 1e-6 + + +class TestPrimSigmoidWithGrad(TestPrimBaseWithGrad): + def setUp(self): + np.random.seed(2023) + self.dtype = "float32" + self.x_shape = [30, 200, 40] + self.init_x_shape = [None, None, None] + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.net = sigmoid_net + self.enable_cinn = False + self.tol = 1e-6 + + if __name__ == "__main__": unittest.main() From b9801d2480133536959ccbe1c912a9a2d4c32062 Mon Sep 17 00:00:00 2001 From: zerorains Date: Thu, 27 Jun 2024 02:10:24 +0000 Subject: [PATCH 2/2] modify the full --- paddle/fluid/primitive/rule/vjp/details.h | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/paddle/fluid/primitive/rule/vjp/details.h b/paddle/fluid/primitive/rule/vjp/details.h index 5513c7deaa63a..95c8410b2efca 100644 --- a/paddle/fluid/primitive/rule/vjp/details.h +++ b/paddle/fluid/primitive/rule/vjp/details.h @@ -1287,12 +1287,7 @@ void masked_select_grad(const Tensor& x, template void relu_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { - Tensor zeros; - if (has_dynamic_shape(out.shape())) { - zeros = backend::full_with_tensor(shape(out), 0.0, out.dtype()); - } else { - zeros = full(common::vectorize(out.dims()), 0.0, out.dtype()); - } + Tensor zeros = full_scalar(0.0, out.dtype()); auto mask = greater_than(out, zeros); auto res = cast(mask, out.dtype()) * out_grad; set_output(res, x_grad);