Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support settiem by Bool index #35133

Merged
merged 18 commits into from
Sep 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
8fd53fd
Support getitem by Bool index
zyfncg Aug 19, 2021
13a6a07
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Aug 19, 2021
0155baf
delete some debug info of bool index
zyfncg Aug 19, 2021
f901c11
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Aug 19, 2021
b4430b4
support the case that the shape of bool index is different from index…
zyfncg Aug 22, 2021
57fa5eb
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Aug 22, 2021
d8b4e40
support setitem by bool index
zyfncg Aug 24, 2021
5278598
resolve conflit with develop
zyfncg Aug 24, 2021
d64dd08
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Aug 24, 2021
37ffff2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Aug 25, 2021
03b9dbf
add the unittest for throwing exception
zyfncg Aug 25, 2021
f412017
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Aug 25, 2021
75a192c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Aug 27, 2021
17adbe3
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Aug 27, 2021
37ed12f
merge conflict
zyfncg Aug 27, 2021
9495851
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Aug 27, 2021
6666c94
add check for int tensor when index is bool
zyfncg Aug 31, 2021
29c3e5b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Aug 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions paddle/fluid/pybind/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -499,12 +499,12 @@ static void ParseIndexingSlice(
none_axes->push_back(dim);
} else if (PyList_Check(slice_item)) {
*list_select_flag = true;
if (size != 1) {
PADDLE_THROW(platform::errors::InvalidArgument(
"When index contains a list, its length is excepted to 1, "
"but received %d",
size));
}
PADDLE_ENFORCE_EQ(
size, 1,
platform::errors::InvalidArgument(
"When index contains a list, its length is excepted to 1, "
"but received %d",
size));
bool all_bool = true;
int list_size = PyList_GET_SIZE(slice_item);
for (int j = 0; j < list_size; ++j) {
Expand All @@ -517,12 +517,13 @@ static void ParseIndexingSlice(
}
}
if (all_bool) {
if (list_size != shape[0]) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The dimension of bool index doesn't match indexed array along "
"dimension 0, the target dimension is %d, but received %d.",
shape[0], list_size));
}
PADDLE_ENFORCE_EQ(
list_size, shape[0],
platform::errors::InvalidArgument(
"The dimension of bool index doesn't match indexed array along "
"dimension 0, the target dimension is %d, but received %d.",
shape[0], list_size));

