Skip to content

Commit

Permalink
【PIR api adaptor No.48、49】 Migrate cummax/min into pir (#58629)
Browse files Browse the repository at this point in the history
  • Loading branch information
DrRyanHuang authored Nov 6, 2023
1 parent e44fb26 commit ee36554
Show file tree
Hide file tree
Showing 13 changed files with 79 additions and 77 deletions.
10 changes: 6 additions & 4 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -495,24 +495,26 @@
data_type : out_grad

- backward_op : cummax_grad
forward : cummax(Tensor x, int axis=-1, int dtype=3) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, int axis, int dtype)
forward : cummax(Tensor x, int axis=-1, DataType dtype = DataType::INT64) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, int axis, DataType dtype)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [x]
kernel :
func : cummax_grad
data_type : out_grad

- backward_op : cummin_grad
forward : cummin(Tensor x, int axis=-1, int dtype=3) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, int axis, int dtype)
forward : cummin(Tensor x, int axis=-1, DataType dtype = DataType::INT64) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, int axis, DataType dtype)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [x]
kernel :
func : cummin_grad
data_type : out_grad

- backward_op : cumprod_grad
forward : cumprod (Tensor x, int dim) -> Tensor(out)
Expand Down
6 changes: 4 additions & 2 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -602,21 +602,23 @@
backward : cross_entropy_with_softmax_grad

- op : cummax
args : (Tensor x, int axis=-1, int dtype=3)
args : (Tensor x, int axis=-1, DataType dtype = DataType::INT64)
output : Tensor(out), Tensor(indices)
infer_meta :
func : CumWithIndicesInferMeta
kernel :
func : cummax
data_type : x
backward : cummax_grad

- op : cummin
args : (Tensor x, int axis=-1, int dtype=3)
args : (Tensor x, int axis=-1, DataType dtype = DataType::INT64)
output : Tensor(out), Tensor(indices)
infer_meta :
func : CumWithIndicesInferMeta
kernel :
func : cummin
data_type : x
backward : cummin_grad

