diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode.py index d89a1b808c96..ccbd517cb36f 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode.py @@ -459,7 +459,7 @@ def run_test_sdpa_decode_single_iter( # [32, 8, 1, 32768, 128, (8, 6), True, False], # Llama2-70B # [4, 32, 8, 32768, 128, (8, 8), True, False], # llama 3.1 8b [4, 32, 8, 32768, 128, (8, 8), True, True], # llama 3.1 8b - # [4, 32, 8, 32768, 128, (8, 8), False, False], # llama 3.1 8b + [32, 32, 8, 8192, 128, (8, 8), True, False], # llama 3.1 8b # [4, 16, 4, 32768, 128, (8, 8), False, False], # llama 3.1 8b ), ) @@ -722,7 +722,9 @@ def to_contiguous_cache(paged_cache, batch, num_kv, max_num_blocks_per_seq, bloc ( [32, 8, 1, 32768, 128, (8, 6), True], # Llama2-70B [4, 32, 8, 32768, 128, (8, 8), True], # llama 3.1 8b - [4, 16, 4, 32768, 128, (8, 8), True], + # [4, 16, 4, 32768, 128, (8, 8), True], + # [32, 32, 8, 4096, 128, (8, 8), True], # llama 3.1 8b + [8, 16, 4, 4096, 128, (8, 2), True], # llama 3.1 8b N300 # [1, 8, 1, 32768, 128, (8, 1), True], # Llama2-70B # [16, 8, 1, 32768, 128, (8, 6), False, False], # Llama2-70B # [8, 8, 1, 32768, 128, (8, 6), True, False], # Llama2-70B diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp index e70f213e893f..04f9a751503b 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp @@ -199,6 +199,7 @@ void mul_block_bcast_scalar_inplace(uint32_t in0_cb, uint32_t in1_scalar_cb, uin } } +template void add_block_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t num_tiles) { // Precondition: in0_cb and in1_cb have num_tiles produced // Postcondition: in0_cb has num_tiles produced @@ -216,29 +217,7 @@ void add_block_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t num_tiles) { cb_push_back(in0_cb, 1); release_dst(tt::DstMode::Half); } - - cb_pop_front(in1_cb, num_tiles); -} - -void add_block_inplace2(uint32_t in0_cb, uint32_t in1_cb, uint32_t num_tiles) { - // Precondition: in0_cb and in1_cb have num_tiles produced - // Postcondition: in0_cb has num_tiles produced - // Postcondition: in1_cb has num_tiles consumed - - add_tiles_init(); - cb_wait_front(in0_cb, num_tiles); - cb_wait_front(in1_cb, num_tiles); - for (uint32_t i = 0; i < num_tiles; i++) { - acquire_dst(tt::DstMode::Half); - add_tiles(in0_cb, in1_cb, 0, i, 0); - cb_pop_front(in0_cb, 1); - cb_reserve_back(in0_cb, 1); - pack_tile(0, in0_cb); - cb_push_back(in0_cb, 1); - release_dst(tt::DstMode::Half); - } - - cb_pop_front(in1_cb, num_tiles); + if (pop_in1) cb_pop_front(in1_cb, num_tiles); } void add_block(uint32_t in0_cb, uint32_t in1_cb, uint32_t out_cb, uint32_t num_tiles) { @@ -395,6 +374,7 @@ void MAIN { constexpr uint32_t num_cores_per_batch = get_compile_time_arg_val(16); constexpr uint32_t k_chunk_size = get_compile_time_arg_val(17); constexpr uint32_t num_cores_per_head = get_compile_time_arg_val(18); + constexpr uint32_t num_heads_per_core = get_compile_time_arg_val(19); constexpr uint32_t q_chunk_tiles = Sq_chunk_t * DHt; constexpr uint32_t k_chunk_tiles = Sk_chunk_t * DHt; @@ -472,166 +452,169 @@ void MAIN { mm_init(); cb_wait_front(cb_q_in, q_chunk_tiles); - // loop while k_low < q_high - for (uint32_t k_chunk = k_chunk_start; k_chunk < k_chunk_end; ++k_chunk) { + for (uint32_t cur_head_work = 0; cur_head_work < num_heads_per_core; ++cur_head_work) { + // loop while k_low < q_high + for (uint32_t k_chunk = k_chunk_start; k_chunk < k_chunk_end; ++k_chunk) { - /* QK = Q_CHUNK @ K_CHUNK */ - unpack_reconfig_data_format(cb_q_in, cb_k_in); // DEBUG - pack_reconfig_data_format(cb_qk_im); - cb_matmul_blocks(cb_q_in, cb_k_in, cb_qk_im, Sq_chunk_t, Sk_chunk_t, DHt, qk_num_blocks, qk_in0_num_subblocks, qk_in1_num_subblocks, qk_in0_block_w, qk_subblock_h, qk_subblock_w, true /*transpose*/); + /* QK = Q_CHUNK @ K_CHUNK */ + unpack_reconfig_data_format(cb_q_in, cb_k_in); // DEBUG + pack_reconfig_data_format(cb_qk_im); + cb_matmul_blocks(cb_q_in, cb_k_in, cb_qk_im, Sq_chunk_t, Sk_chunk_t, DHt, qk_num_blocks, qk_in0_num_subblocks, qk_in1_num_subblocks, qk_in0_block_w, qk_subblock_h, qk_subblock_w, true /*transpose*/); - /* QK *= SCALE */ - mul_block_bcast_scalar_inplace(cb_qk_im, cb_scale_in, qk_chunk_tiles); - - // For decode, we only apply mask at the last chunk on reducer cor - if (k_chunk == k_chunk_end - 1 && do_reduce) { - /* QK += MASK */ - unpack_reconfig_data_format(cb_qk_im, cb_mask_in); - add_block_inplace(cb_qk_im, cb_mask_in, qk_chunk_tiles); - } - - unpack_reconfig_data_format(cb_qk_im, cb_identity_scale_in); - pack_reconfig_data_format(cb_cur_max); - reduce_c(); - - if (k_chunk > k_chunk_start) { - unpack_reconfig_data_format(cb_cur_max, cb_prev_max); - max_block_inplace(cb_cur_max, cb_prev_max, Sq_chunk_t); - } - /* QK -= cb_cur_max */ - /* QK = exp(QK)*/ - unpack_reconfig_data_format(cb_qk_im, cb_cur_max); - pack_reconfig_data_format(cb_qk_im); - sub_exp_block_bcast_cols_inplace(cb_qk_im, cb_cur_max, Sq_chunk_t, Sk_chunk_t); - - /* cb_cur_sum = sum(cb_qk_im, dim=-1) */ - unpack_reconfig_data_format(cb_qk_im, cb_identity_scale_in); - pack_reconfig_data_format(cb_cur_sum); - reduce_c(); - - /* OUT_IM = QK @ V_CHUNK */ - unpack_reconfig_data_format(cb_qk_im, cb_v_in); // DEBUG - pack_reconfig_data_format(cb_out_im); - cb_matmul_blocks(cb_qk_im, cb_v_in, cb_out_im, Sq_chunk_t, DHt, Sk_chunk_t, out_num_blocks, out_in0_num_subblocks, out_in1_num_subblocks, out_in0_block_w, out_subblock_h, out_subblock_w, false /*transpose*/); - unpack_reconfig_data_format_srca(cb_out_im); - cb_pop_front(cb_qk_im, qk_chunk_tiles); - - /* OUT_ACC += OUT_IM */ - if (k_chunk == k_chunk_start) { - unpack_reconfig_data_format_srca(cb_out_im); - pack_reconfig_data_format(cb_out_accumulate_im); - copy_block(cb_out_im, cb_out_accumulate_im, out_chunk_tiles); - } else { - unpack_reconfig_data_format(cb_prev_max, cb_cur_max); // DEBUG - pack_reconfig_data_format(cb_exp_max_diff); - /* cb_exp_max_diff = torch.exp(cb_prev_max - cb_cur_max) */ - sub_exp_block(cb_prev_max, cb_cur_max, cb_exp_max_diff, Sq_chunk_t); - cb_pop_front(cb_prev_max, Sq_chunk_t); + /* QK *= SCALE */ + mul_block_bcast_scalar_inplace(cb_qk_im, cb_scale_in, qk_chunk_tiles); - /* cb_prev_sum *= cb_exp_max_diff */ - mul_block_inplace(cb_prev_sum, cb_exp_max_diff, Sq_chunk_t); + // For decode, we only apply mask at the last chunk on reducer cor + if (k_chunk == k_chunk_end - 1 && do_reduce) { + /* QK += MASK */ + unpack_reconfig_data_format(cb_qk_im, cb_mask_in); + add_block_inplace(cb_qk_im, cb_mask_in, qk_chunk_tiles); + } - /* cb_out_accumulate_im *= cb_exp_max_diff */ - unpack_reconfig_data_format(cb_out_accumulate_im, cb_exp_max_diff); // DEBUG - pack_reconfig_data_format(cb_out_accumulate_im); - mul_block_bcast_cols_inplace(cb_out_accumulate_im, cb_exp_max_diff, Sq_chunk_t, DHt); + unpack_reconfig_data_format(cb_qk_im, cb_identity_scale_in); + pack_reconfig_data_format(cb_cur_max); + reduce_c(); - /* cb_cur_sum += cb_prev_sum */ - unpack_reconfig_data_format(cb_cur_sum, cb_prev_sum); // DEBUG + if (k_chunk > k_chunk_start) { + unpack_reconfig_data_format(cb_cur_max, cb_prev_max); + max_block_inplace(cb_cur_max, cb_prev_max, Sq_chunk_t); + } + /* QK -= cb_cur_max */ + /* QK = exp(QK)*/ + unpack_reconfig_data_format(cb_qk_im, cb_cur_max); + pack_reconfig_data_format(cb_qk_im); + sub_exp_block_bcast_cols_inplace(cb_qk_im, cb_cur_max, Sq_chunk_t, Sk_chunk_t); + + /* cb_cur_sum = sum(cb_qk_im, dim=-1) */ + unpack_reconfig_data_format(cb_qk_im, cb_identity_scale_in); pack_reconfig_data_format(cb_cur_sum); - add_block_inplace(cb_cur_sum, cb_prev_sum, Sq_chunk_t); + reduce_c(); - /* cb_out_accumulate_im += cb_out_im */ - unpack_reconfig_data_format(cb_out_accumulate_im, cb_out_im); // DEBUG - pack_reconfig_data_format(cb_out_accumulate_im); - add_block_inplace(cb_out_accumulate_im, cb_out_im, out_chunk_tiles); - } - - if (k_chunk < k_chunk_end - 1 || do_reduce) { - // Set cb_prev_sum and cb_prev_max - unpack_reconfig_data_format(cb_cur_max, cb_cur_max); // DEBUG - pack_reconfig_data_format(cb_prev_max); - copy_block(cb_cur_max, cb_prev_max, Sq_chunk_t); - copy_block(cb_cur_sum, cb_prev_sum, Sq_chunk_t); - - } else{ - // Write o, m, l into cb_out - copy_block(cb_out_accumulate_im, cb_out_o, out_chunk_tiles); - copy_block(cb_cur_max, cb_out_m, Sq_chunk_t); - copy_block(cb_cur_sum, cb_out_l, Sq_chunk_t); - } - } - cb_pop_front(cb_q_in, q_chunk_tiles); - - // do reduction across intermediates from other cores if this is the reduction core - if (do_reduce) { - // cb_out_accumulate_im should contain o_1 - // cb_prev_max and cb_prev_sum should contain m_1 and l_1 - - if (k_chunk_end - k_chunk_start < k_num_chunks){ - // This indicates that there are computes done by other workers. Needs to wait for them and send to reducer's compute - for (uint32_t i = 0; i < num_cores_to_wait ; i++) { - cb_wait_front(cb_out_o, q_chunk_tiles); //o_2 - cb_wait_front(cb_m_in, Sq_chunk_t); //m_2 - cb_wait_front(cb_l_in, Sq_chunk_t); //l_2 - - // unpack_reconfig_data_format(cb_q_in, cb_q_in); // DEBUG - // pack_reconfig_data_format(cb_out_accumulate_im_2); - copy_block(cb_out_o, cb_out_accumulate_im_2, q_chunk_tiles); - copy_block(cb_l_in, cb_prev_sum_2, Sq_chunk_t); - max_block(cb_m_in, cb_prev_max, cb_cur_max, Sq_chunk_t); // pushed, pushed, popped - - // l = torch.exp(m_2 - m) * l_2 + torch.exp(m_1 - m) * l_1 - /// l1 = torch.exp(m_2 - m) * l_2 - // unpack_reconfig_data_format(cb_prev_max_2, cb_cur_max); // DEBUG - // pack_reconfig_data_format(cb_exp_max_diff_2); - sub_exp_block(cb_m_in, cb_cur_max, cb_exp_max_diff_2, Sq_chunk_t); - mul_block_inplace(cb_prev_sum_2, cb_exp_max_diff_2, Sq_chunk_t); - /// l2 = torch.exp(m_1 - m) * l_1 - // unpack_reconfig_data_format(cb_prev_max, cb_cur_max); // DEBUG - // pack_reconfig_data_format(cb_exp_max_diff); + /* OUT_IM = QK @ V_CHUNK */ + unpack_reconfig_data_format(cb_qk_im, cb_v_in); // DEBUG + pack_reconfig_data_format(cb_out_im); + cb_matmul_blocks(cb_qk_im, cb_v_in, cb_out_im, Sq_chunk_t, DHt, Sk_chunk_t, out_num_blocks, out_in0_num_subblocks, out_in1_num_subblocks, out_in0_block_w, out_subblock_h, out_subblock_w, false /*transpose*/); + unpack_reconfig_data_format_srca(cb_out_im); + cb_pop_front(cb_qk_im, qk_chunk_tiles); + + /* OUT_ACC += OUT_IM */ + if (k_chunk == k_chunk_start) { + unpack_reconfig_data_format_srca(cb_out_im); + pack_reconfig_data_format(cb_out_accumulate_im); + copy_block(cb_out_im, cb_out_accumulate_im, out_chunk_tiles); + } else { + unpack_reconfig_data_format(cb_prev_max, cb_cur_max); // DEBUG + pack_reconfig_data_format(cb_exp_max_diff); + /* cb_exp_max_diff = torch.exp(cb_prev_max - cb_cur_max) */ sub_exp_block(cb_prev_max, cb_cur_max, cb_exp_max_diff, Sq_chunk_t); + cb_pop_front(cb_prev_max, Sq_chunk_t); + + /* cb_prev_sum *= cb_exp_max_diff */ mul_block_inplace(cb_prev_sum, cb_exp_max_diff, Sq_chunk_t); - /// l = l1 + l2 - // unpack_reconfig_data_format(cb_cur_sum, cb_prev_sum); // DEBUG - // pack_reconfig_data_format(cb_cur_sum); - add_block(cb_prev_sum_2, cb_prev_sum, cb_cur_sum, Sq_chunk_t); - // unpack_reconfig_data_format(cb_out_accumulate_im, cb_exp_max_diff); // DEBUG - // pack_reconfig_data_format(cb_out_accumulate_im); + /* cb_out_accumulate_im *= cb_exp_max_diff */ + unpack_reconfig_data_format(cb_out_accumulate_im, cb_exp_max_diff); // DEBUG + pack_reconfig_data_format(cb_out_accumulate_im); mul_block_bcast_cols_inplace(cb_out_accumulate_im, cb_exp_max_diff, Sq_chunk_t, DHt); - mul_block_bcast_cols_inplace(cb_out_accumulate_im_2, cb_exp_max_diff_2, Sq_chunk_t, DHt); - // unpack_reconfig_data_format(cb_out_accumulate_im, cb_out_accumulate_im_2); - // pack_reconfig_data_format(cb_out_accumulate_im); - add_block_inplace2(cb_out_accumulate_im, cb_out_accumulate_im_2, q_chunk_tiles); + /* cb_cur_sum += cb_prev_sum */ + unpack_reconfig_data_format(cb_cur_sum, cb_prev_sum); // DEBUG + pack_reconfig_data_format(cb_cur_sum); + add_block_inplace(cb_cur_sum, cb_prev_sum, Sq_chunk_t); - // copy tiles - // unpack_reconfig_data_format(cb_cur_max, cb_cur_max); // DEBUG - // pack_reconfig_data_format(cb_prev_max); - cb_pop_front(cb_prev_max, Sq_chunk_t); - cb_pop_front(cb_m_in, Sq_chunk_t); + /* cb_out_accumulate_im += cb_out_im */ + unpack_reconfig_data_format(cb_out_accumulate_im, cb_out_im); // DEBUG + pack_reconfig_data_format(cb_out_accumulate_im); + add_block_inplace(cb_out_accumulate_im, cb_out_im, out_chunk_tiles); + } + + if (k_chunk < k_chunk_end - 1 || do_reduce) { + // Set cb_prev_sum and cb_prev_max + unpack_reconfig_data_format(cb_cur_max, cb_cur_max); // DEBUG + pack_reconfig_data_format(cb_prev_max); copy_block(cb_cur_max, cb_prev_max, Sq_chunk_t); copy_block(cb_cur_sum, cb_prev_sum, Sq_chunk_t); + + } else{ + // Write o, m, l into cb_out + copy_block(cb_out_accumulate_im, cb_out_o, out_chunk_tiles); + copy_block(cb_cur_max, cb_out_m, Sq_chunk_t); + copy_block(cb_cur_sum, cb_out_l, Sq_chunk_t); } } - /* cb_cur_sum = 1.0 / cb_cur_sum */ - cb_push_back(cb_cur_sum, Sq_chunk_t); - - unpack_reconfig_data_format(cb_cur_sum, cb_cur_sum); // DEBUG - pack_reconfig_data_format(cb_cur_sum); - recip_block_inplace(cb_cur_sum, Sq_chunk_t); - - /* cb_out_accumulate_im *= cb_cur_sum */ - unpack_reconfig_data_format(cb_out_accumulate_im, cb_cur_sum); // DEBUG - pack_reconfig_data_format(cb_out_accumulate_im); - mul_block_bcast_cols_inplace(cb_out_accumulate_im, cb_cur_sum, Sq_chunk_t, DHt); - pack_reconfig_data_format(cb_out_final); - copy_block(cb_out_accumulate_im, cb_out_final, out_chunk_tiles); - - // free up cb_prev_max after K chunks - cb_pop_front(cb_prev_max, Sq_chunk_t); - cb_pop_front(cb_prev_sum, Sq_chunk_t); + + // do reduction across intermediates from other cores if this is the reduction core + if (do_reduce) { + // cb_out_accumulate_im should contain o_1 + // cb_prev_max and cb_prev_sum should contain m_1 and l_1 + + if (k_chunk_end - k_chunk_start < k_num_chunks){ + // This indicates that there are computes done by other workers. Needs to wait for them and send to reducer's compute + for (uint32_t i = 0; i < num_cores_to_wait ; i++) { + cb_wait_front(cb_out_o, q_chunk_tiles); //o_2 + cb_wait_front(cb_m_in, Sq_chunk_t); //m_2 + cb_wait_front(cb_l_in, Sq_chunk_t); //l_2 + + // unpack_reconfig_data_format(cb_q_in, cb_q_in); // DEBUG + // pack_reconfig_data_format(cb_out_accumulate_im_2); + copy_block(cb_out_o, cb_out_accumulate_im_2, q_chunk_tiles); + copy_block(cb_l_in, cb_prev_sum_2, Sq_chunk_t); + max_block(cb_m_in, cb_prev_max, cb_cur_max, Sq_chunk_t); // pushed, pushed, popped + + // l = torch.exp(m_2 - m) * l_2 + torch.exp(m_1 - m) * l_1 + /// l1 = torch.exp(m_2 - m) * l_2 + // unpack_reconfig_data_format(cb_prev_max_2, cb_cur_max); // DEBUG + // pack_reconfig_data_format(cb_exp_max_diff_2); + sub_exp_block(cb_m_in, cb_cur_max, cb_exp_max_diff_2, Sq_chunk_t); + mul_block_inplace(cb_prev_sum_2, cb_exp_max_diff_2, Sq_chunk_t); + /// l2 = torch.exp(m_1 - m) * l_1 + // unpack_reconfig_data_format(cb_prev_max, cb_cur_max); // DEBUG + // pack_reconfig_data_format(cb_exp_max_diff); + sub_exp_block(cb_prev_max, cb_cur_max, cb_exp_max_diff, Sq_chunk_t); + mul_block_inplace(cb_prev_sum, cb_exp_max_diff, Sq_chunk_t); + /// l = l1 + l2 + // unpack_reconfig_data_format(cb_cur_sum, cb_prev_sum); // DEBUG + // pack_reconfig_data_format(cb_cur_sum); + add_block(cb_prev_sum_2, cb_prev_sum, cb_cur_sum, Sq_chunk_t); + + // unpack_reconfig_data_format(cb_out_accumulate_im, cb_exp_max_diff); // DEBUG + // pack_reconfig_data_format(cb_out_accumulate_im); + mul_block_bcast_cols_inplace(cb_out_accumulate_im, cb_exp_max_diff, Sq_chunk_t, DHt); + mul_block_bcast_cols_inplace(cb_out_accumulate_im_2, cb_exp_max_diff_2, Sq_chunk_t, DHt); + + // unpack_reconfig_data_format(cb_out_accumulate_im, cb_out_accumulate_im_2); + // pack_reconfig_data_format(cb_out_accumulate_im); + add_block_inplace(cb_out_accumulate_im, cb_out_accumulate_im_2, q_chunk_tiles); + + // copy tiles + // unpack_reconfig_data_format(cb_cur_max, cb_cur_max); // DEBUG + // pack_reconfig_data_format(cb_prev_max); + cb_pop_front(cb_prev_max, Sq_chunk_t); + cb_pop_front(cb_m_in, Sq_chunk_t); + copy_block(cb_cur_max, cb_prev_max, Sq_chunk_t); + copy_block(cb_cur_sum, cb_prev_sum, Sq_chunk_t); + } + } + /* cb_cur_sum = 1.0 / cb_cur_sum */ + cb_push_back(cb_cur_sum, Sq_chunk_t); + + unpack_reconfig_data_format(cb_cur_sum, cb_cur_sum); // DEBUG + pack_reconfig_data_format(cb_cur_sum); + recip_block_inplace(cb_cur_sum, Sq_chunk_t); + + /* cb_out_accumulate_im *= cb_cur_sum */ + unpack_reconfig_data_format(cb_out_accumulate_im, cb_cur_sum); // DEBUG + pack_reconfig_data_format(cb_out_accumulate_im); + mul_block_bcast_cols_inplace(cb_out_accumulate_im, cb_cur_sum, Sq_chunk_t, DHt); + pack_reconfig_data_format(cb_out_final); + copy_block(cb_out_accumulate_im, cb_out_final, out_chunk_tiles); + + // free up cb_prev_max after K chunks + cb_pop_front(cb_prev_max, Sq_chunk_t); + cb_pop_front(cb_prev_sum, Sq_chunk_t); + } + } + cb_pop_front(cb_q_in, q_chunk_tiles); } } diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp index 14f0475c3c54..52bfebe11dba 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp @@ -51,7 +51,8 @@ void kernel_main() { constexpr uint32_t page_table_page_size = get_compile_time_arg_val(15); constexpr uint32_t Bkv = get_compile_time_arg_val(16); constexpr uint32_t num_cores_per_head = get_compile_time_arg_val(17); - constexpr uint32_t num_output_cores = get_compile_time_arg_val(18); + constexpr uint32_t num_heads_per_core = get_compile_time_arg_val(18); + constexpr uint32_t num_output_cores = get_compile_time_arg_val(19); uint32_t arg_idx = 0; const uint32_t q_addr = get_arg_val(arg_idx++); @@ -61,7 +62,7 @@ void kernel_main() { const uint32_t page_table_addr = get_arg_val(arg_idx++); const bool is_worker = get_arg_val(arg_idx++) == 0; const bool is_output_core = get_arg_val(arg_idx++) == 1; - const uint32_t cur_head = get_arg_val(arg_idx++); + const uint32_t cur_head_group = get_arg_val(arg_idx++); const uint32_t cur_batch = get_arg_val(arg_idx++); const uint32_t core_num_in_reduce = get_arg_val(arg_idx++); const uint32_t core_num_in_output = get_arg_val(arg_idx++); @@ -203,113 +204,115 @@ void kernel_main() { .data_format = v_data_format }; - if constexpr (is_paged_attention) { - for (uint32_t k_chunk = k_chunk_start; k_chunk < k_chunk_end; ++k_chunk) { - - // Read K chunk in row-major order (to simplify page mapping). Write tiles to CB in transposed order. - const uint32_t k_chunk_start_row_num = k_chunk * Sk_chunk_t; - cb_reserve_back(cb_k_in, k_chunk_tiles); - uint32_t k_write_ptr = get_write_ptr(cb_k_in); - barrier_count = 0; - for (uint32_t row = 0; row < Sk_chunk_t; ++row) { - uint32_t k_write_ptr_col = k_write_ptr + row*k_tile_bytes; - uint32_t virtual_k_tile_row_num = k_chunk_start_row_num + row; - uint32_t physical_k_tile_id = virtual_seq_tile_id_to_physical_tile_id(virtual_k_tile_row_num, cur_head, page_table_ptr); - for (uint32_t col = 0; col < DHt; ++col) { - noc_async_read_tile(physical_k_tile_id, k_reader, k_write_ptr_col); - physical_k_tile_id += 1; // Go to next tile in row - k_write_ptr_col += Sk_chunk_t * k_tile_bytes; // Go to next column in CB + for (uint32_t cur_head = cur_head_group*num_heads_per_core; cur_head < cur_head_group*num_heads_per_core + num_heads_per_core; ++cur_head) { + if constexpr (is_paged_attention) { + for (uint32_t k_chunk = k_chunk_start; k_chunk < k_chunk_end; ++k_chunk) { - if (++barrier_count == barrier_threshold) { - noc_async_read_barrier(); - barrier_count = 0; - } - } - } - noc_async_read_barrier(); - cb_push_back(cb_k_in, k_chunk_tiles); - - // Read V chunk in row major order, write in row-major order - cb_reserve_back(cb_v_in, k_chunk_tiles); - uint32_t v_write_ptr = get_write_ptr(cb_v_in); - barrier_count = 0; - - for (uint32_t row = 0; row < Sk_chunk_t; ++row) { - uint32_t virtual_v_tile_row_num = k_chunk_start_row_num + row; - uint32_t physical_v_tile_id = virtual_seq_tile_id_to_physical_tile_id(virtual_v_tile_row_num, cur_head, page_table_ptr); - for (uint32_t col = 0; col < DHt; ++col) { - noc_async_read_tile(physical_v_tile_id, v_reader, v_write_ptr); - physical_v_tile_id += 1; - v_write_ptr += v_tile_bytes; + // Read K chunk in row-major order (to simplify page mapping). Write tiles to CB in transposed order. + const uint32_t k_chunk_start_row_num = k_chunk * Sk_chunk_t; + cb_reserve_back(cb_k_in, k_chunk_tiles); + uint32_t k_write_ptr = get_write_ptr(cb_k_in); + barrier_count = 0; + for (uint32_t row = 0; row < Sk_chunk_t; ++row) { + uint32_t k_write_ptr_col = k_write_ptr + row*k_tile_bytes; + uint32_t virtual_k_tile_row_num = k_chunk_start_row_num + row; + uint32_t physical_k_tile_id = virtual_seq_tile_id_to_physical_tile_id(virtual_k_tile_row_num, cur_head, page_table_ptr); + for (uint32_t col = 0; col < DHt; ++col) { + noc_async_read_tile(physical_k_tile_id, k_reader, k_write_ptr_col); + physical_k_tile_id += 1; // Go to next tile in row + k_write_ptr_col += Sk_chunk_t * k_tile_bytes; // Go to next column in CB - if (++barrier_count == barrier_threshold) { - noc_async_read_barrier(); - barrier_count = 0; + if (++barrier_count == barrier_threshold) { + noc_async_read_barrier(); + barrier_count = 0; + } } } - } - noc_async_read_barrier(); - cb_push_back(cb_v_in, k_chunk_tiles); + noc_async_read_barrier(); + cb_push_back(cb_k_in, k_chunk_tiles); - } + // Read V chunk in row major order, write in row-major order + cb_reserve_back(cb_v_in, k_chunk_tiles); + uint32_t v_write_ptr = get_write_ptr(cb_v_in); + barrier_count = 0; - } else { - // Offset for current batch - const uint32_t k_batch_offset = (cur_batch % Bkv) * num_kv_heads * St * DHt; - const uint32_t v_batch_offset = (cur_batch % Bkv) * num_kv_heads * St * DHt; - const uint32_t k_head_offset = cur_head * St * DHt; - const uint32_t v_head_offset = cur_head * St * DHt; - - // Then, read K, V, Mask k_chunk_tiles at a time - const uint32_t k_chunk_offset = k_chunk_start * Sk_chunk_t * DHt; - const uint32_t v_chunk_offset = k_chunk_start * Sk_chunk_t * DHt; - uint32_t k_start_tile_id = k_batch_offset + k_head_offset + k_chunk_offset; - uint32_t v_start_tile_id = v_batch_offset + v_head_offset + v_chunk_offset; - - for (uint32_t k_chunk = k_chunk_start; k_chunk < k_chunk_end; ++k_chunk) { - // Read K chunk transposed - cb_reserve_back(cb_k_in, k_chunk_tiles); - uint32_t k_write_ptr = get_write_ptr(cb_k_in); - barrier_count = 0; - for (uint32_t col = 0; col < DHt; ++col) { - uint32_t k_tile_id = k_start_tile_id + col; for (uint32_t row = 0; row < Sk_chunk_t; ++row) { - if (row <= valid_seq_len_tiles) { - noc_async_read_tile(k_tile_id, k_reader, k_write_ptr); + uint32_t virtual_v_tile_row_num = k_chunk_start_row_num + row; + uint32_t physical_v_tile_id = virtual_seq_tile_id_to_physical_tile_id(virtual_v_tile_row_num, cur_head, page_table_ptr); + for (uint32_t col = 0; col < DHt; ++col) { + noc_async_read_tile(physical_v_tile_id, v_reader, v_write_ptr); + physical_v_tile_id += 1; + v_write_ptr += v_tile_bytes; + if (++barrier_count == barrier_threshold) { noc_async_read_barrier(); barrier_count = 0; } } - k_tile_id += DHt; - k_write_ptr += k_tile_bytes; } + noc_async_read_barrier(); + cb_push_back(cb_v_in, k_chunk_tiles); + } - noc_async_read_barrier(); - cb_push_back(cb_k_in, k_chunk_tiles); - k_start_tile_id += k_chunk_tiles; - - // Read V chunk - cb_reserve_back(cb_v_in, k_chunk_tiles); - uint32_t v_write_ptr = get_write_ptr(cb_v_in); - barrier_count = 0; - uint32_t v_tile_id = v_start_tile_id; - for (uint32_t row = 0; row < Sk_chunk_t; ++row) { + + } else { + // Offset for current batch + const uint32_t k_batch_offset = (cur_batch % Bkv) * num_kv_heads * St * DHt; + const uint32_t v_batch_offset = (cur_batch % Bkv) * num_kv_heads * St * DHt; + const uint32_t k_head_offset = cur_head * St * DHt; + const uint32_t v_head_offset = cur_head * St * DHt; + + // Then, read K, V, Mask k_chunk_tiles at a time + const uint32_t k_chunk_offset = k_chunk_start * Sk_chunk_t * DHt; + const uint32_t v_chunk_offset = k_chunk_start * Sk_chunk_t * DHt; + uint32_t k_start_tile_id = k_batch_offset + k_head_offset + k_chunk_offset; + uint32_t v_start_tile_id = v_batch_offset + v_head_offset + v_chunk_offset; + + for (uint32_t k_chunk = k_chunk_start; k_chunk < k_chunk_end; ++k_chunk) { + // Read K chunk transposed + cb_reserve_back(cb_k_in, k_chunk_tiles); + uint32_t k_write_ptr = get_write_ptr(cb_k_in); + barrier_count = 0; for (uint32_t col = 0; col < DHt; ++col) { - if (row <= valid_seq_len_tiles) { - noc_async_read_tile(v_tile_id, v_reader, v_write_ptr); - if (++barrier_count == barrier_threshold) { - noc_async_read_barrier(); - barrier_count = 0; + uint32_t k_tile_id = k_start_tile_id + col; + for (uint32_t row = 0; row < Sk_chunk_t; ++row) { + if (row <= valid_seq_len_tiles) { + noc_async_read_tile(k_tile_id, k_reader, k_write_ptr); + if (++barrier_count == barrier_threshold) { + noc_async_read_barrier(); + barrier_count = 0; + } } + k_tile_id += DHt; + k_write_ptr += k_tile_bytes; } - v_tile_id++; - v_write_ptr += v_tile_bytes; } + noc_async_read_barrier(); + cb_push_back(cb_k_in, k_chunk_tiles); + k_start_tile_id += k_chunk_tiles; + + // Read V chunk + cb_reserve_back(cb_v_in, k_chunk_tiles); + uint32_t v_write_ptr = get_write_ptr(cb_v_in); + barrier_count = 0; + uint32_t v_tile_id = v_start_tile_id; + for (uint32_t row = 0; row < Sk_chunk_t; ++row) { + for (uint32_t col = 0; col < DHt; ++col) { + if (row <= valid_seq_len_tiles) { + noc_async_read_tile(v_tile_id, v_reader, v_write_ptr); + if (++barrier_count == barrier_threshold) { + noc_async_read_barrier(); + barrier_count = 0; + } + } + v_tile_id++; + v_write_ptr += v_tile_bytes; + } + } + noc_async_read_barrier(); + cb_push_back(cb_v_in, k_chunk_tiles); + v_start_tile_id += k_chunk_tiles; } - noc_async_read_barrier(); - cb_push_back(cb_v_in, k_chunk_tiles); - v_start_tile_id += k_chunk_tiles; } } } diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp index 84ed7436816f..c488ca069b00 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp @@ -5,6 +5,7 @@ #include "dataflow_api.h" #include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/generate_bcast_scalar.hpp" #include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/generate_reduce_scaler.hpp" +#include "debug/assert.h" #include "../../rt_args_common.hpp" @@ -239,9 +240,10 @@ void kernel_main() { constexpr uint32_t num_q_heads = get_compile_time_arg_val(12); constexpr uint32_t num_kv_heads = get_compile_time_arg_val(13); constexpr uint32_t num_cores_per_head = get_compile_time_arg_val(14); - constexpr uint32_t num_reducer_cores = get_compile_time_arg_val(15); - constexpr uint32_t num_output_cores = get_compile_time_arg_val(16); - constexpr uint32_t ELEMENT_SIZE = get_compile_time_arg_val(17); + constexpr uint32_t num_heads_per_core = get_compile_time_arg_val(15); + constexpr uint32_t num_reducer_cores = get_compile_time_arg_val(16); + constexpr uint32_t num_output_cores = get_compile_time_arg_val(17); + constexpr uint32_t ELEMENT_SIZE = get_compile_time_arg_val(18); uint32_t arg_idx = 0; const uint32_t out_addr = get_arg_val(arg_idx++); @@ -249,7 +251,7 @@ void kernel_main() { const uint32_t worker_id_for_output = get_arg_val(arg_idx++); const bool is_worker = get_arg_val(arg_idx++) == 0; const bool do_output = get_arg_val(arg_idx++) == 1; - const uint32_t cur_head = get_arg_val(arg_idx++); + const uint32_t cur_head_group = get_arg_val(arg_idx++); const uint32_t cur_batch = get_arg_val(arg_idx++); const uint32_t core_num_in_reduce = get_arg_val(arg_idx++); const uint32_t core_num_in_output = get_arg_val(arg_idx++); @@ -288,7 +290,7 @@ void kernel_main() { arg_idx += num_output_cores; tt_l1_ptr uint32_t * all_output_noc_y = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx++)); - uint32_t reduce_core_index = (cur_batch * num_cores_per_batch) / num_cores_per_head + cur_head; + uint32_t reduce_core_index = (cur_batch * num_cores_per_batch) / num_cores_per_head + cur_head_group; uint32_t reduce_core_noc_x = all_reducer_noc_x[reduce_core_index]; uint32_t reduce_core_noc_y = all_reducer_noc_y[reduce_core_index]; @@ -322,6 +324,7 @@ void kernel_main() { generate_bcast_unary_scalar(cb_scale_in, scale_val); generate_reduce_scaler(cb_identity_scale_in, identity_scalar_packed); if (is_worker) { + ASSERT(num_heads_per_core == 1); // if there are workers, then head must be split across workers so there should not be more than one head per core worker_compute(in0_sender_semaphore_noc_addr, worker_id_for_reduce, reduce_core_noc_x, reduce_core_noc_y); return; } @@ -347,148 +350,151 @@ void kernel_main() { // generate and send mask to compute generate_mask(k_num_chunks, PSt, cur_pos); - if (k_chunk_end - k_chunk_start < k_num_chunks){ - // This indicates that there are computes done by other workers. Needs to wait for them and send to reducer's compute - // Wait for compute to deliver output chunk, and write to compute again for reduction - // data in cb_intermed_out is arranged as [o,m,l,o,m,l,...] with size (out_chunk_tiles + 2*PNHt)*num_cores_to_wait - // wait on in0 semaphore value to become VALID (set by sender) - noc_semaphore_wait(in0_receiver_semaphore_addr_ptr, num_cores_to_wait); - // noc_semaphore_set(in0_receiver_semaphore_addr_ptr, 0); - - // cb_wait_front(cb_intermed_out, num_tiles_to_wait); - constexpr uint32_t q_read_size = out_chunk_tiles*tile_bytes_intermed; - constexpr uint32_t ml_read_size = PNHt*tile_bytes_intermed; - for(uint32_t block = 0; block < num_cores_to_wait+1; ++block) { - - cb_reserve_back(cb_out_o, out_chunk_tiles); - cb_reserve_back(cb_m_in, PNHt); - cb_reserve_back(cb_l_in, PNHt); - - uint32_t q_write_ptr = get_read_ptr(cb_out_o); - noc_async_read(intermed_l1_read_addr, q_write_ptr, q_read_size); - intermed_l1_read_addr+=q_read_size; - noc_async_read_barrier(); - cb_push_back(cb_out_o, out_chunk_tiles); - - uint32_t m_write_ptr = get_read_ptr(cb_m_in); - noc_async_read(intermed_l1_read_addr, m_write_ptr, ml_read_size); - intermed_l1_read_addr+=ml_read_size; - noc_async_read_barrier(); - cb_push_back(cb_m_in, PNHt); - - uint32_t l_write_ptr = get_read_ptr(cb_l_in); - noc_async_read(intermed_l1_read_addr, l_write_ptr, ml_read_size); - intermed_l1_read_addr+=ml_read_size; - noc_async_read_barrier(); - cb_push_back(cb_l_in, PNHt); - } - } - // Offset for current batch - const uint32_t out_batch_offset = cur_batch * out_chunk_tiles; - - // Write entire out into its corresponding batch - uint32_t out_tile_id = out_batch_offset; - cb_wait_front(cb_out, out_chunk_tiles); + for (uint32_t cur_head = cur_head_group*num_heads_per_core; cur_head < cur_head_group*num_heads_per_core + num_heads_per_core; ++cur_head) { + if (k_chunk_end - k_chunk_start < k_num_chunks){ + ASSERT(num_heads_per_core == 1); // if there are workers, then head must be split across workers so there should not be more than one head per core + // This indicates that there are computes done by other workers. Needs to wait for them and send to reducer's compute + // Wait for compute to deliver output chunk, and write to compute again for reduction + // data in cb_intermed_out is arranged as [o,m,l,o,m,l,...] with size (out_chunk_tiles + 2*PNHt)*num_cores_to_wait + // wait on in0 semaphore value to become VALID (set by sender) + noc_semaphore_wait(in0_receiver_semaphore_addr_ptr, num_cores_to_wait); + // noc_semaphore_set(in0_receiver_semaphore_addr_ptr, 0); + + // cb_wait_front(cb_intermed_out, num_tiles_to_wait); + constexpr uint32_t q_read_size = out_chunk_tiles*tile_bytes_intermed; + constexpr uint32_t ml_read_size = PNHt*tile_bytes_intermed; + for(uint32_t block = 0; block < num_cores_to_wait+1; ++block) { + + cb_reserve_back(cb_out_o, out_chunk_tiles); + cb_reserve_back(cb_m_in, PNHt); + cb_reserve_back(cb_l_in, PNHt); + + uint32_t q_write_ptr = get_read_ptr(cb_out_o); + noc_async_read(intermed_l1_read_addr, q_write_ptr, q_read_size); + intermed_l1_read_addr+=q_read_size; + noc_async_read_barrier(); + cb_push_back(cb_out_o, out_chunk_tiles); - if constexpr(num_kv_heads > 1){ - // if gqa, we will need to write partial outputs for each head - constexpr uint32_t TILE_WIDTH = 32; - // we are assuming here that num_heads_to_write = nh/nkv is a power of 2 here, so that we don't write partial across phase - uint32_t num_heads_to_write = num_q_heads/num_kv_heads; // each head is one row in a tile - uint32_t SUBTILE_LINE_BYTES = 16*ELEMENT_SIZE; //size of 16 elements (in a row) - uint32_t starting_row = cur_head * num_heads_to_write; - uint32_t in_tile_offset_by_starting_head = starting_row < 16 ? starting_row * SUBTILE_LINE_BYTES : (starting_row - 16) * SUBTILE_LINE_BYTES + 512*ELEMENT_SIZE; + uint32_t m_write_ptr = get_read_ptr(cb_m_in); + noc_async_read(intermed_l1_read_addr, m_write_ptr, ml_read_size); + intermed_l1_read_addr+=ml_read_size; + noc_async_read_barrier(); + cb_push_back(cb_m_in, PNHt); - if (! is_out_sharded){ - for (uint32_t tile = 0; tile < out_chunk_tiles; ++tile) { + uint32_t l_write_ptr = get_read_ptr(cb_l_in); + noc_async_read(intermed_l1_read_addr, l_write_ptr, ml_read_size); + intermed_l1_read_addr+=ml_read_size; + noc_async_read_barrier(); + cb_push_back(cb_l_in, PNHt); + } + } + // Offset for current batch + const uint32_t out_batch_offset = cur_batch * out_chunk_tiles; + + // Write entire out into its corresponding batch + uint32_t out_tile_id = out_batch_offset; + cb_wait_front(cb_out, out_chunk_tiles); + + if constexpr(num_kv_heads > 1){ + // if gqa, we will need to write partial outputs for each head + constexpr uint32_t TILE_WIDTH = 32; + // we are assuming here that num_heads_to_write = nh/nkv is a power of 2 here, so that we don't write partial across phase + uint32_t num_heads_to_write = num_q_heads/num_kv_heads; // each head is one row in a tile + uint32_t SUBTILE_LINE_BYTES = 16*ELEMENT_SIZE; //size of 16 elements (in a row) + uint32_t starting_row = cur_head * num_heads_to_write; + uint32_t in_tile_offset_by_starting_head = starting_row < 16 ? starting_row * SUBTILE_LINE_BYTES : (starting_row - 16) * SUBTILE_LINE_BYTES + 512*ELEMENT_SIZE; + + if (! is_out_sharded){ + for (uint32_t tile = 0; tile < out_chunk_tiles; ++tile) { - uint64_t out_writer_noc_addr = get_noc_addr(out_tile_id, out_writer) + in_tile_offset_by_starting_head; - uint32_t l1_read_addr = get_read_ptr(cb_out) + tile*tile_bytes + in_tile_offset_by_starting_head; + uint64_t out_writer_noc_addr = get_noc_addr(out_tile_id, out_writer) + in_tile_offset_by_starting_head; + uint32_t l1_read_addr = get_read_ptr(cb_out) + tile*tile_bytes + in_tile_offset_by_starting_head; - // write partial output for each head - for (uint32_t head = 0; head < num_heads_to_write; ++head) { + // write partial output for each head + for (uint32_t head = 0; head < num_heads_to_write; ++head) { - // Write first phase - noc_async_write(l1_read_addr, out_writer_noc_addr, SUBTILE_LINE_BYTES); + // Write first phase + noc_async_write(l1_read_addr, out_writer_noc_addr, SUBTILE_LINE_BYTES); - // Write second phase - noc_async_write(l1_read_addr+256*ELEMENT_SIZE, out_writer_noc_addr+256*ELEMENT_SIZE, SUBTILE_LINE_BYTES); + // Write second phase + noc_async_write(l1_read_addr+256*ELEMENT_SIZE, out_writer_noc_addr+256*ELEMENT_SIZE, SUBTILE_LINE_BYTES); - l1_read_addr += SUBTILE_LINE_BYTES; - out_writer_noc_addr += SUBTILE_LINE_BYTES; + l1_read_addr += SUBTILE_LINE_BYTES; + out_writer_noc_addr += SUBTILE_LINE_BYTES; - if (++barrier_count == barrier_threshold) { - noc_async_writes_flushed(); - barrier_count = 0; + if (++barrier_count == barrier_threshold) { + noc_async_writes_flushed(); + barrier_count = 0; + } } - } - ++out_tile_id; + ++out_tile_id; + } } - } - // sharded out case - else if (do_output){ - // read from reducer cores - constexpr uint32_t num_reducers_per_output = num_reducer_cores / num_output_cores; - constexpr uint32_t num_reducers_to_wait = num_reducers_per_output-1; - volatile tt_l1_ptr uint32_t* output_self_semaphore_addr_ptr = reinterpret_cast(output_semaphore_addr); - noc_semaphore_wait(output_self_semaphore_addr_ptr, num_reducers_to_wait); + // sharded out case + else if (do_output){ + // read from reducer cores + constexpr uint32_t num_reducers_per_output = num_reducer_cores / num_output_cores; + constexpr uint32_t num_reducers_to_wait = num_reducers_per_output-1; + volatile tt_l1_ptr uint32_t* output_self_semaphore_addr_ptr = reinterpret_cast(output_semaphore_addr); + noc_semaphore_wait(output_self_semaphore_addr_ptr, num_reducers_to_wait); - uint32_t reduce_core_read_index_start = (cur_batch * num_cores_per_batch) / num_cores_per_head; + uint32_t reduce_core_read_index_start = (cur_batch * num_cores_per_batch) / num_cores_per_head; - for (uint32_t reduce_core_read_index = reduce_core_read_index_start + 1; reduce_core_read_index < reduce_core_read_index_start+num_reducers_per_output; reduce_core_read_index++){ - uint32_t reduce_core_read_noc_x = all_reducer_noc_x[reduce_core_read_index]; - uint32_t reduce_core_read_noc_y = all_reducer_noc_y[reduce_core_read_index]; + for (uint32_t reduce_core_read_index = reduce_core_read_index_start + 1; reduce_core_read_index < reduce_core_read_index_start+num_reducers_per_output; reduce_core_read_index++){ + uint32_t reduce_core_read_noc_x = all_reducer_noc_x[reduce_core_read_index]; + uint32_t reduce_core_read_noc_y = all_reducer_noc_y[reduce_core_read_index]; - uint64_t out_reader_base_noc_addr = get_noc_addr(reduce_core_read_noc_x, reduce_core_read_noc_y, get_read_ptr(cb_out)) + in_tile_offset_by_starting_head; + uint64_t out_reader_base_noc_addr = get_noc_addr(reduce_core_read_noc_x, reduce_core_read_noc_y, get_read_ptr(cb_out)) + in_tile_offset_by_starting_head; - for (uint32_t tile = 0; tile < out_chunk_tiles; ++tile) { - uint32_t l1_write_addr = get_write_ptr(cb_out) + tile*tile_bytes + in_tile_offset_by_starting_head; - uint32_t out_reader_noc_addr = out_reader_base_noc_addr; + for (uint32_t tile = 0; tile < out_chunk_tiles; ++tile) { + uint32_t l1_write_addr = get_write_ptr(cb_out) + tile*tile_bytes + in_tile_offset_by_starting_head; + uint32_t out_reader_noc_addr = out_reader_base_noc_addr; - // write partial output for each head - for (uint32_t head = 0; head < num_heads_to_write; ++head) { + // write partial output for each head + for (uint32_t head = 0; head < num_heads_to_write; ++head) { - // Write first phase - noc_async_read(out_reader_noc_addr, l1_write_addr, SUBTILE_LINE_BYTES); + // Write first phase + noc_async_read(out_reader_noc_addr, l1_write_addr, SUBTILE_LINE_BYTES); - // Write second phase - noc_async_read(out_reader_noc_addr+256*ELEMENT_SIZE, l1_write_addr+256*ELEMENT_SIZE, SUBTILE_LINE_BYTES); + // Write second phase + noc_async_read(out_reader_noc_addr+256*ELEMENT_SIZE, l1_write_addr+256*ELEMENT_SIZE, SUBTILE_LINE_BYTES); - l1_write_addr += SUBTILE_LINE_BYTES; - out_reader_noc_addr += SUBTILE_LINE_BYTES; + l1_write_addr += SUBTILE_LINE_BYTES; + out_reader_noc_addr += SUBTILE_LINE_BYTES; - if (++barrier_count == barrier_threshold) { - noc_async_read_barrier(); - barrier_count = 0; + if (++barrier_count == barrier_threshold) { + noc_async_read_barrier(); + barrier_count = 0; + } } + out_reader_noc_addr += tile_bytes; } - out_reader_noc_addr += tile_bytes; } + noc_async_read_barrier(); + } else { + // tell output core that its output is ready + uint32_t output_core_noc_x = all_output_noc_x[cur_batch]; + uint32_t output_core_noc_y = all_output_noc_y[cur_batch]; + const uint64_t output_core_semaphore_noc_addr = get_noc_addr(output_core_noc_x, output_core_noc_y, output_semaphore_addr); + noc_semaphore_inc(output_core_semaphore_noc_addr, 1); } - noc_async_read_barrier(); } else { - // tell output core that its output is ready - uint32_t output_core_noc_x = all_output_noc_x[cur_batch]; - uint32_t output_core_noc_y = all_output_noc_y[cur_batch]; - const uint64_t output_core_semaphore_noc_addr = get_noc_addr(output_core_noc_x, output_core_noc_y, output_semaphore_addr); - noc_semaphore_inc(output_core_semaphore_noc_addr, 1); - } - } else { - // if mqa, we don't need to gather outputs for other heads so we can just write entire tiles to memory - if (! is_out_sharded){ - uint32_t l1_read_addr = get_read_ptr(cb_out); - for (uint32_t tile = 0; tile < out_chunk_tiles; ++tile) { - noc_async_write_tile(out_tile_id, out_writer, l1_read_addr); - ++out_tile_id; - l1_read_addr += tile_bytes; - if (++barrier_count == barrier_threshold) { - noc_async_writes_flushed(); - barrier_count = 0; + // if mqa, we don't need to gather outputs for other heads so we can just write entire tiles to memory + if (! is_out_sharded){ + uint32_t l1_read_addr = get_read_ptr(cb_out); + for (uint32_t tile = 0; tile < out_chunk_tiles; ++tile) { + noc_async_write_tile(out_tile_id, out_writer, l1_read_addr); + ++out_tile_id; + l1_read_addr += tile_bytes; + if (++barrier_count == barrier_threshold) { + noc_async_writes_flushed(); + barrier_count = 0; + } } } } + noc_async_write_barrier(); + cb_pop_front(cb_out, out_chunk_tiles); } - noc_async_write_barrier(); - cb_pop_front(cb_out, out_chunk_tiles); } diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp index f4cad7175178..9c24ee5a0980 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp @@ -140,11 +140,13 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( // balance the number of cores to use based on batch uint32_t num_cores_per_batch = num_cores_available / B; uint32_t num_active_cores = num_cores_per_batch * B; - uint32_t num_cores_per_head = num_cores_per_batch / num_kv_heads; - uint32_t num_reducer_cores = (num_cores_per_head == 0) ? B : num_kv_heads*B; + //// for core assignment, it is the same whether there's 1 core for head or 1 core for many heads + uint32_t num_cores_per_head = std::max((uint32_t) 1, num_cores_per_batch / num_kv_heads); + uint32_t num_heads_per_core = std::max((uint32_t) 1, num_kv_heads / num_cores_per_batch); + uint32_t num_reducer_cores = num_kv_heads*B / num_heads_per_core; uint32_t num_output_cores = B; - TT_FATAL(num_cores_per_head > 0, "Case not supported for more n_kv_heads*batch > number of cores. Got batch={}, n_kv_heads={} and num_cores_available={}. Let's assume each core can handle at most one head", B, num_kv_heads, num_cores_available); + TT_FATAL(((num_cores_per_head >= 1) && (num_heads_per_core == 1)) || ((num_cores_per_head == 1) && (num_heads_per_core >= 1)), "This assertion should always be true, unless core assignment logic is wrong"); // create core group, assume n batch and k_heads: // this is a 1D list of cores sorted by batch_output1, worker, ..., batch_output2, worker, ..., batch_output n, worker, ... @@ -189,6 +191,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( log_debug("num_cores_available: {}", num_cores_available); log_debug("num_cores_per_batch: {}", num_cores_per_batch); log_debug("num_cores_per_head: {}", num_cores_per_head); + log_debug("num_heads_per_core: {}", num_heads_per_core); log_debug("num_active_cores: {}", num_active_cores); log_debug("num_reducer_cores: {}", num_reducer_cores); log_debug("num_output_cores: {}", num_output_cores); @@ -456,7 +459,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( for (uint32_t i = 0; i < num_active_cores; ++i) { CoreCoord core = core_group[i]; - uint32_t worker_id_for_reduce = (num_cores_per_head == 0) ? -1 : i % num_cores_per_head - 1; + uint32_t worker_id_for_reduce = i % num_cores_per_head - 1; bool do_reduce = (worker_id_for_reduce == -1); if (do_reduce) { reduce_core_noc_x = core.x; @@ -506,7 +509,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( B, PNHt, St, DHt, Sk_chunk_t, num_active_cores, is_q_sharded, num_cores_per_batch, k_chunk_size, log2_page_size, index_stick_size, (uint32_t)is_paged_attention, num_kv_heads, page_block_size_t, - log2_page_table_page_size, page_table_stick_size, Bkv, num_cores_per_head, num_output_cores + log2_page_table_page_size, page_table_stick_size, Bkv, num_cores_per_head, num_heads_per_core, num_output_cores }; std::vector writer_compile_time_args_common = { @@ -522,6 +525,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( num_q_heads, num_kv_heads, num_cores_per_head, + num_heads_per_core, num_reducer_cores, num_output_cores, output_tensor.element_size() @@ -531,7 +535,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( St, DHt, PNHt, Sk_chunk_t, qk_in0_block_w, qk_out_subblock_w, qk_out_subblock_h, qk_in0_num_subblocks, qk_in1_num_subblocks, qk_num_blocks, out_in0_block_w, out_out_subblock_w, out_out_subblock_h, out_in0_num_subblocks, out_in1_num_subblocks, out_num_blocks, - num_cores_per_batch, k_chunk_size, num_cores_per_head + num_cores_per_batch, k_chunk_size, num_cores_per_head, num_heads_per_core }; std::map defines; @@ -585,17 +589,14 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( // Set rt args for (uint32_t i = 0; i < num_active_cores; ++i) { CoreCoord core = core_group[i]; - uint32_t worker_id_for_reduce = (num_cores_per_head == 0) ? -1 : i % num_cores_per_head - 1; + uint32_t worker_id_for_reduce = i % num_cores_per_head - 1; uint32_t worker_id_for_output = i % num_cores_per_batch - 1; bool do_reduce = (worker_id_for_reduce == -1); bool do_output = (worker_id_for_output == -1); - // 64 cores, 4 batch, 8 head - // num_cores_per_batch = 16 - // num_cores_per_head = 2 - uint32_t cur_head = (num_cores_per_head == 0) ? 0 : (i % num_cores_per_batch) / num_cores_per_head; + uint32_t cur_head = (i % num_cores_per_batch) / num_cores_per_head; uint32_t cur_batch = i / num_cores_per_batch; - uint32_t core_num_in_reduce = (num_cores_per_head == 0) ? 0 : i % num_cores_per_head; + uint32_t core_num_in_reduce = i % num_cores_per_head; uint32_t core_num_in_output = i % num_cores_per_batch; uint32_t cur_pos = use_cur_pos_tensor ? -1 : cur_pos_ids.at(cur_batch);