Skip to content

Commit

Permalink
#12451: add negative ends support for slice with list splicing format
Browse files Browse the repository at this point in the history
  • Loading branch information
sjameelTT committed Sep 10, 2024
1 parent 135eee3 commit b4b21dd
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
12 changes: 12 additions & 0 deletions tests/ttnn/unit_tests/operations/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,15 @@ def test_slice_ellipses(device):
ttnn_output = ttnn_input[...]
ttnn_output = ttnn.to_torch(ttnn_output)
assert_with_pcc(torch_output, ttnn_output, 0.99)


@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
@pytest.mark.parametrize("ends", [-2, -4, -6])
def test_slice_negative_ends(layout, ends, device):
torch_input = torch.randn(32, 32, 32, 32)
ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout)

torch_output = torch_input[:, :, :, 0:ends]
ttnn_output = ttnn_input[:, :, :, 0:ends]
ttnn_output = ttnn.to_torch(ttnn_output)
assert_with_pcc(torch_output, ttnn_output, 0.99)
14 changes: 9 additions & 5 deletions ttnn/ttnn/operations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def __getitem__(input_tensor: ttnn.Tensor, slices) -> ttnn.Tensor:
slices = (slice(None, None, None),) + slices
slice_start = [_slice.start if _slice.start is not None else 0 for _slice in slices]
slice_end = [
(_slice.stop if _slice.stop is not None else input_tensor.shape[index])
(max(input_tensor.shape[index] + _slice.stop, 1) if _slice.stop < 0 else _slice.stop)
if _slice.stop is not None
else input_tensor.shape[index]
for index, _slice in enumerate(slices)
]
slice_step = [_slice.step if _slice.step is not None else 1 for _slice in slices]
Expand All @@ -98,10 +100,12 @@ def __getitem__(input_tensor: ttnn.Tensor, slices) -> ttnn.Tensor:
input_tensor = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT)
output = input_tensor.unpad(slice_start, padded_slice_end_minus_1)
output = ttnn.to_layout(output, input_layout)

output_shape = [len(range(start, end, step)) for (start, end, step) in zip(slice_start, slice_end, slice_step)][
-input_rank:
]
output_shape = [
0
if slices[i].stop is not None and slices[i].stop + input_tensor.shape[i] == slices[i].start
else len(range(start, end, step))
for i, (start, end, step) in enumerate(zip(slice_start, slice_end, slice_step))
][-input_rank:]
padded_output_shape = list(output.shape.with_tile_padding())[-input_rank:]
return ttnn.reshape(output, shape=ttnn.Shape(output_shape, padded_output_shape))

Expand Down

0 comments on commit b4b21dd

Please sign in to comment.