- op : cumprod
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -549,17 +549,17 @@ void CumScalarAxisInferMeta(const MetaTensor& x,

void CumWithIndicesInferMeta(const MetaTensor& x,
int axis,
int dtype,
DataType dtype,
MetaTensor* out,
MetaTensor* indices) {
auto x_dims = x.dims();
auto indices_type = phi::TransToPhiDataType(dtype);
PADDLE_ENFORCE_EQ(
(indices_type == DataType::INT32 || indices_type == DataType::INT64),
(dtype == DataType::INT32 || dtype == DataType::INT64),
true,
phi::errors::InvalidArgument("dtype of indices must be int32 or int64"));
phi::errors::InvalidArgument(
"dtype of indices must be DataType::INT32 or DataType::INT64"));

if (indices_type == DataType::INT32) {
if (dtype == DataType::INT32) {
int _axis = 0;
if (axis < 0) {
_axis = axis + x_dims.size();
Expand Down Expand Up @@ -606,7 +606,7 @@ void CumWithIndicesInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
out->share_lod(x);
indices->set_dims(x_dims);
indices->set_dtype(indices_type);
indices->set_dtype(dtype);
indices->share_lod(x);
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ void CumScalarAxisInferMeta(const MetaTensor& x,

void CumWithIndicesInferMeta(const MetaTensor& x,
int axis,
int dtype,
DataType dtype,
MetaTensor* out,
MetaTensor* indices);

Expand Down
14 changes: 6 additions & 8 deletions paddle/phi/kernels/cpu/cum_maxmin_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,18 @@ void CummaxGradKernel(const Context& dev_ctx,
const DenseTensor& indices,
const DenseTensor& out_grad,
int axis,
int dtype,
DataType dtype,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, x_grad, static_cast<T>(0));
if (axis < 0) {
axis = axis + x.dims().size();
}
auto indices_type = phi::TransToPhiDataType(dtype);
if (indices_type == DataType::INT32) {
if (dtype == DataType::INT32) {
phi::funcs::cpu_scatter_add_kernel<T, int32_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
} else if (indices_type == DataType::INT64) {
} else if (dtype == DataType::INT64) {
phi::funcs::cpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
}
Expand All @@ -52,19 +51,18 @@ void CumminGradKernel(const Context& dev_ctx,
const DenseTensor& indices,
const DenseTensor& out_grad,
int axis,
int dtype,
DataType dtype,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, x_grad, static_cast<T>(0));
if (axis < 0) {
axis = axis + x.dims().size();
}
auto indices_type = phi::TransToPhiDataType(dtype);
if (indices_type == DataType::INT32) {
if (dtype == DataType::INT32) {
phi::funcs::cpu_scatter_add_kernel<T, int32_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
} else if (indices_type == DataType::INT64) {
} else if (dtype == DataType::INT64) {
phi::funcs::cpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
}
Expand Down
14 changes: 6 additions & 8 deletions paddle/phi/kernels/cpu/cum_maxmin_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,13 @@ template <typename T, typename Context>
void CummaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
int dtype,
DataType dtype,
DenseTensor* out,
DenseTensor* indices) {
auto indices_type = phi::TransToPhiDataType(dtype);
if (indices_type == DataType::INT32) {
if (dtype == DataType::INT32) {
ScanWithIndicesKernel<T, int32_t, std::greater_equal<T>, Context>(
dev_ctx, x, axis, out, indices);
} else if (indices_type == DataType::INT64) {
} else if (dtype == DataType::INT64) {
ScanWithIndicesKernel<T, int64_t, std::greater_equal<T>, Context>(
dev_ctx, x, axis, out, indices);
}
Expand All @@ -166,14 +165,13 @@ template <typename T, typename Context>
void CumminKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
int dtype,
DataType dtype,
DenseTensor* out,
DenseTensor* indices) {
auto indices_type = phi::TransToPhiDataType(dtype);
if (indices_type == DataType::INT32) {
if (dtype == DataType::INT32) {
ScanWithIndicesKernel<T, int32_t, std::less_equal<T>, Context>(
dev_ctx, x, axis, out, indices);
} else if (indices_type == DataType::INT64) {
} else if (dtype == DataType::INT64) {
ScanWithIndicesKernel<T, int64_t, std::less_equal<T>, Context>(
dev_ctx, x, axis, out, indices);
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/cum_maxmin_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ void CummaxGradKernel(const Context& dev_ctx,
const DenseTensor& indices,
const DenseTensor& out_grad,
int axis,
int dtype,
DataType dtype,
DenseTensor* x_grad);

template <typename T, typename Context>
Expand All @@ -33,7 +33,7 @@ void CumminGradKernel(const Context& dev_ctx,
const DenseTensor& indices,
const DenseTensor& out_grad,
int axis,
int dtype,
DataType dtype,
DenseTensor* x_grad);

} // namespace phi
4 changes: 2 additions & 2 deletions paddle/phi/kernels/cum_maxmin_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ template <typename T, typename Context>
void CummaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
int dtype,
DataType dtype,
DenseTensor* out,
DenseTensor* indices);

template <typename T, typename Context>
void CumminKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
int dtype,
DataType dtype,
DenseTensor* out,
DenseTensor* indices);

Expand Down
16 changes: 8 additions & 8 deletions paddle/phi/kernels/gpu/cum_maxmin_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ void CummaxGradKernel(const Context& dev_ctx,
const DenseTensor& indices,
const DenseTensor& out_grad,
int axis,
int dtype,
DataType dtype,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, x_grad, static_cast<T>(0));
if (axis < 0) {
axis = axis + x.dims().size();
}
auto indices_type = phi::TransToPhiDataType(dtype);
if (indices_type == DataType::INT32) {

if (dtype == DataType::INT32) {
phi::funcs::gpu_scatter_add_kernel<T, int32_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
} else if (indices_type == DataType::INT64) {
} else if (dtype == DataType::INT64) {
phi::funcs::gpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
}
Expand All @@ -52,19 +52,19 @@ void CumminGradKernel(const Context& dev_ctx,
const DenseTensor& indices,
const DenseTensor& out_grad,
int axis,
int dtype,
DataType dtype,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, x_grad, static_cast<T>(0));
if (axis < 0) {
axis = axis + x.dims().size();
}
auto indices_type = phi::TransToPhiDataType(dtype);
if (indices_type == DataType::INT32) {

if (dtype == DataType::INT32) {
phi::funcs::gpu_scatter_add_kernel<T, int32_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
} else if (indices_type == DataType::INT64) {
} else if (dtype == DataType::INT64) {
phi::funcs::gpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
}
Expand Down
14 changes: 6 additions & 8 deletions paddle/phi/kernels/gpu/cum_maxmin_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -312,17 +312,16 @@ template <typename T, typename Context>
void CummaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
int dtype,
DataType dtype,
DenseTensor* out,
DenseTensor* indices) {
auto indices_type = phi::TransToPhiDataType(dtype);
T init = std::is_floating_point<T>::value
? (-1 * std::numeric_limits<T>::infinity())
: std::numeric_limits<T>::lowest();
if (indices_type == DataType::INT32) {
if (dtype == DataType::INT32) {
ScanWithIndicesKernel<T, int32_t, std::greater_equal<T>, Context>(
dev_ctx, x, axis, init, out, indices);
} else if (indices_type == DataType::INT64) {
} else if (dtype == DataType::INT64) {
ScanWithIndicesKernel<T, int64_t, std::greater_equal<T>, Context>(
dev_ctx, x, axis, init, out, indices);
}
Expand All @@ -332,16 +331,15 @@ template <typename T, typename Context>
void CumminKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
int dtype,
DataType dtype,
DenseTensor* out,
DenseTensor* indices) {
auto indices_type = phi::TransToPhiDataType(dtype);
T init = std::is_floating_point<T>::value ? std::numeric_limits<T>::infinity()
: std::numeric_limits<T>::max();
if (indices_type == DataType::INT32) {
if (dtype == DataType::INT32) {
ScanWithIndicesKernel<T, int32_t, std::less_equal<T>, Context>(
dev_ctx, x, axis, init, out, indices);
} else if (indices_type == DataType::INT64) {
} else if (dtype == DataType::INT64) {
ScanWithIndicesKernel<T, int64_t, std::less_equal<T>, Context>(
dev_ctx, x, axis, init, out, indices);
}
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4035,7 +4035,7 @@ def cummax(x, axis=None, dtype='int64', name=None):
check_dtype(dtype, 'dtype', ['int32', 'int64'], 'cummax')
dtype = convert_np_dtype_to_dtype_(dtype)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.cummax(x, axis, dtype)
else:
check_variable_and_dtype(
Expand Down Expand Up @@ -4120,7 +4120,7 @@ def cummin(x, axis=None, dtype='int64', name=None):
check_dtype(dtype, 'dtype', ['int32', 'int64'], 'cummin')
dtype = convert_np_dtype_to_dtype_(dtype)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.cummin(x, axis, dtype)
else:
check_variable_and_dtype(
Expand Down
28 changes: 15 additions & 13 deletions test/legacy_test/test_cummax_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import paddle
from paddle import base
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


def cummax_dim2(arr, axis=None):
Expand Down Expand Up @@ -91,11 +92,11 @@ def set_attrs(self):

def test_check_output(self):
paddle.enable_static()
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
paddle.enable_static()
self.check_grad(['x'], 'out')
self.check_grad(['x'], 'out', check_pir=True)


class TestCummaxOpAxis1(TestCummaxOp):
Expand Down Expand Up @@ -151,6 +152,7 @@ def run_cases(self):
np.testing.assert_array_equal(z, y.numpy())
np.testing.assert_array_equal(ind, indices.numpy())

@test_with_pir_api
def run_static(self, use_gpu=False):
with base.program_guard(base.Program()):
data_np = np.random.random((100, 100)).astype(np.float32)
Expand All @@ -163,20 +165,19 @@ def run_static(self, use_gpu=False):

place = base.CUDAPlace(0) if use_gpu else base.CPUPlace()
exe = base.Executor(place)
exe.run(base.default_startup_program())
out = exe.run(
feed={'x': data_np},
fetch_list=[
y1.name,
indices1.name,
y2.name,
indices2.name,
y3.name,
indices3.name,
y4.name,
indices4.name,
y5.name,
indices5.name,
y1,
indices1,
y2,
indices2,
y3,
indices3,
y4,
indices4,
y5,
indices5,
],
)

Expand Down Expand Up @@ -218,6 +219,7 @@ def test_errors(self):
paddle.enable_static()
with base.program_guard(base.Program()):

@test_with_pir_api
def test_x_type():
data = [1, 2, 3]
y, indices = paddle.cummax(data, axis=0)
Expand Down
Loading

0 comments on commit ee36554

Please sign in to comment.