for (int j = 0; j < list_size; ++j) {
PyObject *list_item = PyList_GetItem(slice_item, j);
if (list_item == Py_True) {
Expand Down Expand Up @@ -818,7 +819,7 @@ void BindImperative(py::module *m_ptr) {
.def("__setitem_varbase__",
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index,
py::object &value_obj) {
VLOG(4) << "Call __setitem__";
VLOG(4) << "Call __setitem_varbase__";

auto self_tensor =
self->MutableVar()->GetMutable<framework::LoDTensor>();
Expand Down Expand Up @@ -871,7 +872,6 @@ void BindImperative(py::module *m_ptr) {
// TODO(liym27): Try not to call TensorToPyArray because it always
// copys data to cpu place, which reduces performance.
if (parse_index && value_is_tensor) {
VLOG(4) << "index is integer/slice/ellipsis and value is tensor";
std::vector<int> axes, starts, ends, steps, decrease_axes,
none_axes, infer_flags, list_select_idxs;
// if index is a list, list_select_flag will be true
Expand All @@ -880,6 +880,7 @@ void BindImperative(py::module *m_ptr) {
&steps, &decrease_axes, &none_axes,
&infer_flags, &list_select_idxs,
&list_select_flag);

framework::AttributeMap attrs = {
{"axes", axes},
{"starts", starts},
Expand Down
19 changes: 15 additions & 4 deletions python/paddle/fluid/dygraph/varbase_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,14 +587,25 @@ def _is_list_tuple(item):
return self._getitem_index_not_tensor(item)

def __setitem__(self, item, value):
def contain_tensor_or_list(item):
if not isinstance(item, tuple):
item = [item]

if contain_tensor(item):
# 1. Call _setitem_impl_ when item contains tensor.
# Why not call a c++ function ? Because item can't be parsed when it contains tensor.
for slice_item in item:
if isinstance(slice_item, list):
return True
elif isinstance(slice_item, Variable):
return True

return False

if contain_tensor_or_list(item):
# To reuse code with static graph,
# Call _setitem_impl_ when item contains tensor or list.
return _setitem_impl_(self, item, value)

else:
# 2. Call c++ func __setitem_varbase__ to speedup.
# Call c++ func __setitem_varbase__ to speedup.
return self.__setitem_varbase__(item, value)

for method_name, method in (
Expand Down
72 changes: 72 additions & 0 deletions python/paddle/fluid/tests/unittests/test_set_value_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,61 @@ def _get_answer(self):
self.data[None, :, 1, ..., None] = np.zeros(self.shape)[0, 0, :, None]


# 1.5 item is list or Tensor of bol
class TestSetValueItemBool1(TestSetValueApi):
def _call_setitem(self, x):
x[[True, False]] = self.value

def _get_answer(self):
self.data[[True, False]] = self.value


class TestSetValueItemBool2(TestSetValueApi):
def _call_setitem(self, x):
x[[False, False]] = self.value

def _get_answer(self):
self.data[[False, False]] = self.value


class TestSetValueItemBool3(TestSetValueApi):
def _call_setitem(self, x):
x[[False, True]] = np.zeros(self.shape[2])

def _get_answer(self):
self.data[[False, True]] = np.zeros(self.shape[2])


class TestSetValueItemBool4(TestSetValueApi):
def _call_setitem(self, x):
idx = paddle.assign(np.array([False, True]))
x[idx] = np.zeros(self.shape[2])

def _get_answer(self):
self.data[np.array([False, True])] = np.zeros(self.shape[2])


class TestSetValueItemBool5(TestSetValueApi):
def _call_setitem(self, x):
idx = paddle.assign(
np.array([[False, True, False], [True, True, False]]))
x[idx] = self.value

def _get_answer(self):
self.data[np.array([[False, True, False], [True, True, False]
])] = self.value


class TestSetValueItemBool6(TestSetValueApi):
def _call_setitem(self, x):
x[0, ...] = 0
x[x > 0] = self.value

def _get_answer(self):
self.data[0, ...] = 0
self.data[self.data > 0] = self.value


# 2. Test different type of value: int, float, numpy.ndarray, Tensor
# 2.1 value is int32, int64, float32, float64, bool

Expand Down Expand Up @@ -830,6 +885,21 @@ def _ellipsis_error(self):
one = paddle.ones([1])
x[::one] = self.value

def _bool_list_error(self):
with self.assertRaises(TypeError):
x = paddle.ones(shape=self.shape, dtype=self.dtype)
x[[True, False, 0]] = 0

with self.assertRaises(IndexError):
x = paddle.ones(shape=self.shape, dtype=self.dtype)
x[[True, False], [True, False]] = 0

def _bool_tensor_error(self):
with self.assertRaises(IndexError):
x = paddle.ones(shape=self.shape, dtype=self.dtype)
idx = paddle.assign([True, False, True])
x[idx] = 0

def _broadcast_mismatch(self):
program = paddle.static.Program()
with paddle.static.program_guard(program):
Expand All @@ -846,6 +916,8 @@ def test_error(self):
self._value_type_error()
self._dtype_error()
self._step_error()
self._bool_list_error()
self._bool_tensor_error()
self._broadcast_mismatch()


Expand Down
91 changes: 79 additions & 12 deletions python/paddle/fluid/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,16 +509,6 @@ def _setitem_impl_(var, item, value):
start = slice_item
end = slice_item + 1 if slice_item != -1 else MAX_INTEGER
step = 1
elif isinstance(slice_item, list):
if not is_list_tuple(slice_item, int):
raise TypeError(
"Only support int or list in index list. But revceived {}.".
format(slice_item))
slice_info.update(slice_item)
continue
elif isinstance(slice_item, (Variable, np.ndarray)):
slice_info.update(slice_item)
continue

elif isinstance(slice_item, slice):
start = slice_item.start
Expand Down Expand Up @@ -547,10 +537,43 @@ def _setitem_impl_(var, item, value):

if end is None:
end = MAX_INTEGER if step > 0 else (0 - MAX_INTEGER)
elif isinstance(slice_item, list):
if is_list_tuple(slice_item, int):
slice_info.update(slice_item)
continue

for i in slice_item:
if not isinstance(i, bool):
raise TypeError("Doesn't support {} in index list.".format(
type(i)))

if len(item) != 1:
raise IndexError(
"When index contains a bool list, its length must be 1, but received {}.".
format(len(item)))

from .layers import assign
idx_tensor = assign(slice_item)
return set_value_for_bool_tensor(var, idx_tensor, value)

elif isinstance(slice_item, np.ndarray):
slice_info.update(slice_item)
continue

elif isinstance(slice_item, Variable):
if slice_item.dtype == core.VarDesc.VarType.BOOL:
if len(item) != 1:
raise IndexError(
"When index contains a bool tensor, its length must be 1, but received {}.".
format(len(item)))
return set_value_for_bool_tensor(var, slice_item, value)
else:
slice_info.update(slice_item)
continue
else:
raise IndexError(
"Valid index accept int, slice, ellipsis or None, but received {}.".
format(slice_item))
"Valid index accept int, slice, ellipsis, None, list of bool, Variable, "
"but received {}.".format(slice_item))

axes.append(dim)
starts.append(start)
Expand Down Expand Up @@ -632,3 +655,47 @@ def _setitem_impl_(var, item, value):
type="set_value", inputs=inputs, outputs={'Out': var}, attrs=attrs)

return var


# the item is a tensor of bool
def set_value_for_bool_tensor(var, item, value):

# TODO(zyfncg): Now scatter_nd_add only support float32 and float64 tensor,
# so in the current version we also only support float32 and float64 tensor,
# this problem will be fixed in the future.
if var.dtype != core.VarDesc.VarType.FP32 and var.dtype != core.VarDesc.VarType.FP64:
raise TypeError("Only support float and double tensor for bool index, "
"but received {}.".format(var.dtype))

if len(item.shape) > len(var.shape):
raise IndexError("The dims of bool index doesn't match indexed array, "
"the dims of bool index except to be equal or less "
"than {}, but received {}.".format(
len(var.shape), len(item.shape)))
for i, dim_len in enumerate(item.shape):
if dim_len != var.shape[i]:
raise IndexError(
"The dimension of bool index doesn't match indexed array along "
"dimension {}, the target dimension is {}, but received {}.".
format(i, var.shape[i], dim_len))

def idx_not_empty(var, item, value):
from .framework import Variable
from .layers import assign
from .layers.nn import where
from ..tensor import gather_nd, scatter_nd_add

if not isinstance(value, Variable):
value = assign(value).cast(var.dtype)

idx = where(item)
gather_val = gather_nd(var, idx)
gather_val_new = value - gather_val
out = scatter_nd_add(var, idx, gather_val_new)
Copy link
Contributor

@hbwx24 hbwx24 Aug 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scatter_nd_add 的API文档上写的支持float32, float64,这里是否需要做一下类型转换。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里先设置了类型检查,暂时只支持float32和float64数据类型使用bool索引,后续再根据scatter_nd_add的支持情况进行修改

var[:] = out

from .layers.control_flow import cond
# If all the bool index is False, just do nothing
cond(item.any(), lambda: idx_not_empty(var, item, value))

return var