Skip to content

Commit

Permalink
[static getitem]Support index is list bool for getitem in static mode (
Browse files Browse the repository at this point in the history
  • Loading branch information
liym27 authored Jun 10, 2021
1 parent 11b5776 commit a225636
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
29 changes: 29 additions & 0 deletions python/paddle/fluid/tests/unittests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,34 @@ def _test_slice_index_ellipsis(self, place):
with self.assertRaises(TypeError):
res = x[[1.2, 0]]

def _test_slice_index_list_bool(self, place):
data = np.random.rand(2, 3).astype("float32")
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
x = paddle.assign(data)
idx0 = [True, False]
idx1 = [False, True]
idx2 = [False, False]
idx3 = [True, True]

out0 = x[idx0]
out1 = x[idx1]
out2 = x[idx2]
out3 = x[idx3]

exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=[out0, out1, out2, out3])

expected = [data[idx0], data[idx1], data[idx2], data[idx3]]

self.assertTrue((result[0] == expected[0]).all())
self.assertTrue((result[1] == expected[1]).all())
self.assertTrue((result[2] == expected[2]).all())
self.assertTrue((result[3] == expected[3]).all())

with self.assertRaises(TypeError):
res = x[[True, 0]]

def test_slice(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
Expand All @@ -255,6 +283,7 @@ def test_slice(self):
self._test_slice_index_tensor(place)
self._test_slice_index_list(place)
self._test_slice_index_ellipsis(place)
self._test_slice_index_list_bool(place)

def _tostring(self):
b = default_main_program().current_block()
Expand Down
23 changes: 20 additions & 3 deletions python/paddle/fluid/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,36 @@ def _getitem_impl_(var, item):
end = MAX_INTEGER if end is None else end

elif isinstance(slice_item, list):
is_bool_list = False
for i in slice_item:
if not isinstance(i, int):
raise TypeError("Only support int value in list")
if not isinstance(i, (int, bool)):
raise TypeError("Only support int or bool in index list.")

if isinstance(i, bool):
is_bool_list = True
break

if len(item) != 1:
raise IndexError(
"When index contains a list, its length must be 1, but received {}".
format(len(item)))

if is_bool_list:
new_slice_item = []
for idx, ele in enumerate(slice_item):
if not isinstance(ele, bool):
raise TypeError(
"Mixed bool index with other types is not supported."
)

if ele is True:
new_slice_item.append(idx)
slice_item = new_slice_item

from .layers import assign
from ..tensor import index_select

idx = assign(np.array(slice_item))
idx = assign(np.array(slice_item).astype("int32"))
return index_select(var, index=idx, axis=0)

elif isinstance(slice_item, Variable):
Expand Down

0 comments on commit a225636

Please sign in to comment.