Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR API adaptor No.48、49】 Migrate cummax/min into pir #58629

Merged
merged 4 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -495,24 +495,26 @@
data_type : out_grad

- backward_op : cummax_grad
forward : cummax(Tensor x, int axis=-1, int dtype=3) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, int axis, int dtype)
forward : cummax(Tensor x, int axis=-1, DataType dtype = DataType::INT64) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, int axis, DataType dtype)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [x]
kernel :
func : cummax_grad
data_type : out_grad

- backward_op : cummin_grad
forward : cummin(Tensor x, int axis=-1, int dtype=3) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, int axis, int dtype)
forward : cummin(Tensor x, int axis=-1, DataType dtype = DataType::INT64) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, int axis, DataType dtype)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [x]
kernel :
func : cummin_grad
data_type : out_grad

- backward_op : cumprod_grad
forward : cumprod (Tensor x, int dim) -> Tensor(out)
Expand Down
6 changes: 4 additions & 2 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -602,21 +602,23 @@
backward : cross_entropy_with_softmax_grad

- op : cummax
args : (Tensor x, int axis=-1, int dtype=3)
args : (Tensor x, int axis=-1, DataType dtype = DataType::INT64)
output : Tensor(out), Tensor(indices)
infer_meta :
func : CumWithIndicesInferMeta
kernel :
func : cummax
data_type : x
backward : cummax_grad

- op : cummin
args : (Tensor x, int axis=-1, int dtype=3)
args : (Tensor x, int axis=-1, DataType dtype = DataType::INT64)
output : Tensor(out), Tensor(indices)
infer_meta :
func : CumWithIndicesInferMeta
kernel :
func : cummin
data_type : x
backward : cummin_grad

