From cdcdd8fae1b2b4af18b7dd785ad90c95d037000e Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 17 Nov 2021 14:48:11 +0000 Subject: [PATCH 1/3] support zero dim for slice op --- paddle/fluid/operators/slice_op.cc | 3 +++ paddle/fluid/operators/slice_utils.h | 9 +++++---- paddle/fluid/operators/strided_slice_op.cc | 3 +++ paddle/fluid/operators/strided_slice_op.h | 7 +++---- python/paddle/fluid/tests/unittests/test_var_base.py | 6 +++++- 5 files changed, 19 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index a5513ba648776..e91176f312970 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -112,6 +112,9 @@ class SliceOp : public framework::OperatorWithKernel { out_dims = GetDecreasedDims(slice_dims, decrease_axis, nullptr); } + VLOG(1) << "######## slice_dims: " << slice_dims; + VLOG(1) << "######## out_dims: " << out_dims; + ctx->SetOutputDim("Out", out_dims); if (axes.size() > 0 && axes[0] != 0) { ctx->ShareLoD("Input", /*->*/ "Out"); diff --git a/paddle/fluid/operators/slice_utils.h b/paddle/fluid/operators/slice_utils.h index 290df94774b82..03687025088a5 100644 --- a/paddle/fluid/operators/slice_utils.h +++ b/paddle/fluid/operators/slice_utils.h @@ -51,7 +51,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, if (step > 0) { start = std::min(start, dim_value); end = std::max(end, static_cast(0)); - PADDLE_ENFORCE_GT( + PADDLE_ENFORCE_GE( end, start, platform::errors::InvalidArgument( "When step > 0, end should be greater than start, but " @@ -63,7 +63,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, // "end is -1" means contain the 0-th element of this axis. start = std::min(start, dim_value - 1); end = std::max(end, static_cast(-1)); - PADDLE_ENFORCE_GT( + PADDLE_ENFORCE_GE( start, end, platform::errors::InvalidArgument( "When step < 0, start should be greater than end, but " @@ -111,6 +111,7 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims, const std::vector& decrease_axes, std::vector* infer_flags = nullptr) { framework::DDim decreased_dims(slice_dims); + std::vector decrease_flag(slice_dims.size(), 0); if (decrease_axes.size() > 0) { for (size_t i = 0; i < decrease_axes.size(); ++i) { T axis = decrease_axes[i]; @@ -119,12 +120,12 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims, decreased_dims[axis], 1, platform::errors::InvalidArgument("decrease dim should be 1")); } - decreased_dims[axis] = 0; + decrease_flag[axis] = 1; } std::vector new_shape; for (int i = 0; i < decreased_dims.size(); ++i) { - if (decreased_dims[i] != 0) { + if (decrease_flag[i] == 0) { new_shape.push_back(decreased_dims[i]); } } diff --git a/paddle/fluid/operators/strided_slice_op.cc b/paddle/fluid/operators/strided_slice_op.cc index a1b5ca0f6a6eb..9ca3a4e6ffe17 100644 --- a/paddle/fluid/operators/strided_slice_op.cc +++ b/paddle/fluid/operators/strided_slice_op.cc @@ -165,8 +165,11 @@ class StridedSliceOp : public framework::OperatorWithKernel { new_out_shape.push_back(1); } + VLOG(1) << "########### out_dims: " << out_dims; + out_dims = framework::make_ddim(new_out_shape); } + VLOG(1) << "########### out_dims: " << out_dims; ctx->SetOutputDim("Out", out_dims); ctx->ShareLoD("Input", /*->*/ "Out"); } diff --git a/paddle/fluid/operators/strided_slice_op.h b/paddle/fluid/operators/strided_slice_op.h index e5b808174ace4..9eae27cca6840 100644 --- a/paddle/fluid/operators/strided_slice_op.h +++ b/paddle/fluid/operators/strided_slice_op.h @@ -77,10 +77,9 @@ static void StridedSliceOutDims( end_index = end_index + 1; } - bool zero_dim_condition = - ((stride_index < 0 && (start_index <= end_index)) || - (stride_index > 0 && (start_index >= end_index))); - PADDLE_ENFORCE_EQ(zero_dim_condition, false, + bool neg_dim_condition = ((stride_index < 0 && (start_index < end_index)) || + (stride_index > 0 && (start_index > end_index))); + PADDLE_ENFORCE_EQ(neg_dim_condition, false, platform::errors::InvalidArgument( "The start index and end index are invalid for their " "corresponding stride.")); diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index e2a90ed135b90..b4c0cdf26f3a2 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -568,10 +568,12 @@ def _test_slice(self): var14 = var[1:-1, 0:2, ::-1] var15 = var[::-1, ::-1, ::-1] var16 = var[-4:4] + var17 = var[:, 0, 0:0] + var18 = var[:, 1:1:2] vars = [ var, var1, var2, var3, var4, var5, var6, var7, var8, var9, var10, - var11, var12, var13, var14, var15, var16 + var11, var12, var13, var14, var15, var16, var17, var18 ] local_out = [var.numpy() for var in vars] @@ -600,6 +602,8 @@ def _test_slice(self): self.assertTrue( np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1])) self.assertTrue(np.array_equal(local_out[16], tensor_array[-4:4])) + self.assertTrue(np.array_equal(local_out[17], tensor_array[:, 0, 0:0])) + self.assertTrue(np.array_equal(local_out[18], tensor_array[:, 1:1:2])) def _test_slice_for_tensor_attr(self): tensor_array = np.array( From 617c32c7ab93450d7f8b2bd88de62fd0e1ae62c1 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 19 Nov 2021 03:07:58 +0000 Subject: [PATCH 2/3] support zero dim Tensor in set_value op --- paddle/fluid/operators/set_value_op.h | 3 +++ paddle/fluid/operators/slice_op.cc | 3 --- paddle/fluid/operators/slice_utils.h | 26 ++++++++++++++----- .../tests/unittests/test_set_value_op.py | 8 ++++++ 4 files changed, 30 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/set_value_op.h b/paddle/fluid/operators/set_value_op.h index 72b94dfa77279..71eb03895404d 100644 --- a/paddle/fluid/operators/set_value_op.h +++ b/paddle/fluid/operators/set_value_op.h @@ -260,6 +260,9 @@ class SetValueKernel : public framework::OpKernel { starts_indices[axis_index] = starts[i]; ends_indices[axis_index] = ends[i]; strides_indices[axis_index] = steps[i]; + if (starts[i] == ends[i]) { // slice is empty, data will not be changed + return; + } } out_e.stridedSlice(starts_indices, ends_indices, strides_indices) diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index e91176f312970..a5513ba648776 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -112,9 +112,6 @@ class SliceOp : public framework::OperatorWithKernel { out_dims = GetDecreasedDims(slice_dims, decrease_axis, nullptr); } - VLOG(1) << "######## slice_dims: " << slice_dims; - VLOG(1) << "######## out_dims: " << out_dims; - ctx->SetOutputDim("Out", out_dims); if (axes.size() > 0 && axes[0] != 0) { ctx->ShareLoD("Input", /*->*/ "Out"); diff --git a/paddle/fluid/operators/slice_utils.h b/paddle/fluid/operators/slice_utils.h index 03687025088a5..ab82da803b369 100644 --- a/paddle/fluid/operators/slice_utils.h +++ b/paddle/fluid/operators/slice_utils.h @@ -30,12 +30,20 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, std::vector* infer_flags = nullptr) { for (size_t i = 0; i < axes.size(); ++i) { T axis = axes[i]; + PADDLE_ENFORCE_LT( + axis, in_dims.size(), + platform::errors::InvalidArgument( + "The axis value should be less than the rank of input, " + "but received axes[%d] = %d, rank of input is %d.", + i, axis, in_dims.size())); + + if (infer_flags != nullptr && (*infer_flags)[i] == -1) { + continue; + } + T dim_value = in_dims[axis]; if (dim_value > 0) { - if (infer_flags != nullptr && (*infer_flags)[i] == -1) { - continue; - } T step = steps == nullptr ? 1 : (*steps)[i]; PADDLE_ENFORCE_NE( step, 0, platform::errors::InvalidArgument( @@ -73,6 +81,9 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, (*starts)[i] = start; (*ends)[i] = end; + } else if (dim_value == 0) { + (*starts)[i] = 0; + (*ends)[i] = 0; } } } @@ -115,12 +126,13 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims, if (decrease_axes.size() > 0) { for (size_t i = 0; i < decrease_axes.size(); ++i) { T axis = decrease_axes[i]; + decrease_flag[axis] = 1; if (infer_flags && (*infer_flags)[i] != -1) { - PADDLE_ENFORCE_EQ( - decreased_dims[axis], 1, - platform::errors::InvalidArgument("decrease dim should be 1")); + PADDLE_ENFORCE_EQ(decreased_dims[axis], 1, + platform::errors::InvalidArgument( + "Decrease dim should be 1, but new received %d", + decreased_dims[axis])); } - decrease_flag[axis] = 1; } std::vector new_shape; diff --git a/python/paddle/fluid/tests/unittests/test_set_value_op.py b/python/paddle/fluid/tests/unittests/test_set_value_op.py index e9809318cb393..3f33850b80f59 100644 --- a/python/paddle/fluid/tests/unittests/test_set_value_op.py +++ b/python/paddle/fluid/tests/unittests/test_set_value_op.py @@ -127,6 +127,14 @@ def _get_answer(self): self.data[0:, 1:2, :] = self.value +class TestSetValueItemSlice5(TestSetValueApi): + def _call_setitem(self, x): + x[0:, 1:1, :] = self.value + + def _get_answer(self): + self.data[0:, 1:1, :] = self.value + + class TestSetValueItemSliceInWhile(TestSetValueApi): def _call_setitem(self, x): def cond(i, x): From a36dae8bd0cc1440b9771f7277715e1485010548 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 19 Nov 2021 06:49:07 +0000 Subject: [PATCH 3/3] polish some debug log --- paddle/fluid/operators/slice_utils.h | 2 +- paddle/fluid/operators/strided_slice_op.cc | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/paddle/fluid/operators/slice_utils.h b/paddle/fluid/operators/slice_utils.h index ab82da803b369..fa36ded24f915 100644 --- a/paddle/fluid/operators/slice_utils.h +++ b/paddle/fluid/operators/slice_utils.h @@ -130,7 +130,7 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims, if (infer_flags && (*infer_flags)[i] != -1) { PADDLE_ENFORCE_EQ(decreased_dims[axis], 1, platform::errors::InvalidArgument( - "Decrease dim should be 1, but new received %d", + "Decrease dim should be 1, but now received %d", decreased_dims[axis])); } } diff --git a/paddle/fluid/operators/strided_slice_op.cc b/paddle/fluid/operators/strided_slice_op.cc index 9ca3a4e6ffe17..a1b5ca0f6a6eb 100644 --- a/paddle/fluid/operators/strided_slice_op.cc +++ b/paddle/fluid/operators/strided_slice_op.cc @@ -165,11 +165,8 @@ class StridedSliceOp : public framework::OperatorWithKernel { new_out_shape.push_back(1); } - VLOG(1) << "########### out_dims: " << out_dims; - out_dims = framework::make_ddim(new_out_shape); } - VLOG(1) << "########### out_dims: " << out_dims; ctx->SetOutputDim("Out", out_dims); ctx->ShareLoD("Input", /*->*/ "Out"); }