Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support zero value in dimension for slice #37313

Merged
merged 5 commits into from
Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions paddle/fluid/operators/set_value_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ class SetValueKernel : public framework::OpKernel<T> {
starts_indices[axis_index] = starts[i];
ends_indices[axis_index] = ends[i];
strides_indices[axis_index] = steps[i];
if (starts[i] == ends[i]) { // slice is empty, data will not be changed
return;
}
}

out_e.stridedSlice(starts_indices, ends_indices, strides_indices)
Expand Down
33 changes: 23 additions & 10 deletions paddle/fluid/operators/slice_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,20 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
std::vector<T>* infer_flags = nullptr) {
for (size_t i = 0; i < axes.size(); ++i) {
T axis = axes[i];
PADDLE_ENFORCE_LT(
axis, in_dims.size(),
platform::errors::InvalidArgument(
"The axis value should be less than the rank of input, "
"but received axes[%d] = %d, rank of input is %d.",
i, axis, in_dims.size()));

if (infer_flags != nullptr && (*infer_flags)[i] == -1) {
continue;
}

T dim_value = in_dims[axis];

if (dim_value > 0) {
if (infer_flags != nullptr && (*infer_flags)[i] == -1) {
continue;
}
T step = steps == nullptr ? 1 : (*steps)[i];
PADDLE_ENFORCE_NE(
step, 0, platform::errors::InvalidArgument(
Expand All @@ -51,7 +59,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
if (step > 0) {
start = std::min(start, dim_value);
end = std::max(end, static_cast<T>(0));
PADDLE_ENFORCE_GT(
PADDLE_ENFORCE_GE(
end, start,
platform::errors::InvalidArgument(
"When step > 0, end should be greater than start, but "
Expand All @@ -63,7 +71,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
// "end is -1" means contain the 0-th element of this axis.
start = std::min(start, dim_value - 1);
end = std::max(end, static_cast<T>(-1));
PADDLE_ENFORCE_GT(
PADDLE_ENFORCE_GE(
start, end,
platform::errors::InvalidArgument(
"When step < 0, start should be greater than end, but "
Expand All @@ -73,6 +81,9 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,

(*starts)[i] = start;
(*ends)[i] = end;
} else if (dim_value == 0) {
(*starts)[i] = 0;
(*ends)[i] = 0;
}
}
}
Expand Down Expand Up @@ -111,20 +122,22 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
const std::vector<T>& decrease_axes,
std::vector<T>* infer_flags = nullptr) {
framework::DDim decreased_dims(slice_dims);
std::vector<uint8_t> decrease_flag(slice_dims.size(), 0);
if (decrease_axes.size() > 0) {
for (size_t i = 0; i < decrease_axes.size(); ++i) {
T axis = decrease_axes[i];
decrease_flag[axis] = 1;
if (infer_flags && (*infer_flags)[i] != -1) {
PADDLE_ENFORCE_EQ(
decreased_dims[axis], 1,
platform::errors::InvalidArgument("decrease dim should be 1"));
PADDLE_ENFORCE_EQ(decreased_dims[axis], 1,
platform::errors::InvalidArgument(
"Decrease dim should be 1, but now received %d",
decreased_dims[axis]));
}
decreased_dims[axis] = 0;
}

std::vector<T> new_shape;
for (int i = 0; i < decreased_dims.size(); ++i) {
if (decreased_dims[i] != 0) {
if (decrease_flag[i] == 0) {
new_shape.push_back(decreased_dims[i]);
}
}
Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/operators/strided_slice_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,9 @@ static void StridedSliceOutDims(
end_index = end_index + 1;
}

bool zero_dim_condition =
((stride_index < 0 && (start_index <= end_index)) ||
(stride_index > 0 && (start_index >= end_index)));
PADDLE_ENFORCE_EQ(zero_dim_condition, false,
bool neg_dim_condition = ((stride_index < 0 && (start_index < end_index)) ||
(stride_index > 0 && (start_index > end_index)));
PADDLE_ENFORCE_EQ(neg_dim_condition, false,
platform::errors::InvalidArgument(
"The start index and end index are invalid for their "
"corresponding stride."));
Expand Down
8 changes: 8 additions & 0 deletions python/paddle/fluid/tests/unittests/test_set_value_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,14 @@ def _get_answer(self):
self.data[0:, 1:2, :] = self.value


class TestSetValueItemSlice5(TestSetValueApi):
def _call_setitem(self, x):
x[0:, 1:1, :] = self.value

def _get_answer(self):
self.data[0:, 1:1, :] = self.value


class TestSetValueItemSliceInWhile(TestSetValueApi):
def _call_setitem(self, x):
def cond(i, x):
Expand Down
6 changes: 5 additions & 1 deletion python/paddle/fluid/tests/unittests/test_var_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,10 +568,12 @@ def _test_slice(self):
var14 = var[1:-1, 0:2, ::-1]
var15 = var[::-1, ::-1, ::-1]
var16 = var[-4:4]
var17 = var[:, 0, 0:0]
var18 = var[:, 1:1:2]

vars = [
var, var1, var2, var3, var4, var5, var6, var7, var8, var9, var10,
var11, var12, var13, var14, var15, var16
var11, var12, var13, var14, var15, var16, var17, var18
]
local_out = [var.numpy() for var in vars]

Expand Down Expand Up @@ -600,6 +602,8 @@ def _test_slice(self):
self.assertTrue(
np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1]))
self.assertTrue(np.array_equal(local_out[16], tensor_array[-4:4]))
self.assertTrue(np.array_equal(local_out[17], tensor_array[:, 0, 0:0]))
self.assertTrue(np.array_equal(local_out[18], tensor_array[:, 1:1:2]))

def _test_slice_for_tensor_attr(self):
tensor_array = np.array(
Expand Down