Skip to content

Commit

Permalink
Polish code for _getitem_impl_
Browse files Browse the repository at this point in the history
  • Loading branch information
liym27 committed May 12, 2021
1 parent 6b3bb79 commit 9539b31
Showing 1 changed file with 40 additions and 104 deletions.
144 changes: 40 additions & 104 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,29 +792,16 @@ def _getitem_impl_(var, item):
if not isinstance(item, tuple):
item = [item]

decrease_axis = []
slice_axis = []
slice_start = []
slice_end = []
slice_step = []
decrease_axes = []
axes = []
starts = []
ends = []
steps = []

use_strided_slice = False
reverse_axis = []
target_block = default_main_program().current_block()

def fill_constant(shape, value, force_cpu=False, out=None):
var.block.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [out]},
attrs={
'shape': shape,
'dtype': out.dtype,
'value': float(value),
'force_cpu': force_cpu
})
out.stop_gradient = True
return out

max_integer = 2**31 - 1
for dim, slice_item in enumerate(item):
if isinstance(slice_item, slice):
start = slice_item.start
Expand All @@ -824,8 +811,7 @@ def fill_constant(shape, value, force_cpu=False, out=None):
if start is None and end is None and step is None:
continue

if step is None:
step = 1
step = 1 if step is None else step

if start is None and end is None:
assert (step == -1)
Expand All @@ -836,106 +822,56 @@ def fill_constant(shape, value, force_cpu=False, out=None):
start = 0

if end is None:
end = 10000000

if step != 1:
use_strided_slice = True
end = max_integer

slice_axis.append(dim)
slice_start.append(start)
slice_end.append(end)
slice_step.append(step)
else:
decrease_axis.append(dim)
slice_axis.append(dim)
slice_start.append(slice_item)
slice_step.append(1)
if isinstance(slice_item, Variable):
temp_1 = var.block.create_var(dtype=slice_item.dtype)
fill_constant([1], 1, force_cpu=True, out=temp_1)
temp_end = target_block.create_var(dtype=slice_item.dtype)
target_block.append_op(
type='elementwise_add',
inputs={'X': slice_item,
'Y': temp_1},
outputs={'Out': temp_end},
attrs={'axis': -1})
slice_end.append(temp_end)
else:
slice_end.append(slice_item + 1
if slice_item != -1 else 10000000)
decrease_axes.append(dim)
start = slice_item
step = 1
end = slice_item + 1 if slice_item != -1 else max_integer

def contain_var(one_list):
for ele in one_list:
if isinstance(ele, Variable):
return True
return False

def get_new_list_tensor(old_list):
new_list_tensor = []
for dim in old_list:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_list_tensor.append(dim)
else:
assert (isinstance(dim, int))
temp_out = var.block.create_var(dtype='int64')
fill_constant([1], dim, force_cpu=True, out=temp_out)
new_list_tensor.append(temp_out)
return new_list_tensor
axes.append(dim)
starts.append(start)
ends.append(end)
steps.append(step)
use_strided_slice = True if step != 1 else use_strided_slice

inputs = {'Input': [var]}
attrs = {
'axes': slice_axis,
'axes': axes,
'starts': [],
'ends': [],
'decrease_axis': decrease_axis
'decrease_axis': decrease_axes
}
if (use_strided_slice == True):
if use_strided_slice == True:
attrs['strides'] = []
infer_flags = list(1 for i in range(len(slice_axis)))

# starts
if contain_var(slice_start):
inputs['StartsTensorList'] = get_new_list_tensor(slice_start)
for i, dim in enumerate(slice_start):
if isinstance(dim, Variable):
attrs['starts'].append(-1)
infer_flags[i] = -1
else:
attrs['starts'].append(dim)
else:
attrs['starts'] = slice_start

# ends
if contain_var(slice_end):
inputs['EndsTensorList'] = get_new_list_tensor(slice_end)
for i, dim in enumerate(slice_end):
if isinstance(dim, Variable):
attrs['ends'].append(-1)
infer_flags[i] = -1
else:
attrs['ends'].append(dim)
else:
attrs['ends'] = slice_end

# strides
if use_strided_slice == True:
if contain_var(slice_step):
inputs['StridesTensorList'] = get_new_list_tensor(slice_step)
for i, dim in enumerate(slice_step):
infer_flags = list(1 for i in range(len(axes)))
from .layers import utils

def deal_attrs(attr, attr_name, tensor_attr_name, inputs, infer_flags):
if utils._contain_var(attr):
inputs[tensor_attr_name] = utils._convert_to_tensor_list(
attr, dtype="int64")
for i, dim in enumerate(attr):
if isinstance(dim, Variable):
attrs['strides'].append(-1)
attrs[attr_name].append(-1)
infer_flags[i] = -1
else:
attrs['strides'].append(dim)
attrs[attr_name].append(dim)
else:
attrs['strides'] = slice_step
attrs[attr_name] = attr

deal_attrs(starts, "starts", "StartsTensorList", inputs, infer_flags)
deal_attrs(ends, "ends", "EndsTensorList", inputs, infer_flags)
deal_attrs(steps, "strides", "StridesTensorList", inputs, infer_flags)

# infer_flags
attrs['infer_flags'] = infer_flags

out = var
if use_strided_slice == False and len(slice_axis) > 0:
target_block = default_main_program().current_block()
if use_strided_slice == False and len(axes) > 0:
# append slice_op here
slice_out_var = target_block.create_var(
name=unique_name.generate_with_ignorable_key(var.name + "_slice"),
Expand All @@ -948,7 +884,7 @@ def get_new_list_tensor(old_list):
attrs=attrs)

out = slice_out_var
elif use_strided_slice == True and len(slice_axis) > 0:
elif use_strided_slice == True and len(axes) > 0:
strided_slice_out_var = target_block.create_var(
name=unique_name.generate_with_ignorable_key(var.name +
"_strided_slice"),
Expand Down

0 comments on commit 9539b31

Please sign in to comment.