Skip to content

Commit

Permalink
[cherry-pick] Polish code for setitem/getitem and support index for l…
Browse files Browse the repository at this point in the history
…ist/Tensor/None/Ellipsis/bool (#33528)

* [cherry-pick 2.1] Polish code for _getitem_impl_ (#32868)

* [cherry-pick] Polish code for setitem and getitem (#32911)

* [slice getitem] Support getitem idx is Tensor or List (#33000)

* [getitem] Support index is None for getitem in static mode (#33001)

* [Static getitem] Support static Variable getitem for Ellipsis index (#32876)

* [static getitem]Support index is list bool for getitem in static mode (#33298)
  • Loading branch information
liym27 authored Jun 15, 2021
1 parent bbedca4 commit 2b44ae5
Show file tree
Hide file tree
Showing 3 changed files with 569 additions and 357 deletions.
355 changes: 2 additions & 353 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import paddle.version as fluid_version
import warnings
import functools
from .variable_index import _getitem_impl_, _setitem_impl_

__all__ = [
'Program',
Expand Down Expand Up @@ -794,205 +795,6 @@ def __instancecheck__(cls, instance):
return issubclass(t, Parameter)


def _getitem_impl_(var, item):
"""
Slice the variable.
Args:
item(int/slice/tuple) : the index.
Returns:
Sliced variable
"""

if not isinstance(item, tuple):
item = [item]

decrease_axis = []
slice_axis = []
slice_start = []
slice_end = []
slice_step = []
use_strided_slice = False
reverse_axis = []
target_block = default_main_program().current_block()

def fill_constant(shape, value, force_cpu=False, out=None):
var.block.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [out]},
attrs={
'shape': shape,
'dtype': out.dtype,
'value': float(value),
'force_cpu': force_cpu
})
out.stop_gradient = True
return out

for dim, slice_item in enumerate(item):
if isinstance(slice_item, slice):
start = slice_item.start
end = slice_item.stop
step = slice_item.step

if start is None and end is None and step is None:
continue

if step is None:
step = 1

if start is None and end is None:
assert (step == -1)
reverse_axis.append(dim)
continue

if start is None:
start = 0

if end is None:
end = 10000000

if step != 1:
use_strided_slice = True

slice_axis.append(dim)
slice_start.append(start)
slice_end.append(end)
slice_step.append(step)
else:
decrease_axis.append(dim)
slice_axis.append(dim)
slice_start.append(slice_item)
slice_step.append(1)
if isinstance(slice_item, Variable):
temp_1 = var.block.create_var(dtype=slice_item.dtype)
fill_constant([1], 1, force_cpu=True, out=temp_1)
temp_end = target_block.create_var(dtype=slice_item.dtype)
target_block.append_op(
type='elementwise_add',
inputs={'X': slice_item,
'Y': temp_1},
outputs={'Out': temp_end},
attrs={'axis': -1})
slice_end.append(temp_end)
else:
slice_end.append(slice_item + 1
if slice_item != -1 else 10000000)

def contain_var(one_list):
for ele in one_list:
if isinstance(ele, Variable):
return True
return False

def get_new_list_tensor(old_list):
new_list_tensor = []
for dim in old_list:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_list_tensor.append(dim)
else:
assert (isinstance(dim, int))
temp_out = var.block.create_var(dtype='int64')
fill_constant([1], dim, force_cpu=True, out=temp_out)
new_list_tensor.append(temp_out)
return new_list_tensor

inputs = {'Input': [var]}
attrs = {
'axes': slice_axis,
'starts': [],
'ends': [],
'decrease_axis': decrease_axis
}
if (use_strided_slice == True):
attrs['strides'] = []
infer_flags = list(1 for i in range(len(slice_axis)))

# starts
if contain_var(slice_start):
inputs['StartsTensorList'] = get_new_list_tensor(slice_start)
for i, dim in enumerate(slice_start):
if isinstance(dim, Variable):
attrs['starts'].append(-1)
infer_flags[i] = -1
else:
attrs['starts'].append(dim)
else:
attrs['starts'] = slice_start

# ends
if contain_var(slice_end):
inputs['EndsTensorList'] = get_new_list_tensor(slice_end)
for i, dim in enumerate(slice_end):
if isinstance(dim, Variable):
attrs['ends'].append(-1)
infer_flags[i] = -1
else:
attrs['ends'].append(dim)
else:
attrs['ends'] = slice_end

# strides
if use_strided_slice == True:
if contain_var(slice_step):
inputs['StridesTensorList'] = get_new_list_tensor(slice_step)
for i, dim in enumerate(slice_step):
if isinstance(dim, Variable):
attrs['strides'].append(-1)
infer_flags[i] = -1
else:
attrs['strides'].append(dim)
else:
attrs['strides'] = slice_step
# infer_flags
attrs['infer_flags'] = infer_flags

out = var
if use_strided_slice == False and len(slice_axis) > 0:
# append slice_op here
slice_out_var = target_block.create_var(
name=unique_name.generate_with_ignorable_key(var.name + "_slice"),
dtype=var.dtype)

target_block.append_op(
type="slice",
inputs=inputs,
outputs={'Out': [slice_out_var]},
attrs=attrs)

out = slice_out_var
elif use_strided_slice == True and len(slice_axis) > 0:
strided_slice_out_var = target_block.create_var(
name=unique_name.generate_with_ignorable_key(var.name +
"_strided_slice"),
dtype=var.dtype)
target_block.append_op(
type="strided_slice",
inputs=inputs,
outputs={'Out': [strided_slice_out_var]},
attrs=attrs)

out = strided_slice_out_var

if len(reverse_axis) > 0:
reverse_out_var = target_block.create_var(
name=unique_name.generate_with_ignorable_key(var.name +
"_slice_reverse"),
dtype=var.dtype)
target_block.append_op(
type="reverse",
inputs={'X': out},
outputs={'Out': [reverse_out_var]},
attrs={'axis': reverse_axis})

out = reverse_out_var

return out


@six.add_metaclass(VariableMetaClass)
class Variable(object):
"""
Expand Down Expand Up @@ -1848,160 +1650,7 @@ def __getitem__(self, item):
return _getitem_impl_(self, item)

def __setitem__(self, item, value):
inputs = {'Input': self}

# 1. Parse item
if not isinstance(item, tuple):
item = [item]

decrease_axes = []
axes = []
starts = []
ends = []
steps = []

max_integer = sys.maxsize

def replace_ellipsis(item):
# Use slice(None) to replace Ellipsis.
# For var, var.shape = [3,4,5,6]
#
# var[..., 1:2] -> var[:, :, :, 1:2]
# var[0, ...] -> var[0]
# var[0, ..., 1:2] -> var[0, :, :, 1:2]

item = list(item)

# Remove Variable to skip bug when counting Ellipsis
item_remove_var = [
ele for ele in item if not isinstance(ele, Variable)
]
ell_count = item_remove_var.count(Ellipsis)
if ell_count == 0:
return item
elif ell_count > 1:
raise IndexError(
"An index can only have a single ellipsis ('...')")

ell_idx = item.index(Ellipsis)

if ell_idx == len(item) - 1:
return item[:-1]
else:
item[ell_idx:ell_idx + 1] = [slice(None)] * (
len(self.shape) - len(item) + 1)

return item

item = replace_ellipsis(item)

for dim, slice_item in enumerate(item):
if isinstance(slice_item, slice):
start = slice_item.start
end = slice_item.stop
step = slice_item.step

if start is None and end is None and step is None:
continue

step = 1 if step is None else step

# TODO: support cases when step < 1
if not isinstance(step, Variable) and step == 0:
raise ValueError(
"When assign a value to a paddle.Tensor, step can not be 0, "
"but received step is {}.".format(step))

if isinstance(step, Variable) and (start is None or
end is None):
raise ValueError(
"When assign a value to a paddle.Tensor, it's not supported that "
"the start or end is None when the type of step is paddle.Tensor."
)

if start is None:
start = 0 if step > 0 else max_integer

if end is None:
end = max_integer if step > 0 else (0 - max_integer)
else:
decrease_axes.append(dim)
start = slice_item
end = slice_item + 1 if slice_item != -1 else max_integer
step = 1

axes.append(dim)
starts.append(start)
ends.append(end)
steps.append(step)

attrs = {
'axes': axes,
'starts': starts,
'ends': ends,
'steps': steps,
'decrease_axes': decrease_axes
}

from .layers import utils
if utils._contain_var(starts):
inputs['StartsTensorList'] = utils._convert_to_tensor_list(starts)
del attrs['starts']
if utils._contain_var(ends):
inputs['EndsTensorList'] = utils._convert_to_tensor_list(ends)
del attrs['ends']
if utils._contain_var(steps):
inputs['StepsTensorList'] = utils._convert_to_tensor_list(steps)
del attrs['steps']

# 2. Parse value
dtype = self.dtype
attrs['dtype'] = dtype

from .data_feeder import convert_dtype
# 2.1 value is an integer of float
if isinstance(value, (int, float)):
value = np.array([value]).astype(convert_dtype(dtype))

# 2.2 value is a np.ndarray
if isinstance(value, np.ndarray):
shape = list(value.shape)
if dtype == core.VarDesc.VarType.BOOL:
value_name = "bool_values"
values = [bool(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.FP32:
value_name = "fp32_values"
values = [float(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.FP64:
value_name = "fp64_values"
values = [float(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.INT32:
value_name = "int32_values"
values = [int(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.INT64:
value_name = "int64_values"
values = [int(v) for v in value.flat]
else:
raise TypeError(
"When assign a numpy.ndarray, integer or float to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, float32, int32 or int64, but "
"received %s." % convert_dtype(dtype))
attrs[value_name] = values
attrs["shape"] = shape

elif isinstance(value, Variable):
inputs["ValueTensor"] = value
else:
raise TypeError(
"Only support to assign an integer, float, numpy.ndarray or "
"paddle.Tensor to a paddle.Tensor, but received {}".format(
type(value)))

cur_block = default_main_program().current_block()
cur_block.append_op(
type="set_value", inputs=inputs, outputs={'Out': self}, attrs=attrs)

return self
return _setitem_impl_(self, item, value)

def get_value(self, scope=None):
"""
Expand Down
Loading

0 comments on commit 2b44ae5

Please sign in to comment.