From 16380404a572b9cb202dd634259296fa54d96cfb Mon Sep 17 00:00:00 2001 From: drryanhuang Date: Thu, 2 Nov 2023 14:05:12 +0000 Subject: [PATCH 1/4] add cummax/min test --- python/paddle/tensor/math.py | 4 ++-- test/legacy_test/test_cummax_op.py | 2 ++ test/legacy_test/test_cummin_op.py | 2 ++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 4b35337d2c36a..c2ff9bb1b0ca4 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -4034,7 +4034,7 @@ def cummax(x, axis=None, dtype='int64', name=None): check_dtype(dtype, 'dtype', ['int32', 'int64'], 'cummax') dtype = convert_np_dtype_to_dtype_(dtype) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.cummax(x, axis, dtype) else: check_variable_and_dtype( @@ -4119,7 +4119,7 @@ def cummin(x, axis=None, dtype='int64', name=None): check_dtype(dtype, 'dtype', ['int32', 'int64'], 'cummin') dtype = convert_np_dtype_to_dtype_(dtype) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.cummin(x, axis, dtype) else: check_variable_and_dtype( diff --git a/test/legacy_test/test_cummax_op.py b/test/legacy_test/test_cummax_op.py index 91df4866a75a6..53ba293d3b506 100644 --- a/test/legacy_test/test_cummax_op.py +++ b/test/legacy_test/test_cummax_op.py @@ -21,6 +21,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api def cummax_dim2(arr, axis=None): @@ -218,6 +219,7 @@ def test_errors(self): paddle.enable_static() with base.program_guard(base.Program()): + @test_with_pir_api def test_x_type(): data = [1, 2, 3] y, indices = paddle.cummax(data, axis=0) diff --git a/test/legacy_test/test_cummin_op.py b/test/legacy_test/test_cummin_op.py index 416e4c48f0fc0..ac050f4b2361c 100644 --- a/test/legacy_test/test_cummin_op.py +++ b/test/legacy_test/test_cummin_op.py @@ -21,6 +21,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api def cummin_dim2(arr, axis=None): @@ -218,6 +219,7 @@ def test_errors(self): paddle.enable_static() with base.program_guard(base.Program()): + @test_with_pir_api def test_x_type(): data = [1, 2, 3] y, indices = paddle.cummin(data, axis=0) From 96707d6355876e82e6692c723ebdb8b6e6e708de Mon Sep 17 00:00:00 2001 From: drryanhuang Date: Fri, 3 Nov 2023 08:00:16 +0000 Subject: [PATCH 2/4] fix dtype --- paddle/phi/api/yaml/backward.yaml | 8 ++++---- paddle/phi/api/yaml/ops.yaml | 4 ++-- paddle/phi/infermeta/unary.cc | 12 ++++++------ paddle/phi/infermeta/unary.h | 2 +- paddle/phi/kernels/cpu/cum_maxmin_grad_kernel.cc | 14 ++++++-------- paddle/phi/kernels/cpu/cum_maxmin_kernel.cc | 14 ++++++-------- paddle/phi/kernels/cum_maxmin_grad_kernel.h | 4 ++-- paddle/phi/kernels/cum_maxmin_kernel.h | 4 ++-- paddle/phi/kernels/gpu/cum_maxmin_grad_kernel.cu | 16 ++++++++-------- paddle/phi/kernels/gpu/cum_maxmin_kernel.cu | 14 ++++++-------- 10 files changed, 43 insertions(+), 49 deletions(-) diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 241fafbf9cdf5..07cb4c16934d0 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -495,8 +495,8 @@ data_type : out_grad - backward_op : cummax_grad - forward : cummax(Tensor x, int axis=-1, int dtype=3) -> Tensor(out), Tensor(indices) - args : (Tensor x, Tensor indices, Tensor out_grad, int axis, int dtype) + forward : cummax(Tensor x, int axis=-1, DataType dtype = DataType::INT64) -> Tensor(out), Tensor(indices) + args : (Tensor x, Tensor indices, Tensor out_grad, int axis, DataType dtype) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta @@ -505,8 +505,8 @@ func : cummax_grad - backward_op : cummin_grad - forward : cummin(Tensor x, int axis=-1, int dtype=3) -> Tensor(out), Tensor(indices) - args : (Tensor x, Tensor indices, Tensor out_grad, int axis, int dtype) + forward : cummin(Tensor x, int axis=-1, DataType dtype = DataType::INT64) -> Tensor(out), Tensor(indices) + args : (Tensor x, Tensor indices, Tensor out_grad, int axis, DataType dtype) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 5a0c6abc7688b..75d1a01c3a57d 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -602,7 +602,7 @@ backward : cross_entropy_with_softmax_grad - op : cummax - args : (Tensor x, int axis=-1, int dtype=3) + args : (Tensor x, int axis=-1, DataType dtype = DataType::INT64) output : Tensor(out), Tensor(indices) infer_meta : func : CumWithIndicesInferMeta @@ -611,7 +611,7 @@ backward : cummax_grad - op : cummin - args : (Tensor x, int axis=-1, int dtype=3) + args : (Tensor x, int axis=-1, DataType dtype = DataType::INT64) output : Tensor(out), Tensor(indices) infer_meta : func : CumWithIndicesInferMeta diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 8873a617ef303..0308093ed9fc6 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -549,17 +549,17 @@ void CumScalarAxisInferMeta(const MetaTensor& x, void CumWithIndicesInferMeta(const MetaTensor& x, int axis, - int dtype, + DataType dtype, MetaTensor* out, MetaTensor* indices) { auto x_dims = x.dims(); - auto indices_type = phi::TransToPhiDataType(dtype); PADDLE_ENFORCE_EQ( - (indices_type == DataType::INT32 || indices_type == DataType::INT64), + (dtype == DataType::INT32 || dtype == DataType::INT64), true, - phi::errors::InvalidArgument("dtype of indices must be int32 or int64")); + phi::errors::InvalidArgument( + "dtype of indices must be DataType::INT32 or DataType::INT64")); - if (indices_type == DataType::INT32) { + if (dtype == DataType::INT32) { int _axis = 0; if (axis < 0) { _axis = axis + x_dims.size(); @@ -606,7 +606,7 @@ void CumWithIndicesInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); out->share_lod(x); indices->set_dims(x_dims); - indices->set_dtype(indices_type); + indices->set_dtype(dtype); indices->share_lod(x); } diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 8a28d454e42f7..70cfefa2a1daa 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -137,7 +137,7 @@ void CumScalarAxisInferMeta(const MetaTensor& x, void CumWithIndicesInferMeta(const MetaTensor& x, int axis, - int dtype, + DataType dtype, MetaTensor* out, MetaTensor* indices); diff --git a/paddle/phi/kernels/cpu/cum_maxmin_grad_kernel.cc b/paddle/phi/kernels/cpu/cum_maxmin_grad_kernel.cc index 88fb4f4feb91f..acd84a80be2ad 100644 --- a/paddle/phi/kernels/cpu/cum_maxmin_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/cum_maxmin_grad_kernel.cc @@ -28,7 +28,7 @@ void CummaxGradKernel(const Context& dev_ctx, const DenseTensor& indices, const DenseTensor& out_grad, int axis, - int dtype, + DataType dtype, DenseTensor* x_grad) { dev_ctx.template Alloc(x_grad); phi::funcs::SetConstant functor; @@ -36,11 +36,10 @@ void CummaxGradKernel(const Context& dev_ctx, if (axis < 0) { axis = axis + x.dims().size(); } - auto indices_type = phi::TransToPhiDataType(dtype); - if (indices_type == DataType::INT32) { + if (dtype == DataType::INT32) { phi::funcs::cpu_scatter_add_kernel( *x_grad, axis, indices, out_grad, dev_ctx); - } else if (indices_type == DataType::INT64) { + } else if (dtype == DataType::INT64) { phi::funcs::cpu_scatter_add_kernel( *x_grad, axis, indices, out_grad, dev_ctx); } @@ -52,7 +51,7 @@ void CumminGradKernel(const Context& dev_ctx, const DenseTensor& indices, const DenseTensor& out_grad, int axis, - int dtype, + DataType dtype, DenseTensor* x_grad) { dev_ctx.template Alloc(x_grad); phi::funcs::SetConstant functor; @@ -60,11 +59,10 @@ void CumminGradKernel(const Context& dev_ctx, if (axis < 0) { axis = axis + x.dims().size(); } - auto indices_type = phi::TransToPhiDataType(dtype); - if (indices_type == DataType::INT32) { + if (dtype == DataType::INT32) { phi::funcs::cpu_scatter_add_kernel( *x_grad, axis, indices, out_grad, dev_ctx); - } else if (indices_type == DataType::INT64) { + } else if (dtype == DataType::INT64) { phi::funcs::cpu_scatter_add_kernel( *x_grad, axis, indices, out_grad, dev_ctx); } diff --git a/paddle/phi/kernels/cpu/cum_maxmin_kernel.cc b/paddle/phi/kernels/cpu/cum_maxmin_kernel.cc index be1cfe3d86b1f..881664601b85c 100644 --- a/paddle/phi/kernels/cpu/cum_maxmin_kernel.cc +++ b/paddle/phi/kernels/cpu/cum_maxmin_kernel.cc @@ -149,14 +149,13 @@ template void CummaxKernel(const Context& dev_ctx, const DenseTensor& x, int axis, - int dtype, + DataType dtype, DenseTensor* out, DenseTensor* indices) { - auto indices_type = phi::TransToPhiDataType(dtype); - if (indices_type == DataType::INT32) { + if (dtype == DataType::INT32) { ScanWithIndicesKernel, Context>( dev_ctx, x, axis, out, indices); - } else if (indices_type == DataType::INT64) { + } else if (dtype == DataType::INT64) { ScanWithIndicesKernel, Context>( dev_ctx, x, axis, out, indices); } @@ -166,14 +165,13 @@ template void CumminKernel(const Context& dev_ctx, const DenseTensor& x, int axis, - int dtype, + DataType dtype, DenseTensor* out, DenseTensor* indices) { - auto indices_type = phi::TransToPhiDataType(dtype); - if (indices_type == DataType::INT32) { + if (dtype == DataType::INT32) { ScanWithIndicesKernel, Context>( dev_ctx, x, axis, out, indices); - } else if (indices_type == DataType::INT64) { + } else if (dtype == DataType::INT64) { ScanWithIndicesKernel, Context>( dev_ctx, x, axis, out, indices); } diff --git a/paddle/phi/kernels/cum_maxmin_grad_kernel.h b/paddle/phi/kernels/cum_maxmin_grad_kernel.h index 13a6b7ee6ec1e..a018a3bfcc940 100644 --- a/paddle/phi/kernels/cum_maxmin_grad_kernel.h +++ b/paddle/phi/kernels/cum_maxmin_grad_kernel.h @@ -24,7 +24,7 @@ void CummaxGradKernel(const Context& dev_ctx, const DenseTensor& indices, const DenseTensor& out_grad, int axis, - int dtype, + DataType dtype, DenseTensor* x_grad); template @@ -33,7 +33,7 @@ void CumminGradKernel(const Context& dev_ctx, const DenseTensor& indices, const DenseTensor& out_grad, int axis, - int dtype, + DataType dtype, DenseTensor* x_grad); } // namespace phi diff --git a/paddle/phi/kernels/cum_maxmin_kernel.h b/paddle/phi/kernels/cum_maxmin_kernel.h index 37755deb5d91e..19e3fc9da0b80 100644 --- a/paddle/phi/kernels/cum_maxmin_kernel.h +++ b/paddle/phi/kernels/cum_maxmin_kernel.h @@ -22,7 +22,7 @@ template void CummaxKernel(const Context& dev_ctx, const DenseTensor& x, int axis, - int dtype, + DataType dtype, DenseTensor* out, DenseTensor* indices); @@ -30,7 +30,7 @@ template void CumminKernel(const Context& dev_ctx, const DenseTensor& x, int axis, - int dtype, + DataType dtype, DenseTensor* out, DenseTensor* indices); diff --git a/paddle/phi/kernels/gpu/cum_maxmin_grad_kernel.cu b/paddle/phi/kernels/gpu/cum_maxmin_grad_kernel.cu index a89373c607f7d..f8dc67f5bafe8 100644 --- a/paddle/phi/kernels/gpu/cum_maxmin_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/cum_maxmin_grad_kernel.cu @@ -28,7 +28,7 @@ void CummaxGradKernel(const Context& dev_ctx, const DenseTensor& indices, const DenseTensor& out_grad, int axis, - int dtype, + DataType dtype, DenseTensor* x_grad) { dev_ctx.template Alloc(x_grad); phi::funcs::SetConstant functor; @@ -36,11 +36,11 @@ void CummaxGradKernel(const Context& dev_ctx, if (axis < 0) { axis = axis + x.dims().size(); } - auto indices_type = phi::TransToPhiDataType(dtype); - if (indices_type == DataType::INT32) { + + if (dtype == DataType::INT32) { phi::funcs::gpu_scatter_add_kernel( *x_grad, axis, indices, out_grad, dev_ctx); - } else if (indices_type == DataType::INT64) { + } else if (dtype == DataType::INT64) { phi::funcs::gpu_scatter_add_kernel( *x_grad, axis, indices, out_grad, dev_ctx); } @@ -52,7 +52,7 @@ void CumminGradKernel(const Context& dev_ctx, const DenseTensor& indices, const DenseTensor& out_grad, int axis, - int dtype, + DataType dtype, DenseTensor* x_grad) { dev_ctx.template Alloc(x_grad); phi::funcs::SetConstant functor; @@ -60,11 +60,11 @@ void CumminGradKernel(const Context& dev_ctx, if (axis < 0) { axis = axis + x.dims().size(); } - auto indices_type = phi::TransToPhiDataType(dtype); - if (indices_type == DataType::INT32) { + + if (dtype == DataType::INT32) { phi::funcs::gpu_scatter_add_kernel( *x_grad, axis, indices, out_grad, dev_ctx); - } else if (indices_type == DataType::INT64) { + } else if (dtype == DataType::INT64) { phi::funcs::gpu_scatter_add_kernel( *x_grad, axis, indices, out_grad, dev_ctx); } diff --git a/paddle/phi/kernels/gpu/cum_maxmin_kernel.cu b/paddle/phi/kernels/gpu/cum_maxmin_kernel.cu index bf836af72c58f..49903bde6ff99 100644 --- a/paddle/phi/kernels/gpu/cum_maxmin_kernel.cu +++ b/paddle/phi/kernels/gpu/cum_maxmin_kernel.cu @@ -312,17 +312,16 @@ template void CummaxKernel(const Context& dev_ctx, const DenseTensor& x, int axis, - int dtype, + DataType dtype, DenseTensor* out, DenseTensor* indices) { - auto indices_type = phi::TransToPhiDataType(dtype); T init = std::is_floating_point::value ? (-1 * std::numeric_limits::infinity()) : std::numeric_limits::lowest(); - if (indices_type == DataType::INT32) { + if (dtype == DataType::INT32) { ScanWithIndicesKernel, Context>( dev_ctx, x, axis, init, out, indices); - } else if (indices_type == DataType::INT64) { + } else if (dtype == DataType::INT64) { ScanWithIndicesKernel, Context>( dev_ctx, x, axis, init, out, indices); } @@ -332,16 +331,15 @@ template void CumminKernel(const Context& dev_ctx, const DenseTensor& x, int axis, - int dtype, + DataType dtype, DenseTensor* out, DenseTensor* indices) { - auto indices_type = phi::TransToPhiDataType(dtype); T init = std::is_floating_point::value ? std::numeric_limits::infinity() : std::numeric_limits::max(); - if (indices_type == DataType::INT32) { + if (dtype == DataType::INT32) { ScanWithIndicesKernel, Context>( dev_ctx, x, axis, init, out, indices); - } else if (indices_type == DataType::INT64) { + } else if (dtype == DataType::INT64) { ScanWithIndicesKernel, Context>( dev_ctx, x, axis, init, out, indices); } From 581f84b2294cb41dea9cebcc0ff77452b78ee996 Mon Sep 17 00:00:00 2001 From: drryanhuang Date: Fri, 3 Nov 2023 08:49:38 +0000 Subject: [PATCH 3/4] fix data_type & add test --- paddle/phi/api/yaml/backward.yaml | 2 ++ paddle/phi/api/yaml/ops.yaml | 2 ++ test/legacy_test/test_cummax_op.py | 5 +++-- test/legacy_test/test_cummin_op.py | 5 +++-- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 07cb4c16934d0..a379c16314798 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -503,6 +503,7 @@ param: [x] kernel : func : cummax_grad + data_type : out_grad - backward_op : cummin_grad forward : cummin(Tensor x, int axis=-1, DataType dtype = DataType::INT64) -> Tensor(out), Tensor(indices) @@ -513,6 +514,7 @@ param: [x] kernel : func : cummin_grad + data_type : out_grad - backward_op : cumprod_grad forward : cumprod (Tensor x, int dim) -> Tensor(out) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 75d1a01c3a57d..15021e827e796 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -608,6 +608,7 @@ func : CumWithIndicesInferMeta kernel : func : cummax + data_type : x backward : cummax_grad - op : cummin @@ -617,6 +618,7 @@ func : CumWithIndicesInferMeta kernel : func : cummin + data_type : x backward : cummin_grad - op : cumprod diff --git a/test/legacy_test/test_cummax_op.py b/test/legacy_test/test_cummax_op.py index 53ba293d3b506..b12e0f0ba306d 100644 --- a/test/legacy_test/test_cummax_op.py +++ b/test/legacy_test/test_cummax_op.py @@ -92,11 +92,11 @@ def set_attrs(self): def test_check_output(self): paddle.enable_static() - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): paddle.enable_static() - self.check_grad(['x'], 'out') + self.check_grad(['x'], 'out', check_pir=True) class TestCummaxOpAxis1(TestCummaxOp): @@ -152,6 +152,7 @@ def run_cases(self): np.testing.assert_array_equal(z, y.numpy()) np.testing.assert_array_equal(ind, indices.numpy()) + @test_with_pir_api def run_static(self, use_gpu=False): with base.program_guard(base.Program()): data_np = np.random.random((100, 100)).astype(np.float32) diff --git a/test/legacy_test/test_cummin_op.py b/test/legacy_test/test_cummin_op.py index ac050f4b2361c..8acf573584f94 100644 --- a/test/legacy_test/test_cummin_op.py +++ b/test/legacy_test/test_cummin_op.py @@ -92,11 +92,11 @@ def set_attrs(self): def test_check_output(self): paddle.enable_static() - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): paddle.enable_static() - self.check_grad(['x'], 'out') + self.check_grad(['x'], 'out', check_pir=True) class TestCuinOpAxis1(TestCumminOp): @@ -152,6 +152,7 @@ def run_cases(self): np.testing.assert_array_equal(z, y.numpy()) np.testing.assert_array_equal(ind, indices.numpy()) + @test_with_pir_api def run_static(self, use_gpu=False): with base.program_guard(base.Program()): data_np = np.random.random((100, 100)).astype(np.float32) From 4108f2432e798b3af84c8221a42da5ba4bf210e0 Mon Sep 17 00:00:00 2001 From: drryanhuang Date: Fri, 3 Nov 2023 12:25:11 +0000 Subject: [PATCH 4/4] must be pir.Program --- test/legacy_test/test_cummax_op.py | 21 ++++++++++----------- test/legacy_test/test_cummin_op.py | 21 ++++++++++----------- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/test/legacy_test/test_cummax_op.py b/test/legacy_test/test_cummax_op.py index b12e0f0ba306d..89429cf347096 100644 --- a/test/legacy_test/test_cummax_op.py +++ b/test/legacy_test/test_cummax_op.py @@ -165,20 +165,19 @@ def run_static(self, use_gpu=False): place = base.CUDAPlace(0) if use_gpu else base.CPUPlace() exe = base.Executor(place) - exe.run(base.default_startup_program()) out = exe.run( feed={'x': data_np}, fetch_list=[ - y1.name, - indices1.name, - y2.name, - indices2.name, - y3.name, - indices3.name, - y4.name, - indices4.name, - y5.name, - indices5.name, + y1, + indices1, + y2, + indices2, + y3, + indices3, + y4, + indices4, + y5, + indices5, ], ) diff --git a/test/legacy_test/test_cummin_op.py b/test/legacy_test/test_cummin_op.py index 8acf573584f94..d8e5512cbf9b4 100644 --- a/test/legacy_test/test_cummin_op.py +++ b/test/legacy_test/test_cummin_op.py @@ -165,20 +165,19 @@ def run_static(self, use_gpu=False): place = base.CUDAPlace(0) if use_gpu else base.CPUPlace() exe = base.Executor(place) - exe.run(base.default_startup_program()) out = exe.run( feed={'x': data_np}, fetch_list=[ - y1.name, - indices1.name, - y2.name, - indices2.name, - y3.name, - indices3.name, - y4.name, - indices4.name, - y5.name, - indices5.name, + y1, + indices1, + y2, + indices2, + y3, + indices3, + y4, + indices4, + y5, + indices5, ], )