diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 0c8d65340f6b9..7f44afabf2590 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -499,12 +499,12 @@ static void ParseIndexingSlice( none_axes->push_back(dim); } else if (PyList_Check(slice_item)) { *list_select_flag = true; - if (size != 1) { - PADDLE_THROW(platform::errors::InvalidArgument( - "When index contains a list, its length is excepted to 1, " - "but received %d", - size)); - } + PADDLE_ENFORCE_EQ( + size, 1, + platform::errors::InvalidArgument( + "When index contains a list, its length is excepted to 1, " + "but received %d", + size)); bool all_bool = true; int list_size = PyList_GET_SIZE(slice_item); for (int j = 0; j < list_size; ++j) { @@ -517,12 +517,13 @@ static void ParseIndexingSlice( } } if (all_bool) { - if (list_size != shape[0]) { - PADDLE_THROW(platform::errors::InvalidArgument( - "The dimension of bool index doesn't match indexed array along " - "dimension 0, the target dimension is %d, but received %d.", - shape[0], list_size)); - } + PADDLE_ENFORCE_EQ( + list_size, shape[0], + platform::errors::InvalidArgument( + "The dimension of bool index doesn't match indexed array along " + "dimension 0, the target dimension is %d, but received %d.", + shape[0], list_size)); + for (int j = 0; j < list_size; ++j) { PyObject *list_item = PyList_GetItem(slice_item, j); if (list_item == Py_True) { @@ -818,7 +819,7 @@ void BindImperative(py::module *m_ptr) { .def("__setitem_varbase__", [](std::shared_ptr &self, py::handle _index, py::object &value_obj) { - VLOG(4) << "Call __setitem__"; + VLOG(4) << "Call __setitem_varbase__"; auto self_tensor = self->MutableVar()->GetMutable(); @@ -871,7 +872,6 @@ void BindImperative(py::module *m_ptr) { // TODO(liym27): Try not to call TensorToPyArray because it always // copys data to cpu place, which reduces performance. if (parse_index && value_is_tensor) { - VLOG(4) << "index is integer/slice/ellipsis and value is tensor"; std::vector axes, starts, ends, steps, decrease_axes, none_axes, infer_flags, list_select_idxs; // if index is a list, list_select_flag will be true @@ -880,6 +880,7 @@ void BindImperative(py::module *m_ptr) { &steps, &decrease_axes, &none_axes, &infer_flags, &list_select_idxs, &list_select_flag); + framework::AttributeMap attrs = { {"axes", axes}, {"starts", starts}, diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 4a36445ac0c94..c42a2a5943d11 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -587,14 +587,25 @@ def _is_list_tuple(item): return self._getitem_index_not_tensor(item) def __setitem__(self, item, value): + def contain_tensor_or_list(item): + if not isinstance(item, tuple): + item = [item] - if contain_tensor(item): - # 1. Call _setitem_impl_ when item contains tensor. - # Why not call a c++ function ? Because item can't be parsed when it contains tensor. + for slice_item in item: + if isinstance(slice_item, list): + return True + elif isinstance(slice_item, Variable): + return True + + return False + + if contain_tensor_or_list(item): + # To reuse code with static graph, + # Call _setitem_impl_ when item contains tensor or list. return _setitem_impl_(self, item, value) else: - # 2. Call c++ func __setitem_varbase__ to speedup. + # Call c++ func __setitem_varbase__ to speedup. return self.__setitem_varbase__(item, value) for method_name, method in ( diff --git a/python/paddle/fluid/tests/unittests/test_set_value_op.py b/python/paddle/fluid/tests/unittests/test_set_value_op.py index d26055b3166d6..21f506d03ce68 100644 --- a/python/paddle/fluid/tests/unittests/test_set_value_op.py +++ b/python/paddle/fluid/tests/unittests/test_set_value_op.py @@ -408,6 +408,61 @@ def _get_answer(self): self.data[None, :, 1, ..., None] = np.zeros(self.shape)[0, 0, :, None] +# 1.5 item is list or Tensor of bol +class TestSetValueItemBool1(TestSetValueApi): + def _call_setitem(self, x): + x[[True, False]] = self.value + + def _get_answer(self): + self.data[[True, False]] = self.value + + +class TestSetValueItemBool2(TestSetValueApi): + def _call_setitem(self, x): + x[[False, False]] = self.value + + def _get_answer(self): + self.data[[False, False]] = self.value + + +class TestSetValueItemBool3(TestSetValueApi): + def _call_setitem(self, x): + x[[False, True]] = np.zeros(self.shape[2]) + + def _get_answer(self): + self.data[[False, True]] = np.zeros(self.shape[2]) + + +class TestSetValueItemBool4(TestSetValueApi): + def _call_setitem(self, x): + idx = paddle.assign(np.array([False, True])) + x[idx] = np.zeros(self.shape[2]) + + def _get_answer(self): + self.data[np.array([False, True])] = np.zeros(self.shape[2]) + + +class TestSetValueItemBool5(TestSetValueApi): + def _call_setitem(self, x): + idx = paddle.assign( + np.array([[False, True, False], [True, True, False]])) + x[idx] = self.value + + def _get_answer(self): + self.data[np.array([[False, True, False], [True, True, False] + ])] = self.value + + +class TestSetValueItemBool6(TestSetValueApi): + def _call_setitem(self, x): + x[0, ...] = 0 + x[x > 0] = self.value + + def _get_answer(self): + self.data[0, ...] = 0 + self.data[self.data > 0] = self.value + + # 2. Test different type of value: int, float, numpy.ndarray, Tensor # 2.1 value is int32, int64, float32, float64, bool @@ -830,6 +885,21 @@ def _ellipsis_error(self): one = paddle.ones([1]) x[::one] = self.value + def _bool_list_error(self): + with self.assertRaises(TypeError): + x = paddle.ones(shape=self.shape, dtype=self.dtype) + x[[True, False, 0]] = 0 + + with self.assertRaises(IndexError): + x = paddle.ones(shape=self.shape, dtype=self.dtype) + x[[True, False], [True, False]] = 0 + + def _bool_tensor_error(self): + with self.assertRaises(IndexError): + x = paddle.ones(shape=self.shape, dtype=self.dtype) + idx = paddle.assign([True, False, True]) + x[idx] = 0 + def _broadcast_mismatch(self): program = paddle.static.Program() with paddle.static.program_guard(program): @@ -846,6 +916,8 @@ def test_error(self): self._value_type_error() self._dtype_error() self._step_error() + self._bool_list_error() + self._bool_tensor_error() self._broadcast_mismatch() diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 3ae7d8cfd413e..1b9a82ba85f05 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -509,16 +509,6 @@ def _setitem_impl_(var, item, value): start = slice_item end = slice_item + 1 if slice_item != -1 else MAX_INTEGER step = 1 - elif isinstance(slice_item, list): - if not is_list_tuple(slice_item, int): - raise TypeError( - "Only support int or list in index list. But revceived {}.". - format(slice_item)) - slice_info.update(slice_item) - continue - elif isinstance(slice_item, (Variable, np.ndarray)): - slice_info.update(slice_item) - continue elif isinstance(slice_item, slice): start = slice_item.start @@ -547,10 +537,43 @@ def _setitem_impl_(var, item, value): if end is None: end = MAX_INTEGER if step > 0 else (0 - MAX_INTEGER) + elif isinstance(slice_item, list): + if is_list_tuple(slice_item, int): + slice_info.update(slice_item) + continue + + for i in slice_item: + if not isinstance(i, bool): + raise TypeError("Doesn't support {} in index list.".format( + type(i))) + + if len(item) != 1: + raise IndexError( + "When index contains a bool list, its length must be 1, but received {}.". + format(len(item))) + + from .layers import assign + idx_tensor = assign(slice_item) + return set_value_for_bool_tensor(var, idx_tensor, value) + + elif isinstance(slice_item, np.ndarray): + slice_info.update(slice_item) + continue + + elif isinstance(slice_item, Variable): + if slice_item.dtype == core.VarDesc.VarType.BOOL: + if len(item) != 1: + raise IndexError( + "When index contains a bool tensor, its length must be 1, but received {}.". + format(len(item))) + return set_value_for_bool_tensor(var, slice_item, value) + else: + slice_info.update(slice_item) + continue else: raise IndexError( - "Valid index accept int, slice, ellipsis or None, but received {}.". - format(slice_item)) + "Valid index accept int, slice, ellipsis, None, list of bool, Variable, " + "but received {}.".format(slice_item)) axes.append(dim) starts.append(start) @@ -632,3 +655,47 @@ def _setitem_impl_(var, item, value): type="set_value", inputs=inputs, outputs={'Out': var}, attrs=attrs) return var + + +# the item is a tensor of bool +def set_value_for_bool_tensor(var, item, value): + + # TODO(zyfncg): Now scatter_nd_add only support float32 and float64 tensor, + # so in the current version we also only support float32 and float64 tensor, + # this problem will be fixed in the future. + if var.dtype != core.VarDesc.VarType.FP32 and var.dtype != core.VarDesc.VarType.FP64: + raise TypeError("Only support float and double tensor for bool index, " + "but received {}.".format(var.dtype)) + + if len(item.shape) > len(var.shape): + raise IndexError("The dims of bool index doesn't match indexed array, " + "the dims of bool index except to be equal or less " + "than {}, but received {}.".format( + len(var.shape), len(item.shape))) + for i, dim_len in enumerate(item.shape): + if dim_len != var.shape[i]: + raise IndexError( + "The dimension of bool index doesn't match indexed array along " + "dimension {}, the target dimension is {}, but received {}.". + format(i, var.shape[i], dim_len)) + + def idx_not_empty(var, item, value): + from .framework import Variable + from .layers import assign + from .layers.nn import where + from ..tensor import gather_nd, scatter_nd_add + + if not isinstance(value, Variable): + value = assign(value).cast(var.dtype) + + idx = where(item) + gather_val = gather_nd(var, idx) + gather_val_new = value - gather_val + out = scatter_nd_add(var, idx, gather_val_new) + var[:] = out + + from .layers.control_flow import cond + # If all the bool index is False, just do nothing + cond(item.any(), lambda: idx_not_empty(var, item, value)) + + return var