From 0e242c5e40d7939d32181634463d959a943a3112 Mon Sep 17 00:00:00 2001 From: zerorains Date: Tue, 17 Sep 2024 07:47:10 +0000 Subject: [PATCH] add the dynamic shape test case for roll_grad --- python/paddle/autograd/backward_utils.py | 1 + ..._sub_graph_pqrst_backward_dynamic_shape.py | 30 +++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/python/paddle/autograd/backward_utils.py b/python/paddle/autograd/backward_utils.py index 974ae63bc113c..c54040d96b02a 100644 --- a/python/paddle/autograd/backward_utils.py +++ b/python/paddle/autograd/backward_utils.py @@ -63,6 +63,7 @@ "pd_op.reduce_as", "pd_op.relu", "pd_op.reshape", + "pd_op.roll", "pd_op.rsqrt", "pd_op.scale", "pd_op.scatter", diff --git a/test/prim/pir_prim/test_prim_sub_graph_pqrst_backward_dynamic_shape.py b/test/prim/pir_prim/test_prim_sub_graph_pqrst_backward_dynamic_shape.py index 2477fa85e6071..8f9cd265acbe4 100644 --- a/test/prim/pir_prim/test_prim_sub_graph_pqrst_backward_dynamic_shape.py +++ b/test/prim/pir_prim/test_prim_sub_graph_pqrst_backward_dynamic_shape.py @@ -64,6 +64,10 @@ def reshape_net(x): return paddle.reshape(x, [30, 200 * 40]) +def roll_net(x): + return paddle.roll(x, shifts=[101, -1], axis=[0, -2]) + + def scale_net(x): return paddle.scale(x, scale=-2.3) @@ -330,6 +334,32 @@ def setUp(self): self.tol = 1e-6 +class TestPrimRollWithGrad1(TestPrimBaseWithGrad): + def setUp(self): + np.random.seed(2024) + self.op_name = "pd_op.roll_grad" + self.dtype = "float32" + self.x_shape = [100, 4, 5] + self.init_x_shape = [None, None, 5] + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.net = roll_net + self.enable_cinn = False + self.tol = 1e-6 + + +class TestPrimRollWithGrad2(TestPrimBaseWithGrad): + def setUp(self): + np.random.seed(2024) + self.op_name = "pd_op.roll_grad" + self.dtype = "float32" + self.x_shape = [100, 4, 5] + self.init_x_shape = [100, None, None] + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.net = roll_net + self.enable_cinn = False + self.tol = 1e-6 + + class TestPrimScaleWithGrad(TestPrimBaseWithGrad): def setUp(self): np.random.seed(2023)