Skip to content

Commit

Permalink
#12544: support wide channels (> 256)
Browse files Browse the repository at this point in the history
  • Loading branch information
mywoodstock committed Sep 12, 2024
1 parent 6d3424c commit 7b96435
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 87 deletions.
14 changes: 10 additions & 4 deletions tests/ttnn/unit_tests/operations/test_maxpool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def run_max_pool(
else:
ttact = ttnn.from_torch(act_reshaped, dtype)

pre_shard = True
# pre_shard = False
# pre_shard = True
pre_shard = False

ttact_device = ttnn.to_device(ttact, device)
if pre_shard:
Expand Down Expand Up @@ -113,7 +113,7 @@ def run_max_pool(
# interleaved_mem_config = ttnn.L1_MEMORY_CONFIG
# output = ttnn.to_memory_config(output, interleaved_mem_config)
output_host = output.cpu()
output_pytorch_padded = ttnn.to_torch(output_host)
output_pytorch_padded = torch.Tensor(ttnn.to_torch(output_host))
output_pytorch = output_pytorch_padded[:, :, :, :in_c]

## reference
Expand All @@ -129,9 +129,11 @@ def run_max_pool(
## test for equivalance
golden_shape = golden_pytorch.shape
output_pytorch = output_pytorch.reshape(golden_shape[0], golden_shape[2], golden_shape[3], golden_shape[1])
output_pytorch = torch.permute(output_pytorch, (0, 3, 1, 2)) ## N, C, H, W

# torch.save(output_pytorch, "output_pytorch.pt")
# torch.save(golden_pytorch, "golden_pytorch.pt")

output_pytorch = torch.permute(output_pytorch, (0, 3, 1, 2)) ## N, C, H, W
passing, pcc = assert_with_pcc(output_pytorch, golden_pytorch)

logger.debug(f"Passing: {passing}, PCC: {pcc}")
Expand Down Expand Up @@ -187,6 +189,10 @@ def run_max_pool(
# [64, 16, 528, 80], ## oom
# [128, 16, 528, 80], ## oom
# [256, 16, 528, 80], ## oom
## wide for vgg
[1, 256, 56, 56],
[1, 512, 28, 28],
[1, 512, 14, 14],
)
),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,20 @@

#include <cstdint>

// #include "compute_kernel_api.h"
#include "compute_kernel_api/tilize.h"
#include "compute_kernel_api/reduce.h"
#include "compute_kernel_api/pack_untilize.h"
// #include "tools/profiler/kernel_profiler.hpp"

#define DEBUG_PRINT 0

#if DEBUG_PRINT == 1
#include "debug/dprint.h"
// #include "debug_macros.h"

// SliceRange srt = SliceRange{.h0 = 0, .h1 = 32, .hs = 8, .w0 = 0, .w1 = 32, .ws = 4};
// SliceRange srr = SliceRange{.h0 = 0, .h1 = 1, .hs = 8, .w0 = 0, .w1 = 32, .ws = 1};
// SliceRange srr1 = SliceRange{.h0 = 1, .h1 = 2, .hs = 8, .w0 = 0, .w1 = 32, .ws = 1};
// SliceRange src = SliceRange{.h0 = 0, .h1 = 32, .hs = 1, .w0 = 0, .w1 = 1, .ws = 1};

inline void print_tile_rows(uint32_t cb_id, uint32_t rows = 32, uint32_t tile_id = 0, bool untilize = false) {
// UNPACK(( DPRINT << "======" << ENDL() ));
for (uint16_t r = 0; r < rows; ++ r) {
SliceRange sr = SliceRange{.h0 = r, .h1 = (uint16_t)(r + 1), .hs = 1, .w0 = 0, .w1 = 32, .ws = 1};
// UNPACK(( DPRINT << (uint)r << " :: " << TileSlice(cb_id, tile_id, sr, true, untilize) << ENDL() ));
UNPACK(( DPRINT << (uint)r << " :: " << TileSlice(cb_id, tile_id, sr, true, untilize) ));
UNPACK(( DPRINT << (uint)r << " :: " << TileSlice(cb_id, tile_id, sr, true, untilize) << ENDL() ));
}
// UNPACK(( DPRINT << "++++++" << ENDL() ));
}
Expand All @@ -40,47 +31,45 @@
UNPACK(( DPRINT << "++++++" << ENDL() ));
}

// inline void print_cb_details(uint32_t cb_id) {
// DPRINT << "cb_id " << cb_id << ": { "
// << "size: " << cb_interface[cb_id].fifo_size << ", "
// << "limit: " << cb_interface[cb_id].fifo_limit << ", "
// << "page_size: " << cb_interface[cb_id].fifo_page_size << ", "
// << "num_pages: " << cb_interface[cb_id].fifo_num_pages << ", "
// << "rd_ptr: " << cb_interface[cb_id].fifo_rd_ptr << ", "
// << "wr_ptr: " << cb_interface[cb_id].fifo_wr_ptr << ", "
// << "wr_tile_ptr: " << cb_interface[cb_id].fifo_wr_tile_ptr << " }" << ENDL();
// }
inline void print_cb_details(uint32_t cb_id) {
DPRINT << "cb_id " << cb_id << ": { "
<< "size: " << cb_interface[cb_id].fifo_size << ", "
<< "limit: " << cb_interface[cb_id].fifo_limit << ", "
<< "page_size: " << cb_interface[cb_id].fifo_page_size << ", "
<< "num_pages: " << cb_interface[cb_id].fifo_num_pages << ", "
<< "rd_ptr: " << cb_interface[cb_id].fifo_rd_ptr << ", "
<< "wr_ptr: " << cb_interface[cb_id].fifo_wr_ptr << ", "
<< "wr_tile_ptr: " << cb_interface[cb_id].fifo_wr_tile_ptr << " }" << ENDL();
}
#endif

template<uint32_t in_ntiles_hw, uint32_t in_ntiles_c, uint32_t out_ntiles_c, uint32_t nblocks, bool is_partial_tile, uint32_t split_reader, uint32_t unpA_face_r_dim>
template<uint32_t in_ntiles_hw, uint32_t in_ntiles_c, uint32_t out_ntiles_c, bool is_partial_tile, uint32_t split_reader, uint32_t unpA_face_r_dim, uint32_t in_nblocks_c>
inline void reduce_h_fused(
const uint32_t in_cb_id,
const uint32_t in_scalar_cb_id,
const uint32_t in_ntiles_hwc,
const uint32_t in_ntiles_hwc_block,
const uint32_t in_stick_index,
const uint32_t out_cb_id) {

constexpr uint32_t num_output_tiles = out_ntiles_c * nblocks;
constexpr uint32_t num_output_tiles = out_ntiles_c / in_nblocks_c;
constexpr uint32_t num_faces_in_tile = is_partial_tile ? 1 : 2;
constexpr uint32_t num_out_rows = 1;
cb_reserve_back(out_cb_id, 1);
tile_regs_acquire();
for (uint32_t out_elem_i = 0; out_elem_i < nblocks; ++ out_elem_i) {
const uint32_t curr_in_cb_id = split_reader ? (in_cb_id + (in_stick_index * nblocks + out_elem_i)&0x1) : in_cb_id;
for (uint32_t c_i = 0; c_i < in_nblocks_c; ++ c_i) {
cb_reserve_back(out_cb_id, 1);
const uint32_t curr_in_cb_id = split_reader ? (in_cb_id + (in_stick_index & 0x1)) : in_cb_id;
cb_wait_front(curr_in_cb_id, 1);
unpack_tilizeA_B_block(curr_in_cb_id, in_scalar_cb_id, in_ntiles_hwc, 0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/, num_faces_in_tile /* unpack 1 or 2 faces ) */, unpA_face_r_dim);
for (uint32_t c_i = 0; c_i < in_ntiles_c; ++c_i) {
reduce_tile_math(in_ntiles_c * out_elem_i + c_i, num_faces_in_tile /* reduce 1 or 2 faces */);
tile_regs_acquire();
unpack_tilizeA_B_block(curr_in_cb_id, in_scalar_cb_id, in_ntiles_hwc_block, 0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/, num_faces_in_tile /* unpack 1 or 2 faces ) */, unpA_face_r_dim);
for (uint32_t c_i = 0; c_i < in_ntiles_c / in_nblocks_c; ++c_i) {
reduce_tile_math(c_i, num_faces_in_tile /* reduce 1 or 2 faces */);
}
cb_pop_front(curr_in_cb_id, 1);
tile_regs_wait();
tile_regs_commit();
pack_untilize_dst<num_output_tiles>(out_cb_id, 1/*out_subblock_h*/, 0, num_out_rows, num_faces_in_tile); /* pack 1 row (1x16 or 1x32) */
tile_regs_release();
cb_push_back(out_cb_id, 1);
}

tile_regs_wait();
tile_regs_commit();
pack_untilize_dst<num_output_tiles>(out_cb_id, 1/*out_subblock_h*/, 0, num_out_rows, num_faces_in_tile); /* pack 1 row (1x16 or 1x32) */
tile_regs_release();

cb_push_back(out_cb_id, 1);
}

namespace NAMESPACE {
Expand All @@ -95,33 +84,32 @@ void MAIN {
constexpr uint32_t out_h = get_compile_time_arg_val(4);
constexpr uint32_t out_w = get_compile_time_arg_val(5);
constexpr uint32_t out_ntiles_c = get_compile_time_arg_val(7);
constexpr uint32_t nblocks = get_compile_time_arg_val(8);

constexpr uint32_t split_reader = get_compile_time_arg_val(12);

constexpr uint32_t nsticks_per_core_by_nblocks = get_compile_time_arg_val(13);
constexpr uint32_t nsticks_per_core = get_compile_time_arg_val(13);
constexpr uint32_t in_c = get_compile_time_arg_val(14);
constexpr uint32_t num_output_tiles = out_ntiles_c * nblocks;
constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(15);

constexpr uint32_t num_output_tiles = out_ntiles_c;

constexpr uint32_t in_cb_id = tt::CB::c_in0; // and tt::CB::c_in1 for split reader
constexpr uint32_t in_scalar_cb_id = tt::CB::c_in4;
constexpr uint32_t in_tiled_cb_id = tt::CB::c_intermed0;
constexpr uint32_t out_cb_id = tt::CB::c_out0;

// const uint32_t TILE_WIDTH = 32;
constexpr bool is_partial_tile = in_c < 32;
static_assert((!is_partial_tile || (in_c == 16)), "Partial tile must have c_dim 16");
constexpr uint32_t num_faces_in_tile = is_partial_tile ? 1 : 2;
constexpr uint32_t num_out_rows = 1;

tilizeA_B_reduce_init<true>(in_cb_id, in_scalar_cb_id, in_ntiles_hwc, out_cb_id, num_faces_in_tile, window_size_hw);
uint32_t in_ntiles_hwc_block = in_ntiles_hwc / in_nblocks_c;
tilizeA_B_reduce_init<true>(in_cb_id, in_scalar_cb_id, in_ntiles_hwc_block, out_cb_id, num_faces_in_tile, window_size_hw);
pack_untilize_dst_init_short<num_output_tiles>(out_cb_id, num_out_rows, num_faces_in_tile); /* pack 1 row (1x16 or 1x32) */

cb_wait_front(in_scalar_cb_id, 1);
for (uint32_t i = 0; i < nsticks_per_core_by_nblocks; ++ i) {
// NOTE: Assuming in_ntiles_hw < 8 for now.
// TODO: subblocking to support this.
reduce_h_fused<in_ntiles_hw, in_ntiles_c, out_ntiles_c, nblocks, is_partial_tile, split_reader, window_size_hw>(in_cb_id, in_scalar_cb_id, in_ntiles_hwc, i, out_cb_id);
for (uint32_t i = 0; i < nsticks_per_core; ++ i) {
reduce_h_fused<in_ntiles_hw, in_ntiles_c, out_ntiles_c, is_partial_tile, split_reader, window_size_hw, in_nblocks_c>(in_cb_id, in_scalar_cb_id, in_ntiles_hwc_block, i, out_cb_id);
}
cb_pop_front(in_scalar_cb_id, 1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ void kernel_main() {
const uint32_t in_cb_nsticks = get_compile_time_arg_val(7);

const uint32_t in_c = get_compile_time_arg_val(8);
const uint32_t nblocks = get_compile_time_arg_val(9);

const uint32_t split_reader = get_compile_time_arg_val(10);
const uint32_t reader_id = get_compile_time_arg_val(11);

// compile time args
// value of 1 in bf16 in a uin32_t
constexpr uint32_t bf16_one_u32 = get_compile_time_arg_val(12);

constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(13);

// static_assert(0 == reader_nindices%2, "reader_nindices must be multiple of 2");

constexpr uint32_t TILE_WIDTH = 32;
Expand Down Expand Up @@ -91,24 +91,21 @@ void kernel_main() {

uint32_t in_w_padded = in_w + 2 * pad_w;

uint32_t npages_to_reserve = nblocks;
uint32_t npages_to_reserve = 1;
uint32_t counter = reader_id;
while (counter < reader_nindices) {
cb_reserve_back(in_cb_id, npages_to_reserve);

uint32_t out_l1_write_addr_base = get_write_ptr(in_cb_id);
uint32_t out_l1_write_addr = out_l1_write_addr_base;
for (uint32_t i = 0; i < nblocks; ++ i) {
uint16_t top_left_local_index = reader_indices_ptr[counter ++];
uint32_t h_multiples = 0;
for (uint32_t h = 0; h < window_h; ++ h, h_multiples += in_w_padded) {
uint32_t stick_offset = top_left_local_index + h_multiples;
uint32_t read_offset = in_l1_read_base_addr + (stick_offset << in_nbytes_c_log2);
noc_async_read_one_packet(get_noc_addr(read_offset), out_l1_write_addr, in_nbytes_c * window_w);
out_l1_write_addr += in_nbytes_c * window_w;
}
if (split_reader) counter++; // interleave the indices
uint16_t top_left_local_index = reader_indices_ptr[counter ++];
uint32_t h_multiples = 0;
for (uint32_t h = 0; h < window_h; ++ h, h_multiples += in_w_padded) {
uint32_t stick_offset = top_left_local_index + h_multiples;
uint32_t read_offset = in_l1_read_base_addr + (stick_offset << in_nbytes_c_log2);
noc_async_read_one_packet(get_noc_addr(read_offset), out_l1_write_addr, in_nbytes_c * window_w);
out_l1_write_addr += in_nbytes_c * window_w;
}
if (split_reader) counter++; // interleave the indices
noc_async_read_barrier();
cb_push_back(in_cb_id, npages_to_reserve);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <cstdint>
#include <cstring>
#include "dataflow_api.h"

#define ENABLE_DEBUG_PRINT 0

#if ENABLE_DEBUG_PRINT == 1
#include "debug/dprint.h"

inline void print_pages(uint32_t l1_addr, uint32_t pagelen, uint32_t npages, uint32_t start = 0) {
volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint16_t*>(l1_addr) + start * pagelen;
for (uint32_t page = 0; page < npages; ++ page) {
DPRINT << start + page << ": ";
for (uint32_t j = 0; j < pagelen; ++ j, ++ ptr) {
DPRINT << BF16(*ptr) << " ";
}
DPRINT << ENDL();
}
}
#endif

#define ALWI inline __attribute__((always_inline))

// Fill an L1 buffer with the given val
// WARNING: Use with caution as there's no memory protection. Make sure size is within limits
ALWI bool fill_with_val(uint32_t begin_addr, uint32_t n, uint16_t val) {
// simplest impl:
volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(begin_addr);
for (uint32_t i = 0; i < n/2; ++ i) {
ptr[i] = (val | (val << 16));
}
return true;
}

/**
* Max-pool 2D.
*/
void kernel_main() {
const uint32_t reader_nindices = get_compile_time_arg_val(0);
const uint32_t window_h = get_compile_time_arg_val(1);
const uint32_t window_w = get_compile_time_arg_val(2);

const int32_t pad_w = get_compile_time_arg_val(3);

// channel size in bytes, multiple of 32
const uint32_t in_nbytes_c = get_compile_time_arg_val(4);
const uint32_t in_nbytes_c_log2 = get_compile_time_arg_val(5);

// input tensor height / width / channels
const int32_t in_w = get_compile_time_arg_val(6);
const uint32_t in_cb_nsticks = get_compile_time_arg_val(7);

const uint32_t in_c = get_compile_time_arg_val(8);

const uint32_t split_reader = get_compile_time_arg_val(10);
const uint32_t reader_id = get_compile_time_arg_val(11);

// value of 1 in bf16 in a uin32_t
constexpr uint32_t bf16_one_u32 = get_compile_time_arg_val(12);

constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(13);

// static_assert(0 == reader_nindices%2, "reader_nindices must be multiple of 2");

constexpr uint32_t TILE_WIDTH = 32;

constexpr uint32_t in_cb_id = (reader_id == 1) ? tt::CB::c_in1 : tt::CB::c_in0;
constexpr uint32_t in_shard_cb_id = tt::CB::c_in2; // local input shard
constexpr uint32_t in_reader_indices_cb_id = tt::CB::c_in3;
constexpr uint32_t in_scalar_cb_id = tt::CB::c_in4;

constexpr uint32_t ROW_HW = 64;

// Reduce scalar = 1
if (reader_id == 0) {
cb_reserve_back(in_scalar_cb_id, 1);
uint32_t bf16_one_u16 = bf16_one_u32 >> 16;
// fill 1 row w/ scalar
fill_with_val(get_write_ptr(in_scalar_cb_id), ROW_HW, bf16_one_u16);
cb_push_back(in_scalar_cb_id, 1);
}

uint32_t in_l1_read_base_addr = get_read_ptr(in_shard_cb_id);
uint32_t reader_indices_l1_addr = get_read_ptr(in_reader_indices_cb_id);
volatile tt_l1_ptr uint16_t* reader_indices_ptr = reinterpret_cast<volatile tt_l1_ptr uint16_t*>(reader_indices_l1_addr);

uint32_t in_w_padded = in_w + 2 * pad_w;

uint32_t npages_to_reserve = 1;
uint32_t counter = reader_id;
while (counter < reader_nindices) {
uint16_t top_left_local_index = reader_indices_ptr[counter ++];
for (uint32_t c_i = 0; c_i < in_nblocks_c; ++ c_i) {
cb_reserve_back(in_cb_id, npages_to_reserve);
uint32_t out_l1_write_addr_base = get_write_ptr(in_cb_id);
uint32_t out_l1_write_addr = out_l1_write_addr_base;
for (uint32_t h = 0; h < window_h; ++ h) {
for (uint32_t w = 0; w < window_w; ++ w) {
uint32_t stick_offset = top_left_local_index + w + h * in_w_padded;
uint32_t read_offset = in_l1_read_base_addr + (stick_offset * in_nbytes_c + c_i * TILE_WIDTH * 8 * 2); // 2 bytes, max 8 tiles
noc_async_read_one_packet(get_noc_addr(read_offset), out_l1_write_addr, TILE_WIDTH * 8 * 2);
out_l1_write_addr += TILE_WIDTH * 8 * 2;
}
}
noc_async_read_barrier();
cb_push_back(in_cb_id, npages_to_reserve);
}
if (split_reader) counter++; // interleave the indices
}
} // kernel_main()
Loading

0 comments on commit 7b96435

Please sign in to comment.