Skip to content

Commit

Permalink
#12527: refactor slice (#12791)
Browse files Browse the repository at this point in the history
-use exclusive ends and move all the list splicing-exclusive features to ttnn.slice
 -use exclusive ends for ttnn.unpad
 -add support for negative starts
 -add negative starts and ends support to ttnn.slice
 -move most pre-processing to C++
 -refactor C++ code
 -remove skip on bert ttnn.slice unit tests
 -support 1D, 2D, 3D inputs on ttnn.slice
  • Loading branch information
sjameelTT authored Sep 20, 2024
1 parent 00e1478 commit 047cdd9
Show file tree
Hide file tree
Showing 43 changed files with 405 additions and 325 deletions.
12 changes: 6 additions & 6 deletions models/demos/falcon7b_common/tt/falcon_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def forward(
key_layer = ttnn.slice(
layer_past[0],
[0, 0, 0, 0],
[batch - 1, 0, nearest_32(layer_past_len + 1) - 1, self.head_dim - 1],
[batch, 1, nearest_32(layer_past_len + 1), self.head_dim],
memory_config=self.model_config["K_CACHE_SLICE_OUTPUT_MEMCFG"],
)

Expand Down Expand Up @@ -747,7 +747,7 @@ def forward(
value_layer = ttnn.slice(
layer_past[1],
[0, 0, 0, 0],
[batch - 1, 0, nearest_32(layer_past_len + 1) - 1, self.head_dim - 1],
[batch, 1, nearest_32(layer_past_len + 1), self.head_dim],
memory_config=self.model_config["V_CACHE_SLICE_OUTPUT_MEMCFG"],
)
if self.model_config["l1_sharded"]:
Expand Down Expand Up @@ -800,10 +800,10 @@ def forward(
attn_output,
[0, 0, 0, 0],
[
attn_output_shape[0] - 1,
self.num_heads - 1,
attn_output_shape[2] - 1,
attn_output_shape[3] - 1,
attn_output_shape[0],
self.num_heads,
attn_output_shape[2],
attn_output_shape[3],
],
memory_config=self.model_config["POST_SOFTMAX_MM_OUTPUT_MEMCFG"],
)
Expand Down
2 changes: 1 addition & 1 deletion models/demos/falcon7b_common/tt/falcon_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def forward(self, x: ttnn.Tensor) -> ttnn.Tensor:
hidden_states = ttnn.slice(
hidden_states,
[0, 0, 0, 0],
[0, 0, batch_size - 1, self.hidden_size - 1],
[1, 1, batch_size, self.hidden_size],
memory_config=self.model_config["DENSE_4H_TO_H_MM_OUTPUT_MEMCFG"],
)

Expand Down
16 changes: 8 additions & 8 deletions models/demos/t3000/falcon40b/tt/falcon_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,10 @@ def fwd_decode(
layer_past[0],
[0, 0, 0, 0],
[
batch - 1,
self.num_kv_heads // self.num_devices - 1,
padded_layer_past_len - 1,
self.head_dim - 1,
batch,
self.num_kv_heads // self.num_devices,
padded_layer_past_len,
self.head_dim,
],
memory_config=self.model_config["DEFAULT_MEMCFG"],
)
Expand Down Expand Up @@ -553,10 +553,10 @@ def fwd_decode(
layer_past[1],
[0, 0, 0, 0],
[
batch - 1,
self.num_kv_heads // self.num_devices - 1,
padded_layer_past_len - 1,
self.head_dim - 1,
batch,
self.num_kv_heads // self.num_devices,
padded_layer_past_len,
self.head_dim,
],
memory_config=self.model_config["DEFAULT_MEMCFG"],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def forward(self, xq, xk, rot_mat):
xq = ttnn.slice(
xq,
[0, 0, 0, 0],
[1 - 1, 8 - 1, 128 - 1, self.head_dim - 1],
[1, 8, 128, self.head_dim],
)

xk = ttnn.pad(xk, [1, 32, 128, self.head_dim], [0, 0, 0, 0], 0.0)
Expand All @@ -113,7 +113,7 @@ def forward(self, xq, xk, rot_mat):
xk = ttnn.slice(
xk,
[0, 0, 0, 0],
[1 - 1, 1 - 1, 128 - 1, self.head_dim - 1],
[1, 1, 128, self.head_dim],
)

return xq, xk
Expand Down
2 changes: 1 addition & 1 deletion models/demos/t3000/llama2_70b/tt/llama_model_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def prefill_forward(
norm_out_replicated = ttnn.slice(
norm_out_replicated,
(0, 0, last_token_tile * 32, 0),
(0, 0, (last_token_tile + 1) * 32 - 1, dmodel - 1),
(1, 1, (last_token_tile + 1) * 32, dmodel),
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
pc_lm_head = (
Expand Down
2 changes: 1 addition & 1 deletion models/demos/t3000/mixtral8x7b/tt/mixtral_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def forward(
# slicing for the last token
if get_last_token != -1:
x = ttnn.slice(
x, ttnn.Shape((0, 0, get_last_token, 0)), ttnn.Shape((0, 0, get_last_token + 31, 4095))
x, (0, 0, get_last_token, 0), (1, 1, get_last_token + 32, 4096)
) # [:, :, get_last_token:get_last_token+32, :]

x_norm = self.norm(x)
Expand Down
2 changes: 1 addition & 1 deletion models/demos/tg/llama3_70b/tt/llama_model_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def prefill_forward(
norm_out = ttnn.slice(
norm_out,
[0, 0, seq_len - 32, 0],
[0, 0, seq_len - 1, dmodel - 1],
[1, 1, seq_len, dmodel],
)

### Each device does an LM head fracture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1020,7 +1020,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt

if is_wormhole_b0() and self.batch_size == 16:
xshape = x.shape_without_padding()
x = ttnn.slice(x, [0, 0, 0, 0], [xshape[0] - 1, xshape[1] - 1, xshape[2] - 1, xshape[3] - 1])
x = ttnn.slice(x, [0, 0, 0, 0], [xshape[0], xshape[1], xshape[2], xshape[3]])

layer4_module1_input_shape = ttnn.Shape(x.get_legacy_shape())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
x = ttnn.slice(
x,
(0, 0, 0, 0),
(unpadded_shape[0] - 1, unpadded_shape[1] - 1, unpadded_shape[2] - 1, unpadded_shape[3] - 1),
(unpadded_shape[0], unpadded_shape[1], unpadded_shape[2], unpadded_shape[3]),
memory_config=ttnn.L1_MEMORY_CONFIG,
)

Expand Down
4 changes: 2 additions & 2 deletions models/demos/wormhole/mamba/tt/mamba_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ def forward(self, x):

for i in range(0, 4):
slice_start = (0, 0, x_ssm.shape[2] - (4 - i), 0)
slice_end = (0, 0, x_ssm.shape[2] - (4 - i), self.args.d_inner - 1)
entry = ttnn.slice(x_ssm, ttnn.Shape(slice_start), ttnn.Shape(slice_end))
slice_end = (1, 1, (x_ssm.shape[2] - (4 - i)) + 1, self.args.d_inner)
entry = ttnn.slice(x_ssm, slice_start, slice_end)
self.convolution_cache.set(self.configs["current_user"], i, entry)
ttnn.deallocate(entry)

Expand Down
2 changes: 1 addition & 1 deletion models/demos/wormhole/mamba/tt/mamba_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def prepare_input(self, input_tensor):
split_size = self.config.input_channels // self.config.channels_split_factor
for i in range(self.config.channels_split_factor):
slice_start = ttnn.Shape((0, 0, 0, i * split_size))
slice_end = ttnn.Shape((0, self.config.input_length - 1, 0, (i + 1) * split_size - 1))
slice_end = ttnn.Shape((1, self.config.input_length, 1, (i + 1) * split_size))
input_tensor_splits.append(ttnn.slice(input_tensor, slice_start, slice_end))
ttnn.deallocate(input_tensor)
return input_tensor_splits
Expand Down
4 changes: 2 additions & 2 deletions models/demos/wormhole/mamba/tt/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def split_sequence_length(x: ttnn.Tensor, batch: int = 0, chunk_size: int = 32):

for i in range(0, L, chunk_size):
slice_start = (0, 0, batch, i)
slice_end = (0, 0, batch, i + chunk_size - 1)
yield ttnn.slice(x, ttnn.Shape(slice_start), ttnn.Shape(slice_end))
slice_end = (1, 1, batch + 1, i + chunk_size)
yield ttnn.slice(x, slice_start, slice_end)


def select_chunk_size(sequence_length: int, max_chunk_size: int):
Expand Down
8 changes: 4 additions & 4 deletions models/demos/wormhole/stable_diffusion/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def tt_guide(noise_pred, guidance_scale): # will return latents
noise_pred,
[1, 0, 0, 0],
[
noise_pred.shape[0] - 1,
noise_pred.shape[1] - 1,
noise_pred.shape[2] - 1,
noise_pred.shape[3] - 1,
noise_pred.shape[0],
noise_pred.shape[1],
noise_pred.shape[2],
noise_pred.shape[3],
],
)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def time_sharded_attention(self, query, t_key, value, head_size, attn_type):
k_slice = ttnn.slice(
t_key,
(j, i, 0, 0),
(j, i, self.key_len - 1, self.seq_len - 1),
(j + 1, i + 1, self.key_len, self.seq_len),
memory_config=self.l1_interleaved_memory_config,
)

Expand Down Expand Up @@ -407,7 +407,7 @@ def time_sharded_attention(self, query, t_key, value, head_size, attn_type):
v_slice = ttnn.slice(
value,
(j, i, 0, 0),
(j, i, self.seq_len - 1, self.key_len - 1),
(j, i, self.seq_len, self.key_len),
memory_config=self.l1_interleaved_memory_config,
)
mm_slice = ttnn.matmul(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,10 +450,10 @@ def __call__(
hidden_states,
[0, 0, 0, output_tensor_start_width_dim],
[
hidden_states.shape[0] - 1,
hidden_states.shape[1] - 1,
hidden_states.shape[2] - 1,
output_tensor_end_width_dim - 1,
hidden_states.shape[0],
hidden_states.shape[1],
hidden_states.shape[2],
output_tensor_end_width_dim,
],
memory_config=ttnn.L1_MEMORY_CONFIG,
)
Expand Down
4 changes: 1 addition & 3 deletions models/experimental/bloom_old/bloom_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ def create_padded_tensor(

def create_unpadded_tensor(ttm_tensor, input_tensors_shape, input_tensor_start=[0, 0, 0, 0]):
output_tensor_start = input_tensor_start
output_tensor_end = tuple(
input_tensor_start[i] + input_tensors_shape[i] - 1 for i in range(len(input_tensors_shape))
)
output_tensor_end = tuple(input_tensor_start[i] + input_tensors_shape[i] for i in range(len(input_tensors_shape)))
ttm_tensor = ttm_tensor.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad(output_tensor_start, output_tensor_end)

return ttm_tensor
Expand Down
4 changes: 1 addition & 3 deletions models/experimental/nanogpt/nanogpt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ def unpad_from_zero(x, desired_shape):
x = x.cpu()
if x.get_layout() != ttnn.ROW_MAJOR_LAYOUT:
x = x.to(ttnn.ROW_MAJOR_LAYOUT)
x = x.unpad(
(0, 0, 0, 0), (desired_shape[0] - 1, desired_shape[1] - 1, desired_shape[2] - 1, desired_shape[3] - 1)
)
x = x.unpad((0, 0, 0, 0), (desired_shape[0], desired_shape[1], desired_shape[2], desired_shape[3]))
x = x.to_torch().to(torch.float)
return x

Expand Down
8 changes: 4 additions & 4 deletions models/utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,10 @@ def unpad_from_zero(x, desired_shape):
x = x.unpad(
(0, 0, 0, 0),
(
desired_shape[0] - 1,
desired_shape[1] - 1,
desired_shape[2] - 1,
desired_shape[3] - 1,
desired_shape[0],
desired_shape[1],
desired_shape[2],
desired_shape[3],
),
)

Expand Down
16 changes: 10 additions & 6 deletions tests/tt_eager/python_api_testing/sweep_tests/generation_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def gen_tensor_unpad_args(
assert len(input_shapes[0]) == 4
test_args = {}
output_tensor_start = [random.randint(0, input_shapes[0][i] - 1) for i in range(4)]
output_tensor_end = [random.randint(output_tensor_start[i], input_shapes[0][i] - 1) for i in range(4)]
output_tensor_end = [random.randint(output_tensor_start[i] + 1, input_shapes[0][i]) for i in range(4)]

test_args.update(
{
Expand Down Expand Up @@ -917,8 +917,10 @@ def gen_unpad_args(
if input_info is not None:
if input_info["layout"][0] == ttnn.ROW_MAJOR_LAYOUT:
output_tensor_start = [0, 0, 0, 0]
output_tensor_end = [random.randrange(output_tensor_start[i], input_shapes[0][i], 1) for i in range(4)]
if output_tensor_end[-1] % 2 == 0:
output_tensor_end = [
random.randrange(output_tensor_start[i] + 1, input_shapes[0][i], 1) for i in range(4)
]
if output_tensor_end[-1] % 2 != 0:
output_tensor_end[-1] += 1
input_info.update(
{
Expand All @@ -928,9 +930,11 @@ def gen_unpad_args(
)
elif input_info["layout"][0] == ttnn.TILE_LAYOUT:
output_tensor_start = [0, 0, 0, 0]
output_tensor_end = [random.randrange(output_tensor_start[i], input_shapes[0][i], 1) for i in range(4)]
output_tensor_end[-2] = max(nearest_32(output_tensor_end[-2]), 32) - 1
output_tensor_end[-1] = max(nearest_32(output_tensor_end[-1]), 32) - 1
output_tensor_end = [
random.randrange(output_tensor_start[i] + 1, input_shapes[0][i], 1) for i in range(4)
]
output_tensor_end[-2] = max(nearest_32(output_tensor_end[-2]), 32)
output_tensor_end[-1] = max(nearest_32(output_tensor_end[-1]), 32)
input_info.update(
{
"output_tensor_start": output_tensor_start,
Expand Down
8 changes: 4 additions & 4 deletions tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,10 +1338,10 @@ def pad(x, *args, output_tensor_shape, input_tensor_start, pad_value, **kwargs):

def unpad(x, *args, output_tensor_start, output_tensor_end, **kwargs):
out = x[
output_tensor_start[0] : output_tensor_end[0] + 1,
output_tensor_start[1] : output_tensor_end[1] + 1,
output_tensor_start[2] : output_tensor_end[2] + 1,
output_tensor_start[3] : output_tensor_end[3] + 1,
output_tensor_start[0] : output_tensor_end[0],
output_tensor_start[1] : output_tensor_end[1],
output_tensor_start[2] : output_tensor_end[2],
output_tensor_start[3] : output_tensor_end[3],
]

return out
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_indexed_slice(seed, B, b, tt_dtype, device):
assert num_non_zeros == B - b

a_pt = (
ttnn.slice(indices_tt, (0, 0, 0, 0), (0, 0, 0, num_non_zeros - 1), memory_config=mem_config)
ttnn.slice(indices_tt, (0, 0, 0, 0), (1, 1, 1, num_non_zeros), memory_config=mem_config)
.cpu()
.to(ttnn.ROW_MAJOR_LAYOUT)
.to_torch()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def test_run_padding_test(input_tensor_shape, output_tensor_shape, input_tensor_
@pytest.mark.parametrize(
"input_tensor_shape, output_tensor_start, output_tensor_end",
(
((1, 1, 5, 5), (0, 0, 1, 1), (0, 0, 3, 3)),
((2, 2, 5, 5), (0, 0, 0, 0), (0, 0, 2, 2)),
((1, 3, 32, 32), (0, 0, 0, 0), (0, 2, 29, 29)),
((3, 5, 32, 32), (1, 2, 0, 0), (1, 4, 29, 29)),
((3, 3, 64, 64), (0, 0, 32, 32), (0, 2, 61, 61)),
((1, 1, 5, 5), (0, 0, 1, 1), (1, 1, 4, 4)),
((2, 2, 5, 5), (0, 0, 0, 0), (1, 1, 3, 3)),
((1, 3, 32, 32), (0, 0, 0, 0), (1, 3, 30, 30)),
((3, 5, 32, 32), (1, 2, 0, 0), (2, 5, 30, 30)),
((3, 3, 64, 64), (0, 0, 32, 32), (1, 3, 62, 62)),
),
)
def test_run_unpadding_test(input_tensor_shape, output_tensor_start, output_tensor_end):
Expand All @@ -72,18 +72,16 @@ def test_run_unpadding_test(input_tensor_shape, output_tensor_start, output_tens
)

# Unpad inputs on host
output_tensor_shape = tuple(
output_tensor_end[i] - output_tensor_start[i] + 1 for i in range(len(input_tensor_shape))
)
output_tensor_shape = tuple(output_tensor_end[i] - output_tensor_start[i] for i in range(len(input_tensor_shape)))
a_unpad = a.unpad(output_tensor_start, output_tensor_end)
a_pt = a_unpad.to_torch()

# Pytorch reference
a_ref = inp[
output_tensor_start[0] : output_tensor_end[0] + 1,
output_tensor_start[1] : output_tensor_end[1] + 1,
output_tensor_start[2] : output_tensor_end[2] + 1,
output_tensor_start[3] : output_tensor_end[3] + 1,
output_tensor_start[0] : output_tensor_end[0],
output_tensor_start[1] : output_tensor_end[1],
output_tensor_start[2] : output_tensor_end[2],
output_tensor_start[3] : output_tensor_end[3],
]

# print("\n", a_pt.shape)
Expand All @@ -103,7 +101,7 @@ def test_run_unpadding_test(input_tensor_shape, output_tensor_start, output_tens
def test_run_padding_and_add_test(input_tensor_shape, output_tensor_shape, input_tensor_start, pad_value, device):
# Args for unpad
output_tensor_start = input_tensor_start
output_tensor_end = tuple(input_tensor_start[i] + input_tensor_shape[i] - 1 for i in range(len(input_tensor_shape)))
output_tensor_end = tuple(input_tensor_start[i] + input_tensor_shape[i] for i in range(len(input_tensor_shape)))

inp = torch.rand(*input_tensor_shape)
ones = torch.ones(*input_tensor_shape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def test_time_sharded_attnention_hwb(
v_slice = ttnn.slice(
reference_value_layer,
(0, (i * heads_per_slice), 0, 0),
(0, (i * heads_per_slice) + (heads_per_slice - 1), seq_len - 1, 63),
(1, (i * heads_per_slice) + (heads_per_slice), seq_len, 64),
memory_config=dram_interleaved_memory_config,
)

Expand Down Expand Up @@ -339,7 +339,7 @@ def test_time_sharded_attnention(
k_slice = ttnn.slice(
reference_key_layer_transposed,
(0, (i * heads_per_slice), 0, 0),
(0, (i * heads_per_slice) + (heads_per_slice - 1), 63, seq_len - 1),
(1, (i * heads_per_slice) + (heads_per_slice), 64, seq_len),
memory_config=l1_interleaved_memory_config,
)
mm_slice = ttnn.matmul(
Expand Down Expand Up @@ -376,7 +376,7 @@ def test_time_sharded_attnention(
v_slice = ttnn.slice(
reference_value_layer,
(0, (i * heads_per_slice), 0, 0),
(0, (i * heads_per_slice) + (heads_per_slice - 1), seq_len - 1, 63),
(1, (i * heads_per_slice) + (heads_per_slice), seq_len, 64),
memory_config=l1_interleaved_memory_config,
)
mm_slice = ttnn.matmul(
Expand Down
Loading

0 comments on commit 047cdd9

Please sign in to comment.