- op : cumprod
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -549,17 +549,17 @@ void CumScalarAxisInferMeta(const MetaTensor& x,

void CumWithIndicesInferMeta(const MetaTensor& x,
int axis,
int dtype,
DataType dtype,
MetaTensor* out,
MetaTensor* indices) {
auto x_dims = x.dims();
auto indices_type = phi::TransToPhiDataType(dtype);
PADDLE_ENFORCE_EQ(
(indices_type == DataType::INT32 || indices_type == DataType::INT64),
(dtype == DataType::INT32 || dtype == DataType::INT64),
true,
phi::errors::InvalidArgument("dtype of indices must be int32 or int64"));
phi::errors::InvalidArgument(
"dtype of indices must be DataType::INT32 or DataType::INT64"));

if (indices_type == DataType::INT32) {
if (dtype == DataType::INT32) {
int _axis = 0;
if (axis < 0) {
_axis = axis + x_dims.size();
Expand Down Expand Up @@ -606,7 +606,7 @@ void CumWithIndicesInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
out->share_lod(x);
indices->set_dims(x_dims);
indices->set_dtype(indices_type);
indices->set_dtype(dtype);
indices->share_lod(x);
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ void CumScalarAxisInferMeta(const MetaTensor& x,

void CumWithIndicesInferMeta(const MetaTensor& x,
int axis,
int dtype,
DataType dtype,
MetaTensor* out,
MetaTensor* indices);

Expand Down
14 changes: 6 additions & 8 deletions paddle/phi/kernels/cpu/cum_maxmin_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,18 @@ void CummaxGradKernel(const Context& dev_ctx,
const DenseTensor& indices,
const DenseTensor& out_grad,
int axis,
int dtype,
DataType dtype,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, x_grad, static_cast<T>(0));
if (axis < 0) {
axis = axis + x.dims().size();
}
auto indices_type = phi::TransToPhiDataType(dtype);
if (indices_type == DataType::INT32) {
if (dtype == DataType::INT32) {
phi::funcs::cpu_scatter_add_kernel<T, int32_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
} else if (indices_type == DataType::INT64) {
} else if (dtype == DataType::INT64) {
phi::funcs::cpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
}
Expand All @@ -52,19 +51,18 @@ void CumminGradKernel(const Context& dev_ctx,
const DenseTensor& indices,
const DenseTensor& out_grad,
int axis,
int dtype,
DataType dtype,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, x_grad, static_cast<T>(0));
if (axis < 0) {
axis = axis + x.dims().size();
}
auto indices_type = phi::TransToPhiDataType(dtype);
if (indices_type == DataType::INT32) {
if (dtype == DataType::INT32) {
phi::funcs::cpu_scatter_add_kernel<T, int32_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
} else if (indices_type == DataType::INT64) {
} else if (dtype == DataType::INT64) {
phi::funcs::cpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
}
Expand Down
14 changes: 6 additions & 8 deletions paddle/phi/kernels/cpu/cum_maxmin_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,13 @@ template <typename T, typename Context>
void CummaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
int dtype,
DataType dtype,
DenseTensor* out,
DenseTensor* indices) {
auto indices_type = phi::TransToPhiDataType(dtype);
if (indices_type == DataType::INT32) {
if (dtype == DataType::INT32) {
ScanWithIndicesKernel<T, int32_t, std::greater_equal<T>, Context>(
dev_ctx, x, axis, out, indices);
} else if (indices_type == DataType::INT64) {
} else if (dtype == DataType::INT64) {
ScanWithIndicesKernel<T, int64_t, std::greater_equal<T>, Context>(
dev_ctx, x, axis, out, indices);
}
Expand All @@ -166,14 +165,13 @@ template <typename T, typename Context>
void CumminKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
int dtype,
DataType dtype,
DenseTensor* out,
DenseTensor* indices) {
auto indices_type = phi::TransToPhiDataType(dtype);
if (indices_type == DataType::INT32) {
if (dtype == DataType::INT32) {
ScanWithIndicesKernel<T, int32_t, std::less_equal<T>, Context>(
dev_ctx, x, axis, out, indices);
} else if (indices_type == DataType::INT64) {
} else if (dtype == DataType::INT64) {
ScanWithIndicesKernel<T, int64_t, std::less_equal<T>, Context>(
dev_ctx, x, axis, out, indices);
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/cum_maxmin_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ void CummaxGradKernel(const Context& dev_ctx,
const DenseTensor& indices,
const DenseTensor& out_grad,
int axis,
int dtype,
DataType dtype,
DenseTensor* x_grad);

template <typename T, typename Context>
Expand All @@ -33,7 +33,7 @@ void CumminGradKernel(const Context& dev_ctx,
const DenseTensor& indices,
const DenseTensor& out_grad,
int axis,
int dtype,
DataType dtype,
DenseTensor* x_grad);

} // namespace phi
4 changes: 2 additions & 2 deletions paddle/phi/kernels/cum_maxmin_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ template <typename T, typename Context>
void CummaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
int dtype,
DataType dtype,
DenseTensor* out,
DenseTensor* indices);

template <typename T, typename Context>
void CumminKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
int dtype,
DataType dtype,
DenseTensor* out,
DenseTensor* indices);

Expand Down
16 changes: 8 additions & 8 deletions paddle/phi/kernels/gpu/cum_maxmin_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ void CummaxGradKernel(const Context& dev_ctx,
const DenseTensor& indices,
const DenseTensor& out_grad,
int axis,
int dtype,
DataType dtype,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, x_grad, static_cast<T>(0));
if (axis < 0) {
axis = axis + x.dims().size();
}
auto indices_type = phi::TransToPhiDataType(dtype);
if (indices_type == DataType::INT32) {

if (dtype == DataType::INT32) {
phi::funcs::gpu_scatter_add_kernel<T, int32_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
} else if (indices_type == DataType::INT64) {
} else if (dtype == DataType::INT64) {
phi::funcs::gpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
}
Expand All @@ -52,19 +52,19 @@ void CumminGradKernel(const Context& dev_ctx,
const DenseTensor& indices,
const DenseTensor& out_grad,
int axis,
int dtype,
DataType dtype,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, x_grad, static_cast<T>(0));
if (axis < 0) {
axis = axis + x.dims().size();
}
auto indices_type = phi::TransToPhiDataType(dtype);
if (indices_type == DataType::INT32) {

if (dtype == DataType::INT32) {
phi::funcs::gpu_scatter_add_kernel<T, int32_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
} else if (indices_type == DataType::INT64) {
} else if (dtype == DataType::INT64) {
phi::funcs::gpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
}
Expand Down
14 changes: 6 additions & 8 deletions paddle/phi/kernels/gpu/cum_maxmin_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -312,17 +312,16 @@ template <typename T, typename Context>
void CummaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
int dtype,
DataType dtype,
DenseTensor* out,
DenseTensor* indices) {
auto indices_type = phi::TransToPhiDataType(dtype);
T init = std::is_floating_point<T>::value
? (-1 * std::numeric_limits<T>::infinity())
: std::numeric_limits<T>::lowest();
if (indices_type == DataType::INT32) {
if (dtype == DataType::INT32) {
ScanWithIndicesKernel<T, int32_t, std::greater_equal<T>, Context>(
dev_ctx, x, axis, init, out, indices);
} else if (indices_type == DataType::INT64) {
} else if (dtype == DataType::INT64) {
ScanWithIndicesKernel<T, int64_t, std::greater_equal<T>, Context>(
dev_ctx, x, axis, init, out, indices);
}
Expand All @@ -332,16 +331,15 @@ template <typename T, typename Context>
void CumminKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
int dtype,
DataType dtype,
DenseTensor* out,
DenseTensor* indices) {
auto indices_type = phi::TransToPhiDataType(dtype);
T init = std::is_floating_point<T>::value ? std::numeric_limits<T>::infinity()
: std::numeric_limits<T>::max();
if (indices_type == DataType::INT32) {
if (dtype == DataType::INT32) {
ScanWithIndicesKernel<T, int32_t, std::less_equal<T>, Context>(
dev_ctx, x, axis, init, out, indices);
} else if (indices_type == DataType::INT64) {
} else if (dtype == DataType::INT64) {
ScanWithIndicesKernel<T, int64_t, std::less_equal<T>, Context>(
dev_ctx, x, axis, init, out, indices);
}
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4034,7 +4034,7 @@ def cummax(x, axis=None, dtype='int64', name=None):
check_dtype(dtype, 'dtype', ['int32', 'int64'], 'cummax')
dtype = convert_np_dtype_to_dtype_(dtype)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.cummax(x, axis, dtype)
else:
check_variable_and_dtype(
Expand Down Expand Up @@ -4119,7 +4119,7 @@ def cummin(x, axis=None, dtype='int64', name=None):
check_dtype(dtype, 'dtype', ['int32', 'int64'], 'cummin')
dtype = convert_np_dtype_to_dtype_(dtype)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.cummin(x, axis, dtype)
else:
check_variable_and_dtype(
Expand Down
28 changes: 15 additions & 13 deletions test/legacy_test/test_cummax_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import paddle
from paddle import base
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


def cummax_dim2(arr, axis=None):
Expand Down Expand Up @@ -91,11 +92,11 @@ def set_attrs(self):

def test_check_output(self):
paddle.enable_static()
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
paddle.enable_static()
self.check_grad(['x'], 'out')
self.check_grad(['x'], 'out', check_pir=True)


class TestCummaxOpAxis1(TestCummaxOp):
Expand Down Expand Up @@ -151,6 +152,7 @@ def run_cases(self):
np.testing.assert_array_equal(z, y.numpy())
np.testing.assert_array_equal(ind, indices.numpy())

@test_with_pir_api
def run_static(self, use_gpu=False):
with base.program_guard(base.Program()):
data_np = np.random.random((100, 100)).astype(np.float32)
Expand All @@ -163,20 +165,19 @@ def run_static(self, use_gpu=False):

place = base.CUDAPlace(0) if use_gpu else base.CPUPlace()
exe = base.Executor(place)
exe.run(base.default_startup_program())
out = exe.run(
feed={'x': data_np},
fetch_list=[
y1.name,
indices1.name,
y2.name,
indices2.name,
y3.name,
indices3.name,
y4.name,
indices4.name,
y5.name,
indices5.name,
y1,
indices1,
y2,
indices2,
y3,
indices3,
y4,
indices4,
y5,
indices5,
],
)

Expand Down Expand Up @@ -218,6 +219,7 @@ def test_errors(self):
paddle.enable_static()
with base.program_guard(base.Program()):

@test_with_pir_api
def test_x_type():
data = [1, 2, 3]
y, indices = paddle.cummax(data, axis=0)
Expand Down
Loading