diff --git a/python/paddle/autograd/backward_utils.py b/python/paddle/autograd/backward_utils.py index ad530c76d1d2f..51b5781d5134f 100644 --- a/python/paddle/autograd/backward_utils.py +++ b/python/paddle/autograd/backward_utils.py @@ -63,6 +63,7 @@ "pd_op.reduce_as", "pd_op.relu", "pd_op.reshape", + "pd_op.roll", "pd_op.rsqrt", "pd_op.scale", "pd_op.scatter", diff --git a/test/prim/pir_prim/test_prim_sub_graph_pqrst_backward_dynamic_shape.py b/test/prim/pir_prim/test_prim_sub_graph_pqrst_backward_dynamic_shape.py index 018e6b019d77a..d3b09a6e21020 100644 --- a/test/prim/pir_prim/test_prim_sub_graph_pqrst_backward_dynamic_shape.py +++ b/test/prim/pir_prim/test_prim_sub_graph_pqrst_backward_dynamic_shape.py @@ -64,6 +64,10 @@ def reshape_net(x): return paddle.reshape(x, [30, 200 * 40]) +def roll_net(x): + return paddle.roll(x, shifts=[101, -1], axis=[0, -2]) + + def scale_net(x): return paddle.scale(x, scale=-2.3) @@ -334,6 +338,32 @@ def setUp(self): self.tol = 1e-6 +class TestPrimRollWithGrad1(TestPrimBaseWithGrad): + def setUp(self): + np.random.seed(2024) + self.op_name = "pd_op.roll_grad" + self.dtype = "float32" + self.x_shape = [100, 4, 5] + self.init_x_shape = [None, None, 5] + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.net = roll_net + self.enable_cinn = False + self.tol = 1e-6 + + +class TestPrimRollWithGrad2(TestPrimBaseWithGrad): + def setUp(self): + np.random.seed(2024) + self.op_name = "pd_op.roll_grad" + self.dtype = "float32" + self.x_shape = [100, 4, 5] + self.init_x_shape = [100, None, None] + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.net = roll_net + self.enable_cinn = False + self.tol = 1e-6 + + class TestPrimScaleWithGrad(TestPrimBaseWithGrad): def setUp(self): np.random.seed(2023)