diff --git a/models/demos/falcon7b_common/tt/falcon_attention.py b/models/demos/falcon7b_common/tt/falcon_attention.py index 10185eac8a2..04f2704a929 100644 --- a/models/demos/falcon7b_common/tt/falcon_attention.py +++ b/models/demos/falcon7b_common/tt/falcon_attention.py @@ -618,7 +618,7 @@ def forward( key_layer = ttnn.slice( layer_past[0], [0, 0, 0, 0], - [batch - 1, 0, nearest_32(layer_past_len + 1) - 1, self.head_dim - 1], + [batch, 1, nearest_32(layer_past_len + 1), self.head_dim], memory_config=self.model_config["K_CACHE_SLICE_OUTPUT_MEMCFG"], ) @@ -747,7 +747,7 @@ def forward( value_layer = ttnn.slice( layer_past[1], [0, 0, 0, 0], - [batch - 1, 0, nearest_32(layer_past_len + 1) - 1, self.head_dim - 1], + [batch, 1, nearest_32(layer_past_len + 1), self.head_dim], memory_config=self.model_config["V_CACHE_SLICE_OUTPUT_MEMCFG"], ) if self.model_config["l1_sharded"]: @@ -800,10 +800,10 @@ def forward( attn_output, [0, 0, 0, 0], [ - attn_output_shape[0] - 1, - self.num_heads - 1, - attn_output_shape[2] - 1, - attn_output_shape[3] - 1, + attn_output_shape[0], + self.num_heads, + attn_output_shape[2], + attn_output_shape[3], ], memory_config=self.model_config["POST_SOFTMAX_MM_OUTPUT_MEMCFG"], ) diff --git a/models/demos/falcon7b_common/tt/falcon_mlp.py b/models/demos/falcon7b_common/tt/falcon_mlp.py index b6812dd54fa..e9554250c4b 100644 --- a/models/demos/falcon7b_common/tt/falcon_mlp.py +++ b/models/demos/falcon7b_common/tt/falcon_mlp.py @@ -368,7 +368,7 @@ def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: hidden_states = ttnn.slice( hidden_states, [0, 0, 0, 0], - [0, 0, batch_size - 1, self.hidden_size - 1], + [1, 1, batch_size, self.hidden_size], memory_config=self.model_config["DENSE_4H_TO_H_MM_OUTPUT_MEMCFG"], ) diff --git a/models/demos/t3000/falcon40b/tt/falcon_attention.py b/models/demos/t3000/falcon40b/tt/falcon_attention.py index e968b002f69..cc11aef5662 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_attention.py +++ b/models/demos/t3000/falcon40b/tt/falcon_attention.py @@ -489,10 +489,10 @@ def fwd_decode( layer_past[0], [0, 0, 0, 0], [ - batch - 1, - self.num_kv_heads // self.num_devices - 1, - padded_layer_past_len - 1, - self.head_dim - 1, + batch, + self.num_kv_heads // self.num_devices, + padded_layer_past_len, + self.head_dim, ], memory_config=self.model_config["DEFAULT_MEMCFG"], ) @@ -553,10 +553,10 @@ def fwd_decode( layer_past[1], [0, 0, 0, 0], [ - batch - 1, - self.num_kv_heads // self.num_devices - 1, - padded_layer_past_len - 1, - self.head_dim - 1, + batch, + self.num_kv_heads // self.num_devices, + padded_layer_past_len, + self.head_dim, ], memory_config=self.model_config["DEFAULT_MEMCFG"], ) diff --git a/models/demos/t3000/llama2_70b/tests/unit_tests/test_reshape_rotary.py b/models/demos/t3000/llama2_70b/tests/unit_tests/test_reshape_rotary.py index cf6e20261a7..108fdfb0879 100644 --- a/models/demos/t3000/llama2_70b/tests/unit_tests/test_reshape_rotary.py +++ b/models/demos/t3000/llama2_70b/tests/unit_tests/test_reshape_rotary.py @@ -96,7 +96,7 @@ def forward(self, xq, xk, rot_mat): xq = ttnn.slice( xq, [0, 0, 0, 0], - [1 - 1, 8 - 1, 128 - 1, self.head_dim - 1], + [1, 8, 128, self.head_dim], ) xk = ttnn.pad(xk, [1, 32, 128, self.head_dim], [0, 0, 0, 0], 0.0) @@ -113,7 +113,7 @@ def forward(self, xq, xk, rot_mat): xk = ttnn.slice( xk, [0, 0, 0, 0], - [1 - 1, 1 - 1, 128 - 1, self.head_dim - 1], + [1, 1, 128, self.head_dim], ) return xq, xk diff --git a/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py index 8cfa1a77ba8..526db82b437 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py @@ -410,7 +410,7 @@ def prefill_forward( norm_out_replicated = ttnn.slice( norm_out_replicated, (0, 0, last_token_tile * 32, 0), - (0, 0, (last_token_tile + 1) * 32 - 1, dmodel - 1), + (1, 1, (last_token_tile + 1) * 32, dmodel), memory_config=ttnn.DRAM_MEMORY_CONFIG, ) pc_lm_head = ( diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_model.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_model.py index 0e94450bba4..c2f7c740281 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_model.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_model.py @@ -101,7 +101,7 @@ def forward( # slicing for the last token if get_last_token != -1: x = ttnn.slice( - x, ttnn.Shape((0, 0, get_last_token, 0)), ttnn.Shape((0, 0, get_last_token + 31, 4095)) + x, (0, 0, get_last_token, 0), (1, 1, get_last_token + 32, 4096) ) # [:, :, get_last_token:get_last_token+32, :] x_norm = self.norm(x) diff --git a/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py index 86995fb4fc5..0b0bfb4ba4f 100644 --- a/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py @@ -418,7 +418,7 @@ def prefill_forward( norm_out = ttnn.slice( norm_out, [0, 0, seq_len - 32, 0], - [0, 0, seq_len - 1, dmodel - 1], + [1, 1, seq_len, dmodel], ) ### Each device does an LM head fracture diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py index 198b0fb5091..322495b1611 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py @@ -1020,7 +1020,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt if is_wormhole_b0() and self.batch_size == 16: xshape = x.shape_without_padding() - x = ttnn.slice(x, [0, 0, 0, 0], [xshape[0] - 1, xshape[1] - 1, xshape[2] - 1, xshape[3] - 1]) + x = ttnn.slice(x, [0, 0, 0, 0], [xshape[0], xshape[1], xshape[2], xshape[3]]) layer4_module1_input_shape = ttnn.Shape(x.get_legacy_shape()) diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py index 7e89e7484e2..a9712d3360e 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py @@ -690,7 +690,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt x = ttnn.slice( x, (0, 0, 0, 0), - (unpadded_shape[0] - 1, unpadded_shape[1] - 1, unpadded_shape[2] - 1, unpadded_shape[3] - 1), + (unpadded_shape[0], unpadded_shape[1], unpadded_shape[2], unpadded_shape[3]), memory_config=ttnn.L1_MEMORY_CONFIG, ) diff --git a/models/demos/wormhole/mamba/tt/mamba_block.py b/models/demos/wormhole/mamba/tt/mamba_block.py index 8795a24500d..e3d333ae9f2 100644 --- a/models/demos/wormhole/mamba/tt/mamba_block.py +++ b/models/demos/wormhole/mamba/tt/mamba_block.py @@ -196,8 +196,8 @@ def forward(self, x): for i in range(0, 4): slice_start = (0, 0, x_ssm.shape[2] - (4 - i), 0) - slice_end = (0, 0, x_ssm.shape[2] - (4 - i), self.args.d_inner - 1) - entry = ttnn.slice(x_ssm, ttnn.Shape(slice_start), ttnn.Shape(slice_end)) + slice_end = (1, 1, (x_ssm.shape[2] - (4 - i)) + 1, self.args.d_inner) + entry = ttnn.slice(x_ssm, slice_start, slice_end) self.convolution_cache.set(self.configs["current_user"], i, entry) ttnn.deallocate(entry) diff --git a/models/demos/wormhole/mamba/tt/mamba_conv.py b/models/demos/wormhole/mamba/tt/mamba_conv.py index 1d0033828a8..c4dd0d961ef 100644 --- a/models/demos/wormhole/mamba/tt/mamba_conv.py +++ b/models/demos/wormhole/mamba/tt/mamba_conv.py @@ -76,7 +76,7 @@ def prepare_input(self, input_tensor): split_size = self.config.input_channels // self.config.channels_split_factor for i in range(self.config.channels_split_factor): slice_start = ttnn.Shape((0, 0, 0, i * split_size)) - slice_end = ttnn.Shape((0, self.config.input_length - 1, 0, (i + 1) * split_size - 1)) + slice_end = ttnn.Shape((1, self.config.input_length, 1, (i + 1) * split_size)) input_tensor_splits.append(ttnn.slice(input_tensor, slice_start, slice_end)) ttnn.deallocate(input_tensor) return input_tensor_splits diff --git a/models/demos/wormhole/mamba/tt/preprocessing.py b/models/demos/wormhole/mamba/tt/preprocessing.py index e674028dd42..51970fd52ad 100644 --- a/models/demos/wormhole/mamba/tt/preprocessing.py +++ b/models/demos/wormhole/mamba/tt/preprocessing.py @@ -33,8 +33,8 @@ def split_sequence_length(x: ttnn.Tensor, batch: int = 0, chunk_size: int = 32): for i in range(0, L, chunk_size): slice_start = (0, 0, batch, i) - slice_end = (0, 0, batch, i + chunk_size - 1) - yield ttnn.slice(x, ttnn.Shape(slice_start), ttnn.Shape(slice_end)) + slice_end = (1, 1, batch + 1, i + chunk_size) + yield ttnn.slice(x, slice_start, slice_end) def select_chunk_size(sequence_length: int, max_chunk_size: int): diff --git a/models/demos/wormhole/stable_diffusion/demo/demo.py b/models/demos/wormhole/stable_diffusion/demo/demo.py index 71b44821e83..53fca566f81 100644 --- a/models/demos/wormhole/stable_diffusion/demo/demo.py +++ b/models/demos/wormhole/stable_diffusion/demo/demo.py @@ -59,10 +59,10 @@ def tt_guide(noise_pred, guidance_scale): # will return latents noise_pred, [1, 0, 0, 0], [ - noise_pred.shape[0] - 1, - noise_pred.shape[1] - 1, - noise_pred.shape[2] - 1, - noise_pred.shape[3] - 1, + noise_pred.shape[0], + noise_pred.shape[1], + noise_pred.shape[2], + noise_pred.shape[3], ], ) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_cross_attention.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_cross_attention.py index daec08200de..d722cd035a8 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_cross_attention.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_cross_attention.py @@ -367,7 +367,7 @@ def time_sharded_attention(self, query, t_key, value, head_size, attn_type): k_slice = ttnn.slice( t_key, (j, i, 0, 0), - (j, i, self.key_len - 1, self.seq_len - 1), + (j + 1, i + 1, self.key_len, self.seq_len), memory_config=self.l1_interleaved_memory_config, ) @@ -407,7 +407,7 @@ def time_sharded_attention(self, query, t_key, value, head_size, attn_type): v_slice = ttnn.slice( value, (j, i, 0, 0), - (j, i, self.seq_len - 1, self.key_len - 1), + (j, i, self.seq_len, self.key_len), memory_config=self.l1_interleaved_memory_config, ) mm_slice = ttnn.matmul( diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py index 89a1b0a6397..4e63fc9b13c 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py @@ -450,10 +450,10 @@ def __call__( hidden_states, [0, 0, 0, output_tensor_start_width_dim], [ - hidden_states.shape[0] - 1, - hidden_states.shape[1] - 1, - hidden_states.shape[2] - 1, - output_tensor_end_width_dim - 1, + hidden_states.shape[0], + hidden_states.shape[1], + hidden_states.shape[2], + output_tensor_end_width_dim, ], memory_config=ttnn.L1_MEMORY_CONFIG, ) diff --git a/models/experimental/bloom_old/bloom_utils.py b/models/experimental/bloom_old/bloom_utils.py index 61b520a8a3b..bd0b1dad561 100644 --- a/models/experimental/bloom_old/bloom_utils.py +++ b/models/experimental/bloom_old/bloom_utils.py @@ -79,9 +79,7 @@ def create_padded_tensor( def create_unpadded_tensor(ttm_tensor, input_tensors_shape, input_tensor_start=[0, 0, 0, 0]): output_tensor_start = input_tensor_start - output_tensor_end = tuple( - input_tensor_start[i] + input_tensors_shape[i] - 1 for i in range(len(input_tensors_shape)) - ) + output_tensor_end = tuple(input_tensor_start[i] + input_tensors_shape[i] for i in range(len(input_tensors_shape))) ttm_tensor = ttm_tensor.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad(output_tensor_start, output_tensor_end) return ttm_tensor diff --git a/models/experimental/nanogpt/nanogpt_utils.py b/models/experimental/nanogpt/nanogpt_utils.py index 69236af6bfd..50226779bd1 100644 --- a/models/experimental/nanogpt/nanogpt_utils.py +++ b/models/experimental/nanogpt/nanogpt_utils.py @@ -18,9 +18,7 @@ def unpad_from_zero(x, desired_shape): x = x.cpu() if x.get_layout() != ttnn.ROW_MAJOR_LAYOUT: x = x.to(ttnn.ROW_MAJOR_LAYOUT) - x = x.unpad( - (0, 0, 0, 0), (desired_shape[0] - 1, desired_shape[1] - 1, desired_shape[2] - 1, desired_shape[3] - 1) - ) + x = x.unpad((0, 0, 0, 0), (desired_shape[0], desired_shape[1], desired_shape[2], desired_shape[3])) x = x.to_torch().to(torch.float) return x diff --git a/models/utility_functions.py b/models/utility_functions.py index c158ff65025..699237c1778 100644 --- a/models/utility_functions.py +++ b/models/utility_functions.py @@ -299,10 +299,10 @@ def unpad_from_zero(x, desired_shape): x = x.unpad( (0, 0, 0, 0), ( - desired_shape[0] - 1, - desired_shape[1] - 1, - desired_shape[2] - 1, - desired_shape[3] - 1, + desired_shape[0], + desired_shape[1], + desired_shape[2], + desired_shape[3], ), ) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/generation_funcs.py b/tests/tt_eager/python_api_testing/sweep_tests/generation_funcs.py index e45737ccfc7..d482b03f0ef 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/generation_funcs.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/generation_funcs.py @@ -302,7 +302,7 @@ def gen_tensor_unpad_args( assert len(input_shapes[0]) == 4 test_args = {} output_tensor_start = [random.randint(0, input_shapes[0][i] - 1) for i in range(4)] - output_tensor_end = [random.randint(output_tensor_start[i], input_shapes[0][i] - 1) for i in range(4)] + output_tensor_end = [random.randint(output_tensor_start[i] + 1, input_shapes[0][i]) for i in range(4)] test_args.update( { @@ -917,8 +917,10 @@ def gen_unpad_args( if input_info is not None: if input_info["layout"][0] == ttnn.ROW_MAJOR_LAYOUT: output_tensor_start = [0, 0, 0, 0] - output_tensor_end = [random.randrange(output_tensor_start[i], input_shapes[0][i], 1) for i in range(4)] - if output_tensor_end[-1] % 2 == 0: + output_tensor_end = [ + random.randrange(output_tensor_start[i] + 1, input_shapes[0][i], 1) for i in range(4) + ] + if output_tensor_end[-1] % 2 != 0: output_tensor_end[-1] += 1 input_info.update( { @@ -928,9 +930,11 @@ def gen_unpad_args( ) elif input_info["layout"][0] == ttnn.TILE_LAYOUT: output_tensor_start = [0, 0, 0, 0] - output_tensor_end = [random.randrange(output_tensor_start[i], input_shapes[0][i], 1) for i in range(4)] - output_tensor_end[-2] = max(nearest_32(output_tensor_end[-2]), 32) - 1 - output_tensor_end[-1] = max(nearest_32(output_tensor_end[-1]), 32) - 1 + output_tensor_end = [ + random.randrange(output_tensor_start[i] + 1, input_shapes[0][i], 1) for i in range(4) + ] + output_tensor_end[-2] = max(nearest_32(output_tensor_end[-2]), 32) + output_tensor_end[-1] = max(nearest_32(output_tensor_end[-1]), 32) input_info.update( { "output_tensor_start": output_tensor_start, diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py index c745be2f5f9..d77f952d7fd 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py @@ -1338,10 +1338,10 @@ def pad(x, *args, output_tensor_shape, input_tensor_start, pad_value, **kwargs): def unpad(x, *args, output_tensor_start, output_tensor_end, **kwargs): out = x[ - output_tensor_start[0] : output_tensor_end[0] + 1, - output_tensor_start[1] : output_tensor_end[1] + 1, - output_tensor_start[2] : output_tensor_end[2] + 1, - output_tensor_start[3] : output_tensor_end[3] + 1, + output_tensor_start[0] : output_tensor_end[0], + output_tensor_start[1] : output_tensor_end[1], + output_tensor_start[2] : output_tensor_end[2], + output_tensor_start[3] : output_tensor_end[3], ] return out diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_non_zero.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_non_zero.py index 343abffed1a..e672856c3e2 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_non_zero.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_non_zero.py @@ -69,7 +69,7 @@ def test_indexed_slice(seed, B, b, tt_dtype, device): assert num_non_zeros == B - b a_pt = ( - ttnn.slice(indices_tt, (0, 0, 0, 0), (0, 0, 0, num_non_zeros - 1), memory_config=mem_config) + ttnn.slice(indices_tt, (0, 0, 0, 0), (1, 1, 1, num_non_zeros), memory_config=mem_config) .cpu() .to(ttnn.ROW_MAJOR_LAYOUT) .to_torch() diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_padding_test.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_padding_test.py index c0cae9eb37a..726963c8465 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_padding_test.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_padding_test.py @@ -53,11 +53,11 @@ def test_run_padding_test(input_tensor_shape, output_tensor_shape, input_tensor_ @pytest.mark.parametrize( "input_tensor_shape, output_tensor_start, output_tensor_end", ( - ((1, 1, 5, 5), (0, 0, 1, 1), (0, 0, 3, 3)), - ((2, 2, 5, 5), (0, 0, 0, 0), (0, 0, 2, 2)), - ((1, 3, 32, 32), (0, 0, 0, 0), (0, 2, 29, 29)), - ((3, 5, 32, 32), (1, 2, 0, 0), (1, 4, 29, 29)), - ((3, 3, 64, 64), (0, 0, 32, 32), (0, 2, 61, 61)), + ((1, 1, 5, 5), (0, 0, 1, 1), (1, 1, 4, 4)), + ((2, 2, 5, 5), (0, 0, 0, 0), (1, 1, 3, 3)), + ((1, 3, 32, 32), (0, 0, 0, 0), (1, 3, 30, 30)), + ((3, 5, 32, 32), (1, 2, 0, 0), (2, 5, 30, 30)), + ((3, 3, 64, 64), (0, 0, 32, 32), (1, 3, 62, 62)), ), ) def test_run_unpadding_test(input_tensor_shape, output_tensor_start, output_tensor_end): @@ -72,18 +72,16 @@ def test_run_unpadding_test(input_tensor_shape, output_tensor_start, output_tens ) # Unpad inputs on host - output_tensor_shape = tuple( - output_tensor_end[i] - output_tensor_start[i] + 1 for i in range(len(input_tensor_shape)) - ) + output_tensor_shape = tuple(output_tensor_end[i] - output_tensor_start[i] for i in range(len(input_tensor_shape))) a_unpad = a.unpad(output_tensor_start, output_tensor_end) a_pt = a_unpad.to_torch() # Pytorch reference a_ref = inp[ - output_tensor_start[0] : output_tensor_end[0] + 1, - output_tensor_start[1] : output_tensor_end[1] + 1, - output_tensor_start[2] : output_tensor_end[2] + 1, - output_tensor_start[3] : output_tensor_end[3] + 1, + output_tensor_start[0] : output_tensor_end[0], + output_tensor_start[1] : output_tensor_end[1], + output_tensor_start[2] : output_tensor_end[2], + output_tensor_start[3] : output_tensor_end[3], ] # print("\n", a_pt.shape) @@ -103,7 +101,7 @@ def test_run_unpadding_test(input_tensor_shape, output_tensor_start, output_tens def test_run_padding_and_add_test(input_tensor_shape, output_tensor_shape, input_tensor_start, pad_value, device): # Args for unpad output_tensor_start = input_tensor_start - output_tensor_end = tuple(input_tensor_start[i] + input_tensor_shape[i] - 1 for i in range(len(input_tensor_shape))) + output_tensor_end = tuple(input_tensor_start[i] + input_tensor_shape[i] for i in range(len(input_tensor_shape))) inp = torch.rand(*input_tensor_shape) ones = torch.ones(*input_tensor_shape) diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_sharded_attention.py b/tests/ttnn/integration_tests/stable_diffusion/test_sharded_attention.py index 1a13a97c6e9..1b45761e11c 100644 --- a/tests/ttnn/integration_tests/stable_diffusion/test_sharded_attention.py +++ b/tests/ttnn/integration_tests/stable_diffusion/test_sharded_attention.py @@ -202,7 +202,7 @@ def test_time_sharded_attnention_hwb( v_slice = ttnn.slice( reference_value_layer, (0, (i * heads_per_slice), 0, 0), - (0, (i * heads_per_slice) + (heads_per_slice - 1), seq_len - 1, 63), + (1, (i * heads_per_slice) + (heads_per_slice), seq_len, 64), memory_config=dram_interleaved_memory_config, ) @@ -339,7 +339,7 @@ def test_time_sharded_attnention( k_slice = ttnn.slice( reference_key_layer_transposed, (0, (i * heads_per_slice), 0, 0), - (0, (i * heads_per_slice) + (heads_per_slice - 1), 63, seq_len - 1), + (1, (i * heads_per_slice) + (heads_per_slice), 64, seq_len), memory_config=l1_interleaved_memory_config, ) mm_slice = ttnn.matmul( @@ -376,7 +376,7 @@ def test_time_sharded_attnention( v_slice = ttnn.slice( reference_value_layer, (0, (i * heads_per_slice), 0, 0), - (0, (i * heads_per_slice) + (heads_per_slice - 1), seq_len - 1, 63), + (1, (i * heads_per_slice) + (heads_per_slice), seq_len, 64), memory_config=l1_interleaved_memory_config, ) mm_slice = ttnn.matmul( diff --git a/tests/ttnn/profiling/ops_for_profiling.py b/tests/ttnn/profiling/ops_for_profiling.py index 61ca8f1c617..f84a445917c 100644 --- a/tests/ttnn/profiling/ops_for_profiling.py +++ b/tests/ttnn/profiling/ops_for_profiling.py @@ -1186,10 +1186,10 @@ def ttnn_slice(x): shape = x.get_legacy_shape() output_tensor_end = ( - shape[0] - 1, - shape[1] - 1, - shape[2] - 33, - shape[3] - 33, + shape[0], + shape[1], + shape[2] - 32, + shape[3] - 32, ) ttnn.slice(x, (0, 0, 0, 0), output_tensor_end) diff --git a/tests/ttnn/python_api_testing/non_working_unit_tests/grayskull/test_unpad.py b/tests/ttnn/python_api_testing/non_working_unit_tests/grayskull/test_unpad.py index 545182235a2..7955876b9c8 100644 --- a/tests/ttnn/python_api_testing/non_working_unit_tests/grayskull/test_unpad.py +++ b/tests/ttnn/python_api_testing/non_working_unit_tests/grayskull/test_unpad.py @@ -14,10 +14,10 @@ def ref_unpad(x, *args, output_tensor_start, output_tensor_end, **kwargs): out = x[ - output_tensor_start[0] : output_tensor_end[0] + 1, - output_tensor_start[1] : output_tensor_end[1] + 1, - output_tensor_start[2] : output_tensor_end[2] + 1, - output_tensor_start[3] : output_tensor_end[3] + 1, + output_tensor_start[0] : output_tensor_end[0], + output_tensor_start[1] : output_tensor_end[1], + output_tensor_start[2] : output_tensor_end[2], + output_tensor_start[3] : output_tensor_end[3], ] return out diff --git a/tests/ttnn/unit_tests/operations/test_fold_op.py b/tests/ttnn/unit_tests/operations/test_fold_op.py index f00d66ab2b1..4e645594936 100644 --- a/tests/ttnn/unit_tests/operations/test_fold_op.py +++ b/tests/ttnn/unit_tests/operations/test_fold_op.py @@ -135,7 +135,7 @@ def pad_and_fold_with_permute_and_reshape_on_device( activation_pyt_padded = ttnn.slice( activation_pyt_padded, (0, 0, 0, 0), - (n - 1, target_h - 1, target_w - 1, c - 1), + (n, target_h, target_w, c), memory_config=ttnn.L1_MEMORY_CONFIG, ) @@ -241,7 +241,7 @@ def pad_and_fold_with_permute_and_reshape_on_device_sharded(device, tt_input_ten tt_output_tensor = ttnn.slice( tt_output_tensor, (0, 0, 0, 0), - (n - 1, target_h - 1, target_w - 1, c - 1), + (n, target_h, target_w, c), memory_config=slice_sharded_memory_config, ) print("output " + str(tt_output_tensor.shape)) diff --git a/tests/ttnn/unit_tests/operations/test_slice.py b/tests/ttnn/unit_tests/operations/test_slice.py index 251749ef38b..331d37ec768 100644 --- a/tests/ttnn/unit_tests/operations/test_slice.py +++ b/tests/ttnn/unit_tests/operations/test_slice.py @@ -53,7 +53,7 @@ def run_slice_rm_sharded(device, n, c, h, w): tt_output_tensor = ttnn.slice( tt_input_tensor, (0, 0, 0, 0), - (n_unpadded - 1, c_unpadded - 1, h_unpadded - 1, w - 1), + (n_unpadded, c_unpadded, h_unpadded, w), memory_config=output_mem_config, ) tt_output_tensor = ttnn.to_memory_config(tt_output_tensor, ttnn.L1_MEMORY_CONFIG) @@ -99,7 +99,7 @@ def test_slice_rm(device, n, c, h, w): activation_pyt_padded = ttnn.slice( activation_pyt_padded, (0, 0, 2, 0), - (n - 1, 115 - 1, 115 - 1, w - 1), + (n, 115, 115, w), memory_config=ttnn.L1_MEMORY_CONFIG, ) activation_pyt_padded_out = ttnn.to_memory_config(activation_pyt_padded, ttnn.L1_MEMORY_CONFIG) @@ -133,10 +133,10 @@ def slice_test( # Pytorch reference a_ref = torch_input_tensor[ - output_tensor_start[0] : output_tensor_end[0] + 1, - output_tensor_start[1] : output_tensor_end[1] + 1, - output_tensor_start[2] : output_tensor_end[2] + 1, - output_tensor_start[3] : output_tensor_end[3] + 1, + output_tensor_start[0] : output_tensor_end[0], + output_tensor_start[1] : output_tensor_end[1], + output_tensor_start[2] : output_tensor_end[2], + output_tensor_start[3] : output_tensor_end[3], ] return a_pt, a_ref, device.num_program_cache_entries() @@ -160,19 +160,19 @@ def slice_test( @pytest.mark.parametrize( "input_tensor_shape_0, output_tensor_start_0, output_tensor_end_0", ( - ((4, 3, 64, 64), (0, 0, 0, 0), (3, 2, 31, 31)), - ((1, 1, 64, 64), (0, 0, 0, 0), (0, 0, 31, 63)), - ((1, 1, 128, 96), (0, 0, 64, 32), (0, 0, 95, 95)), - ((1, 1, 128, 96), (0, 0, 64, 32), (0, 0, 95, 95)), - ((1, 3, 32, 32), (0, 1, 0, 0), (0, 2, 31, 31)), - ((1, 6, 32, 32), (0, 2, 0, 0), (0, 4, 31, 31)), - ((1, 6, 128, 64), (0, 2, 64, 32), (0, 4, 95, 63)), - ((4, 6, 128, 64), (1, 2, 64, 32), (2, 4, 95, 63)), + ((4, 3, 64, 64), (0, 0, 0, 0), (4, 3, 32, 32)), + ((1, 1, 64, 64), (0, 0, 0, 0), (1, 1, 32, 64)), + ((1, 1, 128, 96), (0, 0, 64, 32), (1, 1, 96, 96)), + ((1, 1, 128, 96), (0, 0, 64, 32), (1, 1, 96, 96)), + ((1, 3, 32, 32), (0, 1, 0, 0), (1, 2, 32, 32)), + ((1, 6, 32, 32), (0, 2, 0, 0), (1, 4, 32, 32)), + ((1, 6, 128, 64), (0, 2, 64, 32), (1, 4, 96, 64)), + ((4, 6, 128, 64), (1, 2, 64, 32), (2, 4, 96, 64)), ), ) @pytest.mark.parametrize( "input_tensor_shape_1, output_tensor_start_1, output_tensor_end_1", - (((9, 8, 128, 128), (0, 0, 0, 0), (8, 7, 31, 31)),), + (((9, 8, 128, 128), (0, 0, 0, 0), (9, 8, 32, 32)),), ) def test_run_slice_test( input_tensor_shape_0, @@ -390,7 +390,7 @@ def test_slice_negative_ends(layout, dim, ends, device): ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout) if dim == 3: - if layout == ttnn.ROW_MAJOR_LAYOUT: + if layout == ttnn.ROW_MAJOR_LAYOUT and ends == -32: pytest.skip("Page size will become 0 and we don't handle transforming pages to second last dimension") torch_output = torch_input[:, :, :, 0:ends] ttnn_output = ttnn_input[:, :, :, 0:ends] @@ -450,13 +450,12 @@ def test_slice_bert(input_shape, input_start, input_ends, layout, device): assert_with_pcc(torch_output, ttnn_output, 0.99) -@pytest.mark.xfail(reason="2D slices and negative ends are not supported in ttnn.slice path") @pytest.mark.parametrize( "input_shape, input_start, input_ends", ( ((1, 1, 1, 256), (0, 0, 0, 0), (1, 1, 1, -1)), - ((1, 256), (0, 0), (-1, 1)), - ((1, 512), (0, 0), (-1, 1)), + ((1, 256), (0, 0), (-1, 256)), + ((1, 512), (0, 0), (-1, 512)), ((1, 512), (0, 0), (1, 256)), ), ) @@ -464,7 +463,11 @@ def test_slice_bert(input_shape, input_start, input_ends, layout, device): "layout", (ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT), ) -def test_ttnn_slice_bert(input_shape, input_start, input_ends, layout, device): +@pytest.mark.parametrize( + "memory_config", + (ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG), +) +def test_ttnn_slice_bert(input_shape, input_start, input_ends, layout, memory_config, device): if layout == ttnn.TILE_LAYOUT: torch_input = torch.randn(input_shape, dtype=torch.bfloat16) ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout) @@ -482,7 +485,7 @@ def test_ttnn_slice_bert(input_shape, input_start, input_ends, layout, device): elif len(input_shape) == 2: torch_output = torch_input[input_start[0] : input_ends[0], input_start[1] : input_ends[1]] - ttnn_output = ttnn.slice(ttnn_input, list(input_start), list(input_ends)) + ttnn_output = ttnn.slice(ttnn_input, input_start, input_ends, memory_config=memory_config) ttnn_output = ttnn.to_torch(ttnn_output) assert_with_pcc(torch_output, ttnn_output, 0.99) @@ -496,7 +499,7 @@ def test_slice_output_tensor_rm(device): torch_output = torch_input[..., ::2, ::2] # torch_output shape: [1, 3, 320, 320] pages_before = ttnn._ttnn.reports.get_buffer_pages() - ttnn.slice(ttnn_input, [0, 0, 0, 0], [0, 2, 319, 319], output_tensor=ttnn_output) + ttnn.slice(ttnn_input, [0, 0, 0, 0], [1, 3, 320, 320], output_tensor=ttnn_output) assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages()) ttnn_output = ttnn.to_torch(ttnn_output) @@ -514,7 +517,7 @@ def test_slice_output_tensor_tile(device): torch_output = torch_input[..., ::2, ::2] # torch_output shape: [1, 3, 320, 320] pages_before = ttnn._ttnn.reports.get_buffer_pages() - ttnn.slice(ttnn_input, [0, 0, 0, 0], [0, 2, 319, 319], output_tensor=ttnn_output) + ttnn.slice(ttnn_input, [0, 0, 0, 0], [1, 3, 320, 320], output_tensor=ttnn_output) assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages()) ttnn_output = ttnn.to_torch(ttnn_output) diff --git a/tests/ttnn/unit_tests/test_getitem.py b/tests/ttnn/unit_tests/test_getitem.py index 1333fce3ea9..455f82bcf47 100644 --- a/tests/ttnn/unit_tests/test_getitem.py +++ b/tests/ttnn/unit_tests/test_getitem.py @@ -83,7 +83,7 @@ def test_getitem_scalar_output(): with pytest.raises(RuntimeError) as e: input_tensor[0, 0] - assert "ttnn.Tensor.__getitem__: cannot return a scalar!" in str(e.value) + assert "Host tensor slice cannot return a scalar or empty tensor" in str(e.value) @pytest.mark.parametrize("batch_sizes", [(), (1, 1)]) diff --git a/ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp b/ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp index 8bdbdaf59db..6b93c80a49c 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp @@ -92,8 +92,9 @@ std::vector fold_with_transpose_( // slice n = output_shape.value()[0], w = output_shape.value()[1], h = output_shape.value()[2], c = output_shape.value()[3]; tt::tt_metal::Array4D slice_output_tensor_start = {0, 0, 0, 0}; - tt::tt_metal::Array4D slice_output_tensor_end = {n - 1, w - 1, h - 1, c - 1}; - auto slice_output = ttnn::slice(transpose_hc_output2, slice_output_tensor_start, slice_output_tensor_end, std::nullopt, L1_mem_config); + tt::tt_metal::Array4D slice_output_tensor_end = {n, w, h, c}; + tt::tt_metal::Array4D step = {1, 1, 1, 1}; + auto slice_output = ttnn::slice(transpose_hc_output2, slice_output_tensor_start, slice_output_tensor_end, step, L1_mem_config); output_tensors.emplace_back(slice_output); @@ -243,18 +244,19 @@ std::vector fold_with_transpose_sharded_( std::vector output_tensors; // override output shape + auto steps = tt::tt_metal::Array4D({1, 1, 1, 1}); if (output_shape.has_value()) { // slice n = output_shape.value()[0], h = output_shape.value()[1], w = output_shape.value()[2], c = output_shape.value()[3]; tt::tt_metal::Array4D slice_output_tensor_start = {0, 0, 0, 0}; - tt::tt_metal::Array4D slice_output_tensor_end = {n - 1, h - 1, w - 1, c - 1}; + tt::tt_metal::Array4D slice_output_tensor_end = {n, h, w, c}; auto slice_mem_config = create_sharded_memory_config( ttnn::Shape(tt::tt_metal::Array4D{n, h, w, c}), grid_size, shard_spec.orientation, override_memory_config ); - tt_output_tensor = ttnn::slice(tt_output_tensor, slice_output_tensor_start, slice_output_tensor_end, std::nullopt, slice_mem_config); + tt_output_tensor = ttnn::slice(tt_output_tensor, slice_output_tensor_start, slice_output_tensor_end, steps, slice_mem_config); output_tensors.emplace_back(tt_output_tensor); @@ -263,14 +265,14 @@ std::vector fold_with_transpose_sharded_( // slice n = slice_output_shape[0], h = slice_output_shape[1], w = slice_output_shape[2], c = slice_output_shape[3]; tt::tt_metal::Array4D slice_output_tensor_start = {0, 0, 0, 0}; - tt::tt_metal::Array4D slice_output_tensor_end = {n - 1, h - 1, w - 1, c - 1}; + tt::tt_metal::Array4D slice_output_tensor_end = {n, h, w, c}; auto slice_mem_config = create_sharded_memory_config( ttnn::Shape(tt::tt_metal::Array4D{n, h, w, c}), grid_size, shard_spec.orientation, override_memory_config ); - tt_output_tensor = ttnn::slice(tt_output_tensor, slice_output_tensor_start, slice_output_tensor_end, std::nullopt, slice_mem_config); + tt_output_tensor = ttnn::slice(tt_output_tensor, slice_output_tensor_start, slice_output_tensor_end, steps, slice_mem_config); output_tensors.emplace_back(tt_output_tensor); diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp index 6e3f5692ea1..2ef9c625563 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp @@ -76,8 +76,7 @@ void SliceDeviceOperation::validate_with_output_tensors( TT_FATAL(input_tensor_a.get_legacy_shape().rank() == this->slice_start.rank() && this->slice_start.rank() == this->slice_end.rank(), "Error"); for (uint32_t i = 0; i < input_tensor_a.get_legacy_shape().rank(); i++) { TT_FATAL(this->slice_start[i] < input_tensor_a.get_legacy_shape()[i], "Error"); - TT_FATAL(this->slice_end[i] < input_tensor_a.get_legacy_shape()[i], "Error"); - + TT_FATAL(this->slice_end[i] <= input_tensor_a.get_legacy_shape()[i], "Ends {} must be less than or equal to the shape of the tensor {}", this->slice_end[i], input_tensor_a.get_legacy_shape()[i]); // Check if start shape is <= end shape TT_FATAL(this->slice_start[i] <= this->slice_end[i], "Error"); } @@ -91,6 +90,7 @@ void SliceDeviceOperation::validate_with_output_tensors( TT_FATAL(input_tensor_a.get_layout() == Layout::ROW_MAJOR, "Strided slice is only supported for row major layout"); TT_FATAL(!input_tensor_a.is_sharded(), "Strided slice is not supported for sharded tensor"); TT_FATAL(input_tensor_a.get_dtype() == DataType::BFLOAT16, "Strided slice is only supported for BFLOAT16"); + TT_FATAL(this->step.value().size() == this->slice_end.rank(), "Number of steps {} must match number of ends/starts {}", this->step.value().size(), this->slice_end.rank()); } if (input_tensor_a.get_layout() == Layout::TILE) { TT_FATAL(input_tensor_a.volume() % TILE_HW == 0, "Error"); @@ -122,16 +122,13 @@ std::vector SliceDeviceOperation::compute_output_shap out_shape.reserve(rank); if (!step.has_value()) { for (uint32_t i = 0; i < rank; i++) { - out_shape.push_back(this->slice_end[i] - this->slice_start[i] + 1); + out_shape.push_back(this->slice_end[i] - this->slice_start[i]); } } else { + auto output_dim_i = [this] (size_t i) { - int res = 0; - for (int j = this->slice_start[i]; j < this->slice_end[i] + 1; j+=this->step.value()[i]) { - res++; - } - return res; + return (this->slice_end[i] - this->slice_start[i] + this->step.value()[i] - 1) / this->step.value()[i]; }; for (uint32_t i = 0; i < rank; i++) { out_shape.push_back(output_dim_i(i)); diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.cpp index 4701e2e4b7d..ea939d7c058 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.cpp @@ -328,10 +328,10 @@ operation::ProgramWithCallbacks slice_rm_strided_single_core(const Tensor& a, Te output_tensor_start[2], output_tensor_start[1], output_tensor_start[0], - output_tensor_end[3] + 1, - output_tensor_end[2] + 1, - output_tensor_end[1] + 1, - output_tensor_end[0] + 1, + output_tensor_end[3], + output_tensor_end[2], + output_tensor_end[1], + output_tensor_end[0], }); diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp b/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp index acc4b204e2b..7e28ad1a99e 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp @@ -8,141 +8,246 @@ #include "device/slice_op.hpp" #include "ttnn/run_operation.hpp" #include "ttnn/operations/core/core.hpp" +#include "ttnn/cpp/ttnn/operations/creation.hpp" #include "ttnn/common/constants.hpp" namespace ttnn::operations::data_movement { +namespace detail { + uint32_t wrap_index(int index, int size) { + return index < 0 ? size + index : index; + } + uint32_t round_up_to_multiple_of_32(uint32_t value) { + return value == 0 ? 32 : ((value + 31) & ~31); + } +} +template ttnn::Tensor SliceOperation::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - tt::tt_metal::LegacyShape output_tensor_start, - tt::tt_metal::LegacyShape output_tensor_end, - const std::optional step, + const std::vector &begins, + const std::vector &ends, + const std::vector &step, const std::optional& memory_config_arg, const std::optional& optional_output_tensor) { - std::optional modified_step = step; - if (modified_step.has_value()) { - if (std::all_of(modified_step->begin(), modified_step->end(), [](int32_t s) { return s == 1; })) { - modified_step = std::nullopt; + + // Ensure start and end vectors have matching sizes and correct tensor rank + uint32_t input_rank = input_tensor.get_shape().rank(); + TT_FATAL(input_rank == begins.size(), "Input rank {} and begins {} must have the same size", input_rank, begins.size()); + TT_FATAL(begins.size() == ends.size(), "Start {} and end {} must have the same size", begins.size(), ends.size()); + TT_FATAL(step.size() == begins.size(), "Step {} must have the same size as start {} and end", step.size(), begins.size()); + + // Create modified vectors with appropriate size (max rank 4) and wrap indices + Tensor input_4d = (input_rank < 4) ? ttnn::unsqueeze_to_4D(input_tensor) : input_tensor; + auto padded_4d_shape = input_4d.get_legacy_shape(); + std::array modified_begins = {0, 0, 0, 0}; + std::array modified_ends = {padded_4d_shape[0], padded_4d_shape[1], padded_4d_shape[2], padded_4d_shape[3]}; + std::array modified_step = {1, 1, 1, 1}; + uint32_t rank_diff = 4 - input_rank; + + // Ideally we would call the 4D array implementation of slice here and then handle reshapes and padding outside of it but it's not ready yet + // Insert values for start, step and end, wrapping indices using detail::wrap_index + // should be able to skip wrap_index if T is uint32_t + for (size_t i = 0; i < begins.size(); ++i) { + modified_begins[i + rank_diff] = detail::wrap_index(begins[i], input_tensor.get_shape()[i]); + modified_ends[i + rank_diff] = detail::wrap_index(ends[i], input_tensor.get_shape()[i]); + modified_step[i + rank_diff] = step[i]; + } + + auto output_dim_i = [&modified_begins, &modified_step] (size_t i, const std::array &modified_ends) { + return (modified_ends[i] - modified_begins[i] + modified_step[i] - 1) / modified_step[i]; + }; + + std::array padded_ends = modified_ends; + if (input_tensor.layout() == Layout::TILE) { + padded_ends[2] = detail::round_up_to_multiple_of_32(padded_ends[2]); + padded_ends[3] = detail::round_up_to_multiple_of_32(padded_ends[3]); + } + std::vector actual_shape, padded_shape; + actual_shape.reserve(input_rank); + padded_shape.reserve(input_rank); + bool empty = false; + for (int i = 0; i < input_rank; ++i) { + // Check that end indices are greater than or equal to start indices (empty tensor where end=start is supported) + TT_FATAL(modified_ends[i + rank_diff] >= modified_begins[i + rank_diff], "End {} must be greater than or equal to start {}", modified_ends[i + rank_diff], modified_begins[i + rank_diff]); + auto val = output_dim_i(i + rank_diff, modified_ends); + if (val == 0) { + empty = true; } + actual_shape.push_back(val); + padded_shape.push_back(std::max(output_dim_i(i + rank_diff, padded_ends), (uint32_t)1)); + } + + ttnn::Shape output_shape(actual_shape, padded_shape); + // PyTorch supports final dimension = 0 (start = end, where end is inclusive) so >= is okay, just return an empty tensor + if (empty) { + TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Host tensor slice cannot return a scalar or empty tensor"); + return ttnn::empty(output_shape, input_tensor.dtype(), input_tensor.layout(), + input_tensor.device(), memory_config_arg.value_or(input_tensor.memory_config())); + } + + // Early exit if slice is a no-op (ends = padding ends and step = 1 for all dimensions) + bool no_step = std::all_of(step.begin(), step.end(), [](int i) {return i == 1;}); + if (tt::tt_metal::LegacyShape(padded_shape) == input_tensor.get_legacy_shape() and no_step) { + return ttnn::reshape(input_tensor, output_shape); } + if (input_tensor.storage_type() != StorageType::DEVICE) { - TT_FATAL(!modified_step.has_value(), "Host tensor slice does not support strides"); - tt::tt_metal::LegacyShape output_tensor_shape = { - output_tensor_end[0] - output_tensor_start[0] + 1, - output_tensor_end[1] - output_tensor_start[1] + 1, - output_tensor_end[2] - output_tensor_start[2] + 1, - output_tensor_end[3] - output_tensor_start[3] + 1, - }; + TT_FATAL(no_step, "Host tensor slice does not support strides"); // if we support negative strides, we can't do this early exit - if (input_tensor.get_legacy_shape() == output_tensor_shape) { + if (input_tensor.get_legacy_shape() == actual_shape) { return input_tensor; } else { - return input_tensor.unpad(output_tensor_start, output_tensor_end); + auto input_4d_rm = ttnn::to_layout(input_4d, Layout::ROW_MAJOR, std::nullopt, std::nullopt, (Device *)nullptr); + auto output_4d = input_4d_rm.unpad(tt::tt_metal::LegacyShape(modified_begins), tt::tt_metal::LegacyShape(modified_ends)); + auto output_4d_rm = ttnn::to_layout(output_4d, input_tensor.get_layout(), std::nullopt, std::nullopt, (Device *)nullptr); + return ttnn::reshape(output_4d_rm, output_shape); } } else { - auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input_tensor.memory_config()); // TODO: Generalize this early exit of slice for other cases - auto& input_tensor_shape = input_tensor.get_legacy_shape(); - if (input_tensor.is_sharded() && input_tensor.memory_config() == memory_config && - input_tensor_shape.rank() > 1 && input_tensor_shape.rank() == output_tensor_start.rank() && - output_tensor_start.rank() == output_tensor_end.rank()) { - TT_FATAL(!modified_step.has_value(), "Sharded tensor slice implementation does not support striding"); + auto& input_tensor_shape = input_4d.get_legacy_shape(); + auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input_tensor.memory_config()); + if (input_4d.is_sharded() && input_4d.memory_config() == memory_config && + input_tensor_shape.rank() > 1) { + TT_FATAL(no_step, "Sharded tensor slice implementation does not support striding"); uint32_t i; // Require all leading dims to be 1 (TODO: This can be relaxed to support outermost non-1 dim unpadding) bool in_place_unpad = true; - for (i = 0; i < input_tensor.get_legacy_shape().rank() - 2; ++i) { + for (i = 0; i < input_4d.get_legacy_shape().rank() - 2; ++i) { in_place_unpad &= - output_tensor_start[i] == 0 && output_tensor_end[i] == 0 && input_tensor_shape[i] == 1; + modified_begins[i] == 0 && modified_ends[i] == 1 && input_tensor_shape[i] == 1; } - in_place_unpad &= output_tensor_start[i] == 0 && - tt::div_up(output_tensor_end[i] + 1, input_tensor.shard_spec().value().shape[0]) == - tt::div_up(input_tensor_shape[i], input_tensor.shard_spec().value().shape[0]); + in_place_unpad &= modified_begins[i] == 0 && + tt::div_up(modified_ends[i], input_4d.shard_spec().value().shape[0]) == + tt::div_up(input_tensor_shape[i], input_4d.shard_spec().value().shape[0]); i++; - in_place_unpad &= output_tensor_start[i] == 0 && output_tensor_end[i] == input_tensor_shape[i] - 1; + in_place_unpad &= modified_begins[i] == 0 && modified_ends[i] == input_tensor_shape[i]; if (in_place_unpad) { - auto new_shape = input_tensor.get_legacy_shape(); - auto new_pad = new_shape.padding(); - - std::size_t unpad_val = input_tensor_shape[-2] - output_tensor_end[-2] - 1; - new_shape[-2] -= unpad_val; - new_pad[-2].back -= std::min(unpad_val, new_pad[-2].back); - auto padded_shape = ttnn::Shape(tt::tt_metal::LegacyShape(new_shape, new_pad)); - return Tensor(input_tensor.storage(), padded_shape, input_tensor.dtype(), input_tensor.layout()); + return ttnn::reshape(input_tensor, output_shape); } } - return operation::run( - SliceDeviceOperation{output_tensor_start, output_tensor_end, modified_step, memory_config}, {input_tensor}, {}, {optional_output_tensor}, queue_id) + auto res = operation::run( + SliceDeviceOperation{ + tt::tt_metal::LegacyShape(modified_begins), + tt::tt_metal::LegacyShape(padded_ends), + no_step ? std::nullopt : std::optional(tt::tt_metal::LegacyShape(modified_step)), + memory_config}, + {input_4d}, {}, {optional_output_tensor}, queue_id) .at(0); + res = ttnn::reshape(res, output_shape); + return res; } } - +template ttnn::Tensor SliceOperation::invoke( const ttnn::Tensor& input_tensor, - tt::tt_metal::LegacyShape output_tensor_start, - tt::tt_metal::LegacyShape output_tensor_end, - const std::optional step, + const std::vector &begins, + const std::vector &ends, + const std::vector &step, const std::optional& memory_config_arg, const std::optional& optional_output_tensor) { - return invoke(DefaultQueueId, input_tensor, output_tensor_start, output_tensor_end, step, memory_config_arg, optional_output_tensor); -} + return SliceOperation::invoke(ttnn::DefaultQueueId, input_tensor, begins, ends, step, memory_config_arg); + } +template ttnn::Tensor SliceOperation::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - tt::tt_metal::Array1D output_tensor_start, - tt::tt_metal::Array1D output_tensor_end, - const std::optional step, + const std::array &output_tensor_start, + const std::array &output_tensor_end, + const std::array &step, const std::optional& memory_config_arg, const std::optional& optional_output_tensor) { - return invoke( - queue_id, - input_tensor, - tt::tt_metal::LegacyShape(output_tensor_start), - tt::tt_metal::LegacyShape(output_tensor_end), - step.has_value() ? std::optional(tt::tt_metal::LegacyShape(step.value())) : std::nullopt, - memory_config_arg, - optional_output_tensor); -} + std::vector start(output_tensor_start.begin(), output_tensor_start.end()); + std::vector end(output_tensor_end.begin(), output_tensor_end.end()); + std::vector step_vec(step.begin(), step.end()); + return SliceOperation::invoke(queue_id, input_tensor, start, end, step_vec, memory_config_arg); + } +template ttnn::Tensor SliceOperation::invoke( - uint8_t queue_id, const ttnn::Tensor& input_tensor, - tt::tt_metal::Array4D output_tensor_start, - tt::tt_metal::Array4D output_tensor_end, - const std::optional step, + const std::array &output_tensor_start, + const std::array &output_tensor_end, + const std::array &step, const std::optional& memory_config_arg, const std::optional& optional_output_tensor) { - return invoke( - queue_id, - input_tensor, - tt::tt_metal::LegacyShape(output_tensor_start), - tt::tt_metal::LegacyShape(output_tensor_end), - step.has_value() ? std::optional(tt::tt_metal::LegacyShape(step.value())) : std::nullopt, - memory_config_arg, - optional_output_tensor); -} + return SliceOperation::invoke(ttnn::DefaultQueueId, input_tensor, output_tensor_start, output_tensor_end, step, memory_config_arg); + } -ttnn::Tensor SliceOperation::invoke( + +template ttnn::Tensor SliceOperation::invoke( + uint8_t queue_id, const ttnn::Tensor& input_tensor, - tt::tt_metal::Array4D output_tensor_start, - tt::tt_metal::Array4D output_tensor_end, - const std::optional step, + const std::vector &begins, + const std::vector &ends, + const std::vector &step, const std::optional& memory_config_arg, - const std::optional& optional_output_tensor) { - return invoke(DefaultQueueId, input_tensor, output_tensor_start, output_tensor_end, step, memory_config_arg, optional_output_tensor); -} + const std::optional& optional_output_tensor); -ttnn::Tensor SliceOperation::invoke( +template ttnn::Tensor SliceOperation::invoke( const ttnn::Tensor& input_tensor, - tt::tt_metal::Array4D output_tensor_start, - tt::tt_metal::Array4D output_tensor_end, - const std::optional step) { - return invoke(DefaultQueueId, input_tensor, output_tensor_start, output_tensor_end, step, std::nullopt, std::nullopt); -} + const std::vector &begins, + const std::vector &ends, + const std::vector &step, + const std::optional& memory_config_arg, + const std::optional& optional_output_tensor); + + +template ttnn::Tensor SliceOperation::invoke( + uint8_t queue_id, + const ttnn::Tensor& input_tensor, + const std::vector &begins, + const std::vector &ends, + const std::vector &step, + const std::optional& memory_config_arg, + const std::optional& optional_output_tensor); + +template ttnn::Tensor SliceOperation::invoke( + const ttnn::Tensor& input_tensor, + const std::vector &begins, + const std::vector &ends, + const std::vector &step, + const std::optional& memory_config_arg, + const std::optional& optional_output_tensor); + +template ttnn::Tensor SliceOperation::invoke( + uint8_t queue_id, + const ttnn::Tensor& input_tensor, + const std::array &output_tensor_start, + const std::array &output_tensor_end, + const std::array &step, + const std::optional& memory_config_arg, + const std::optional& optional_output_tensor); + +template ttnn::Tensor SliceOperation::invoke( + const ttnn::Tensor& input_tensor, + const std::array &output_tensor_start, + const std::array &output_tensor_end, + const std::array &step, + const std::optional& memory_config_arg, + const std::optional& optional_output_tensor); + +template ttnn::Tensor SliceOperation::invoke( + uint8_t queue_id, + const ttnn::Tensor& input_tensor, + const std::array &output_tensor_start, + const std::array &output_tensor_end, + const std::array &step, + const std::optional& memory_config_arg, + const std::optional& optional_output_tensor); + +template ttnn::Tensor SliceOperation::invoke( + const ttnn::Tensor& input_tensor, + const std::array &output_tensor_start, + const std::array &output_tensor_end, + const std::array &step, + const std::optional& memory_config_arg, + const std::optional& optional_output_tensor); + } // namespace operations diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp b/ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp index aaa979a034b..f814df01cdc 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp @@ -11,54 +11,45 @@ namespace operations { namespace data_movement { struct SliceOperation { + template static ttnn::Tensor invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - tt::tt_metal::LegacyShape output_tensor_start, - tt::tt_metal::LegacyShape output_tensor_end, - const std::optional step = std::nullopt, + const std::vector &begins, + const std::vector &ends, + const std::vector &step, const std::optional& memory_config_arg = std::nullopt, const std::optional& optional_output_tensor = std::nullopt); + template static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, - tt::tt_metal::LegacyShape output_tensor_start, - tt::tt_metal::LegacyShape output_tensor_end, - const std::optional step = std::nullopt, + const std::vector &output_tensor_start, + const std::vector &output_tensor_end, + const std::vector &step, const std::optional& memory_config_arg = std::nullopt, const std::optional& optional_output_tensor = std::nullopt); + template static ttnn::Tensor invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - tt::tt_metal::Array1D output_tensor_start, - tt::tt_metal::Array1D output_tensor_end, - const std::optional step = std::nullopt, + const std::array &output_tensor_start, + const std::array &output_tensor_end, + const std::array &step, const std::optional& memory_config_arg = std::nullopt, const std::optional& optional_output_tensor = std::nullopt); + template static ttnn::Tensor invoke( - uint8_t queue_id, const ttnn::Tensor& input_tensor, - tt::tt_metal::Array4D output_tensor_start, - tt::tt_metal::Array4D output_tensor_end, - const std::optional step = std::nullopt, + const std::array &output_tensor_start, + const std::array &output_tensor_end, + const std::array &step, const std::optional& memory_config_arg = std::nullopt, const std::optional& optional_output_tensor = std::nullopt); - static ttnn::Tensor invoke( - const ttnn::Tensor& input_tensor, - tt::tt_metal::Array4D output_tensor_start, - tt::tt_metal::Array4D output_tensor_end, - const std::optional step = std::nullopt, - const std::optional& memory_config_arg = std::nullopt, - const std::optional& optional_output_tensor = std::nullopt); - static ttnn::Tensor invoke( - const ttnn::Tensor& input_tensor, - tt::tt_metal::Array4D output_tensor_start, - tt::tt_metal::Array4D output_tensor_end, - const std::optional step); }; } // namespace data_movement diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/slice_pybind.hpp b/ttnn/cpp/ttnn/operations/data_movement/slice/slice_pybind.hpp index f5fa39ec2cc..6f1e1614152 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/slice_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/slice_pybind.hpp @@ -38,6 +38,7 @@ void bind_slice(py::module& module) { * :attr:`queue_id` (Optional[uint8]): command queue id )doc"; + // TODO: implementing the array version and overloading the pybind with all the possible array sizes is better than a vector with a fixed size default value using OperationType = decltype(ttnn::slice); ttnn::bind_registered_operation( module, @@ -46,18 +47,19 @@ void bind_slice(py::module& module) { ttnn::pybind_overload_t{ [] (const OperationType& self, const ttnn::Tensor& input_tensor, - const tt::tt_metal::Array4D & slice_start, - const tt::tt_metal::Array4D & slice_end, - const std::optional &step, + const std::vector &slice_start, + const std::vector &slice_end, + const std::optional> &step, const std::optional& memory_config, const std::optional& optional_output_tensor, uint8_t queue_id) { - return self(queue_id, input_tensor, slice_start, slice_end, step, memory_config, optional_output_tensor); + const auto step_value = step.value_or(std::vector(slice_end.size(), 1)); + return self(queue_id, input_tensor, slice_start, slice_end, step_value, memory_config, optional_output_tensor); }, py::arg("input_tensor"), py::arg("slice_start"), py::arg("slice_end"), - py::arg("step") = std::nullopt, + py::arg("step") = std::nullopt, // should consider a better default value py::kw_only(), py::arg("memory_config") = std::nullopt, py::arg("output_tensor") = std::nullopt, diff --git a/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp b/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp index 81a84ffb6db..abe323352a8 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp @@ -44,7 +44,7 @@ namespace detail { for (int i = 0; i < num_splits; i++) { auto start = i*chunk_len; - auto end = start + chunk_len - 1; + auto end = start + chunk_len; std::vector start_shape(preproc_shape.size(), 0); start_shape[dim] = start; @@ -54,14 +54,14 @@ namespace detail { if (j == dim) { end_shape[j] = end; } else { - end_shape[j] = preproc_shape[j] - 1; + end_shape[j] = preproc_shape[j]; } } Tensor output_chunk = ttnn::slice(preprocessed, - tt::tt_metal::LegacyShape(start_shape), - tt::tt_metal::LegacyShape(end_shape), - std::nullopt, + start_shape, + end_shape, + std::vector(end_shape.size(), 1), mem_config); if (input_rank < 4) { output_chunk = ttnn::squeeze_from_4D(output_chunk, input_rank); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp index 76c20b737af..7461081449b 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp @@ -470,12 +470,12 @@ std::vector> ExecuteBackwardConcat::invoke( if(are_required_outputs[0]){ std::vector start_index = {0, 0, 0, 0}; std::vector end_index = { - input.get_legacy_shape()[0] - 1, - input.get_legacy_shape()[1] - 1, - input.get_legacy_shape()[2] - 1, - input.get_legacy_shape()[3] - 1}; - - ttnn::slice(queue_id, grad, start_index, end_index, std::nullopt, std::nullopt, input_grad); + input.get_legacy_shape()[0], + input.get_legacy_shape()[1], + input.get_legacy_shape()[2], + input.get_legacy_shape()[3]}; + std::vector step = std::vector({1, 1, 1, 1}); + ttnn::slice(queue_id, grad, start_index, end_index, step, std::nullopt, input_grad); grad_tensor[0] = input_grad; } @@ -492,11 +492,12 @@ std::vector> ExecuteBackwardConcat::invoke( start_index_2 = {0, 0, 0, input.get_legacy_shape()[3]}; } std::vector end_index_2 = { - grad.get_legacy_shape()[0] - 1, - grad.get_legacy_shape()[1] - 1, - grad.get_legacy_shape()[2] - 1, - grad.get_legacy_shape()[3] - 1}; - ttnn::slice(queue_id, grad, start_index_2, end_index_2, std::nullopt, std::nullopt, other_grad); + grad.get_legacy_shape()[0], + grad.get_legacy_shape()[1], + grad.get_legacy_shape()[2], + grad.get_legacy_shape()[3]}; + std::vector step_2 = std::vector({1, 1, 1, 1}); + ttnn::slice(queue_id, grad, start_index_2, end_index_2, step_2, std::nullopt, other_grad); grad_tensor[1] = other_grad; } diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp index 46de0bde33c..35bc1dddbb4 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp @@ -532,13 +532,14 @@ std::vector split_tensor_for_glu(const Tensor& input_a, int32_t dim, con tt::tt_metal::LegacyShape inshape(input_a.get_legacy_shape()); TT_FATAL(((inshape[dim] / 2) % tt::constants::TILE_WIDTH == 0), "Split tensor dimension should be in full tile"); std::vector s_a = {0, 0, 0, 0}; - std::vector e_a = {input_a.get_legacy_shape()[0] - 1, inshape[1] - 1, inshape[2] - 1, inshape[3] / 2 - 1}; + std::vector e_a = {input_a.get_legacy_shape()[0], inshape[1], inshape[2], inshape[3] / 2}; std::vector s_b = {0, 0, 0, inshape[3] / 2}; - std::vector e_b = {inshape[0] - 1, inshape[1] - 1, inshape[2] - 1, inshape[3] - 1}; + std::vector e_b = {inshape[0], inshape[1], inshape[2], inshape[3]}; - Tensor t_a = ttnn::slice(DefaultQueueId, input_a, s_a, e_a, std::nullopt, output_mem_config); - Tensor t_b = ttnn::slice(DefaultQueueId, input_a, s_b, e_b, std::nullopt, output_mem_config); + auto step = std::vector({1,1,1,1}); + Tensor t_a = ttnn::slice(DefaultQueueId, input_a, s_a, e_a, step, output_mem_config); + Tensor t_b = ttnn::slice(DefaultQueueId, input_a, s_b, e_b, step, output_mem_config); t_split.emplace_back(t_a); t_split.emplace_back(t_b); diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp index f78c7db3610..8c9665cbfbb 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp @@ -1486,14 +1486,15 @@ std::vector ExecuteUnaryBackwardProd::invoke( } // all_dimensions = False Tensor updated_grad = prod_result; + auto step = std::vector({1, 1, 1, 1}); if (prod_result.get_legacy_shape().without_padding() != grad.get_legacy_shape()) { if (dim == 3 || dim == -1) { std::vector after_permute_dims = {0, 3, 1, 2}; Tensor required = ttnn::permute(grad, after_permute_dims, output_memory_config); std::vector start_index = {0, 0, 0, 0}; std::vector end_index = { - grad.get_legacy_shape()[0] - 1, 0, grad.get_legacy_shape()[1] - 1, grad.get_legacy_shape()[2] - 1}; - Tensor new_slice_tensor = ttnn::slice(DefaultQueueId, required, start_index, end_index); + grad.get_legacy_shape()[0], 1, grad.get_legacy_shape()[1], grad.get_legacy_shape()[2]}; + Tensor new_slice_tensor = ttnn::slice(DefaultQueueId, required, start_index, end_index, step, std::nullopt); after_permute_dims = {0, 2, 3, 1}; updated_grad = ttnn::permute(new_slice_tensor, after_permute_dims, output_memory_config); if(updated_grad.storage_type() != StorageType::DEVICE && updated_grad.storage_type() != StorageType::MULTI_DEVICE) { @@ -1506,8 +1507,8 @@ std::vector ExecuteUnaryBackwardProd::invoke( Tensor required = ttnn::permute(grad, after_permute_dims, output_memory_config); std::vector start_index = {0, 0, 0, 0}; std::vector end_index = { - grad.get_legacy_shape()[0] - 1, 0, grad.get_legacy_shape()[1] - 1, grad.get_legacy_shape()[3] - 1}; - Tensor new_slice_tensor = ttnn::slice(DefaultQueueId, required, start_index, end_index); + grad.get_legacy_shape()[0], 1, grad.get_legacy_shape()[1], grad.get_legacy_shape()[3]}; + Tensor new_slice_tensor = ttnn::slice(DefaultQueueId, required, start_index, end_index, step, std::nullopt); updated_grad = ttnn::permute(new_slice_tensor, after_permute_dims, output_memory_config); if(updated_grad.get_layout()==Layout::ROW_MAJOR){ updated_grad = ttnn::operations::unary_backward::change_layout_to_tile(updated_grad, output_memory_config); @@ -1553,11 +1554,12 @@ std::vector ExecuteUnaryBackwardProd::invoke( if (reciprocal_input.get_legacy_shape()[1] % 32 != 0) { std::vector start_index = {0, 0, 0, 0}; std::vector end_index = { - input.get_legacy_shape()[0] - 1, - input.get_legacy_shape()[1] - 1, - input.get_legacy_shape()[2] - 1, - input.get_legacy_shape()[3] - 1}; - grad_result = ttnn::slice(DefaultQueueId, result, start_index, end_index); + input.get_legacy_shape()[0], + input.get_legacy_shape()[1], + input.get_legacy_shape()[2], + input.get_legacy_shape()[3]}; + auto step = std::vector({1,1,1,1}); + grad_result = ttnn::slice(DefaultQueueId, result, start_index, end_index, step, std::nullopt); } grad_tensor.emplace_back(grad_result); return grad_tensor; @@ -1587,11 +1589,11 @@ std::vector ExecuteUnaryBackwardProd::invoke( if (reciprocal_input.get_legacy_shape()[0] % 32 != 0) { std::vector start_index = {0, 0, 0, 0}; std::vector end_index = { - input.get_legacy_shape()[0] - 1, - input.get_legacy_shape()[1] - 1, - input.get_legacy_shape()[2] - 1, - input.get_legacy_shape()[3] - 1}; - grad_result = ttnn::slice(DefaultQueueId, result, start_index, end_index); + input.get_legacy_shape()[0], + input.get_legacy_shape()[1], + input.get_legacy_shape()[2], + input.get_legacy_shape()[3]}; + grad_result = ttnn::slice(DefaultQueueId, result, start_index, end_index, step, std::nullopt); } grad_tensor.emplace_back(grad_result); return grad_tensor; diff --git a/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp b/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp index 04133120761..5617c05c198 100644 --- a/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp @@ -158,12 +158,16 @@ Tensor AutoFormat::format_output_tensor( // Output can be unpadded and layout supports the shape if ((formatted_output.get_layout() == Layout::TILE && AutoFormat::legal_tile_shape(shape)) || (formatted_output.get_layout() == Layout::ROW_MAJOR && AutoFormat::legal_rm_shape(shape))) { + auto begins = std::vector({0, 0, 0, 0}); + auto ends = std::vector({shape[0], shape[1], shape[2], shape[3]}); + auto step = std::vector({1, 1, 1, 1}); + formatted_output = ttnn::slice( DefaultQueueId, formatted_output, - std::vector({0, 0, 0, 0}), - std::vector({shape[0] - 1, shape[1] - 1, shape[2] - 1, shape[3] - 1}), - std::nullopt, + begins, + ends, + step, mem_config); return formatted_output; // Output is tile but shape cannot be tile. We leave in RM @@ -185,12 +189,15 @@ Tensor AutoFormat::format_output_tensor( } else if ( formatted_output.get_layout() == Layout::ROW_MAJOR && target_layout == Layout::TILE && AutoFormat::legal_tile_shape(shape)) { + auto begins = std::vector({0, 0, 0, 0}); + auto ends = std::vector({shape[0], shape[1], shape[2], shape[3]}); + auto step = std::vector({1, 1, 1, 1}); formatted_output = ttnn::slice( DefaultQueueId, formatted_output, - std::vector({0, 0, 0, 0}), - std::vector({shape[0] - 1, shape[1] - 1, shape[2] - 1, shape[3] - 1}), - std::nullopt, + begins, + ends, + step, mem_config); formatted_output = ttnn::tilize(formatted_output, mem_config); return formatted_output; @@ -207,8 +214,11 @@ Tensor AutoFormat::format_output_tensor( formatted_output = formatted_output.to(Layout::ROW_MAJOR); convert_layout = formatted_output.get_layout() != target_layout; } + auto begins = std::vector({0, 0, 0, 0}); + auto ends = std::vector({shape[0], shape[1], shape[2], shape[3]}); + auto step = std::vector({1, 1, 1, 1}); formatted_output = - ttnn::slice(formatted_output, tt::tt_metal::Array4D({0, 0, 0, 0}), tt::tt_metal::Array4D({shape[0] - 1, shape[1] - 1, shape[2] - 1, shape[3] - 1}), std::nullopt, std::nullopt); + ttnn::slice(formatted_output, begins, ends, step, std::nullopt); } if (convert_layout) { diff --git a/ttnn/cpp/ttnn/operations/reduction/prod/prod.cpp b/ttnn/cpp/ttnn/operations/reduction/prod/prod.cpp index 39232ec236a..cecb97e7195 100644 --- a/ttnn/cpp/ttnn/operations/reduction/prod/prod.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/prod/prod.cpp @@ -99,6 +99,7 @@ Tensor ProdOperation::invoke(const Tensor& input_a, bool all_dimensions, int64_t } Tensor result = prod_nc(temp, dim, output_mem_config); // Permute and unpad result for dim 2,3 + auto step = std::vector({1, 1, 1, 1}); if (dim == 0 || dim == 1 || dim == -4 || dim == -3) { return result; } else if (dim == 2 || dim == -2) { @@ -106,8 +107,8 @@ Tensor ProdOperation::invoke(const Tensor& input_a, bool all_dimensions, int64_t Tensor required = ttnn::permute(result, after_permute_dims, output_mem_config); tt::tt_metal::LegacyShape input_shape = input_a.get_legacy_shape(); std::vector start_index = {0, 0, 0, 0}; - std::vector end_index = {input_shape[0] - 1, input_shape[1] - 1, 0, input_shape[3] - 1}; - return ttnn::slice(DefaultQueueId, required, start_index, end_index); + std::vector end_index = {input_shape[0], input_shape[1], 1, input_shape[3]}; + return ttnn::slice(DefaultQueueId, required, start_index, end_index, step, std::nullopt); } else { // dim 3 // permute std::vector after_permute_dims = {1, 2, 0, 3}; @@ -115,8 +116,8 @@ Tensor ProdOperation::invoke(const Tensor& input_a, bool all_dimensions, int64_t // unpad tt::tt_metal::LegacyShape input_shape = input_a.get_legacy_shape(); std::vector start_index = {0, 0, 0, 0}; - std::vector end_index = {input_shape[0] - 1, input_shape[1] - 1, 0, input_shape[2] - 1}; - Tensor new_unpad_tensor = ttnn::slice(DefaultQueueId, required, start_index, end_index); + std::vector end_index = {input_shape[0], input_shape[1], 1, input_shape[2]}; + Tensor new_unpad_tensor = ttnn::slice(DefaultQueueId, required, start_index, end_index, step, std::nullopt); // permute back after_permute_dims = {0, 1, 3, 2}; Tensor res_host = ttnn::permute(new_unpad_tensor, after_permute_dims, output_mem_config); diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp index d55ad1b2596..ea4c0fc753a 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp @@ -1260,11 +1260,11 @@ Tensor unpad(const Tensor& tensor, const tt::tt_metal::LegacyShape& output_tenso for (auto i = 0; i < input_shape.rank(); i++) { // Check if tensor start and end indices are within input tensor shape TT_ASSERT(output_tensor_start[i] < input_shape[i]); - TT_ASSERT(output_tensor_end[i] < input_shape[i]); - // Check if start shape is <= end shape - TT_ASSERT(output_tensor_start[i] <= output_tensor_end[i]); + TT_ASSERT(output_tensor_end[i] <= input_shape[i]); + // Check if start shape is < end shape + TT_ASSERT(output_tensor_start[i] < output_tensor_end[i]); // Figure out output tensor shape - output_shape.push_back(output_tensor_end[i] - output_tensor_start[i] + 1); + output_shape.push_back(output_tensor_end[i] - output_tensor_start[i]); } auto unpad = [&input_shape, &input_strides, &output_shape, &output_tensor_start, &output_tensor_end]( @@ -1275,7 +1275,7 @@ Tensor unpad(const Tensor& tensor, const tt::tt_metal::LegacyShape& output_tenso auto output_buffer = owned_buffer::create(compute_volume(output_shape)); std::function unpad_from_tile = [&](std::size_t dim) -> void { - for (auto i = output_tensor_start[dim]; i <= output_tensor_end[dim]; i++) { + for (auto i = output_tensor_start[dim]; i < output_tensor_end[dim]; i++) { input_indices[dim] = i; if (dim == input_shape.rank() - 1) { auto flat_input_index = compute_flat_input_index(input_indices, input_strides); diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp index ef39c5d9aca..3d2a160c13b 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp @@ -335,7 +335,7 @@ Tensor tensor_unpad_from_tile(const Tensor& input_tensor, const tt::tt_metal::Le std::vector output_tensor_end{}; for (auto index = 0; index < input_tensor.get_legacy_shape().rank(); index++) { output_tensor_start.push_back(0); - output_tensor_end.push_back(output_tensor_shape[index] - 1); + output_tensor_end.push_back(output_tensor_shape[index]); } auto output = input_tensor.unpad(output_tensor_start, output_tensor_end); output = tt::tt_metal::set_tensor_id(output); diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 76c3fd45386..c3d13f6bfbb 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -28,7 +28,6 @@ def _golden_function(input_tensor: ttnn.Tensor, slices): ) def __getitem__(input_tensor: ttnn.Tensor, slices) -> ttnn.Tensor: input_rank = len(input_tensor.shape) - input_layout = input_tensor.layout if isinstance(slices, int): slices = (slice(None, slices, None),) @@ -66,47 +65,15 @@ def __getitem__(input_tensor: ttnn.Tensor, slices) -> ttnn.Tensor: raise RuntimeError(f"Too many slices for tensor of rank {input_rank}") if input_rank <= 4: - input_tensor = ttnn.unsqueeze_to_4D(input_tensor) - - while len(slices) != 4: - slices = (slice(None, None, None),) + slices slice_start = [_slice.start if _slice.start is not None else 0 for _slice in slices] slice_end = [ - (max(input_tensor.shape[index] + _slice.stop, 1) if _slice.stop < 0 else _slice.stop) - if _slice.stop is not None - else input_tensor.shape[index] - for index, _slice in enumerate(slices) + _slice.stop if _slice.stop is not None else input_tensor.shape[i] for i, _slice in enumerate(slices) ] slice_step = [_slice.step if _slice.step is not None else 1 for _slice in slices] - padded_slice_end = list(slice_end) - if input_layout == ttnn.TILE_LAYOUT: - padded_slice_end[-1] = int(math.ceil((slice_end[-1]) / ttnn.TILE_SIZE)) * ttnn.TILE_SIZE - padded_slice_end[-2] = int(math.ceil((slice_end[-2]) / ttnn.TILE_SIZE)) * ttnn.TILE_SIZE + output = ttnn.slice(input_tensor, slice_start, slice_end, slice_step) - if list(padded_slice_end) == list(input_tensor.shape.with_tile_padding()) and (slice_step is None): - output = input_tensor - else: - padded_slice_end_minus_1 = [x - 1 for x in padded_slice_end] - if any([x < 0 for x in padded_slice_end_minus_1]): - raise RuntimeError("ttnn.Tensor.__getitem__: cannot return a scalar!") - - if ttnn.is_tensor_storage_on_device(input_tensor): - output = ttnn.slice(input_tensor, slice_start, padded_slice_end_minus_1, slice_step) - else: - if any([x != 1 for x in slice_step]): - raise NotImplementedError("ttnn.Tensor.__getitem__: step is not supported for host tensor!") - input_tensor = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT) - output = input_tensor.unpad(slice_start, padded_slice_end_minus_1) - output = ttnn.to_layout(output, input_layout) - output_shape = [ - 0 - if slices[i].stop is not None and slices[i].stop + input_tensor.shape[i] == slices[i].start - else len(range(start, end, step)) - for i, (start, end, step) in enumerate(zip(slice_start, slice_end, slice_step)) - ][-input_rank:] - padded_output_shape = list(output.shape.with_tile_padding())[-input_rank:] - return ttnn.reshape(output, shape=ttnn.Shape(output_shape, padded_output_shape)) + return output raise NotImplementedError