Skip to content

Commit

Permalink
#0: Conv act split reader speed up
Browse files Browse the repository at this point in the history
  • Loading branch information
Pavle Josipovic authored and pavlejosipovic committed Sep 20, 2024
1 parent b8aca23 commit 8e0e2e1
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 273 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def __init__(
reshard_if_not_optimal=False,
)
if whb0_and_b16:
self.conv1_config.act_block_h_override = 64
self.conv1_config.act_block_h_override = 256

self.conv1_kernel_size = (4, 4)
self.conv1_stride = (1, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@pytest.mark.parametrize(
"batch_size, test, expected_perf",
[
[16, "16-act_dtype0-weight_dtype0-math_fidelity0-device_params0", 5020],
[16, "16-act_dtype0-weight_dtype0-math_fidelity0-device_params0", 5255],
],
)
def test_perf_device(batch_size, test, expected_perf):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
@pytest.mark.models_device_performance_bare_metal
@pytest.mark.parametrize(
"batch, groups, expected_device_perf_fps",
((2, 1, 650.0),),
((2, 1, 683.0),),
)
def test_unet_perf_device(batch: int, groups: int, expected_device_perf_fps: float):
command = f"pytest models/experimental/functional_unet/tests/test_unet_model.py::test_unet_model[device_params0-{groups}-{batch}]"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@

#define DILATION_W get_compile_time_arg_val(4)
void kernel_main() {
constexpr uint32_t LOCAL_PACKED_READER_INDICES_MAX_SIZE = 128;
uint32_t local_packed_reader_indices[LOCAL_PACKED_READER_INDICES_MAX_SIZE];

constexpr bool act_in_dram = get_compile_time_arg_val(0)== 1;
constexpr uint32_t stride_h = get_compile_time_arg_val(1);
constexpr uint32_t stride_w = get_compile_time_arg_val(2);
Expand All @@ -31,12 +28,8 @@ void kernel_main() {
constexpr uint32_t act_num_blocks_h = get_compile_time_arg_val(16);
constexpr uint32_t act_block_h_datums_last_block = get_compile_time_arg_val(25);

uint32_t act_block_h_datums_read_last_block;
if (act_block_h_datums_last_block > act_block_h_datums) {
act_block_h_datums_read_last_block = act_block_h_datums / 2;
} else {
act_block_h_datums_read_last_block = act_block_h_datums_last_block / 2;
}
constexpr uint32_t act_block_h_datums_read_last_block =
act_block_h_datums_last_block > act_block_h_datums ? act_block_h_datums / 2 : act_block_h_datums_last_block / 2;

uint32_t i = 0;
uint32_t noop = get_arg_val<uint32_t>(i); i+=1;
Expand Down Expand Up @@ -68,24 +61,15 @@ void kernel_main() {
reader_offset += (dilation_h * conv_act_size_w_padded);
}

#ifdef SPLIT_READER
constexpr uint32_t act_block_h_datums_read = act_block_h_datums / 2; // Extra /2 because of packed uint16 reads
constexpr uint32_t act_block_num_tiles_read = act_block_num_tiles;
#else
constexpr uint32_t act_block_h_datums_read = act_block_h_datums / 2; // packed uint16 reads
constexpr uint32_t act_block_num_tiles_read = act_block_num_tiles;
#endif


// LOOP TO FILL READER INDICES
constexpr uint32_t cb_reader_indices = tt::CB::c_in4;
volatile tt_l1_ptr uint32_t* packed_reader_indices_ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(get_write_ptr(cb_reader_indices));

uint32_t reader_idx = 0;

// Copy packed reader indices to local memory for faster access
constexpr bool cache_packed_reader_indices = act_block_h_datums_read <= LOCAL_PACKED_READER_INDICES_MAX_SIZE;

// TODO: need to make the read coalescing optimization cleaner
// pass coalesce_window_inner_reads as a compile time arg and num_coalesced_reads so we can constexpr the if
// currently works for the case of num_coalesced_reads == weight_size_w since these reads are contiguous on both src/dst side
Expand All @@ -109,13 +93,6 @@ void kernel_main() {

uint32_t start_reader_idx = 0;
for (uint32_t bh = 0; bh < act_num_blocks_h; bh++) {
#ifdef SPLIT_READER
if constexpr (cache_packed_reader_indices) {
for (uint32_t i = 0; i < act_block_h_datums_read; i++) {
local_packed_reader_indices[i] = packed_reader_indices_ptr[start_reader_idx+i];
}
}
#endif
for (uint32_t outer = 0; outer < window_outer; outer++) {
// Reset reader_idx to finish act_block_h_datums
reader_idx = start_reader_idx;
Expand All @@ -128,11 +105,7 @@ void kernel_main() {

for (uint32_t bhd = 0; bhd < act_block_h_datums_read_curr; bhd++) {
// local read from reader_index + reader_offset;
#ifdef SPLIT_READER
uint32_t two_reader_indices = cache_packed_reader_indices ? local_packed_reader_indices[bhd] : packed_reader_indices_ptr[reader_idx];
#else // no split reader
uint32_t two_reader_indices = packed_reader_indices_ptr[reader_idx];
#endif
uint32_t reader_idx_1 = two_reader_indices & 0xffff;
uint32_t reader_idx_2 = two_reader_indices >> 16;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

void kernel_main() {
// This writer is for output tensor in tile format
constexpr uint32_t LOCAL_PACKED_READER_INDICES_MAX_SIZE = 128;
uint32_t local_packed_reader_indices[LOCAL_PACKED_READER_INDICES_MAX_SIZE];

constexpr bool out_in_dram = get_compile_time_arg_val(0) == 1;
constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(1);
Expand Down Expand Up @@ -56,13 +54,10 @@ void kernel_main() {
constexpr uint32_t act_block_h_datums_first_reader = get_compile_time_arg_val(38);
constexpr uint32_t act_block_h_datums_last_block = get_compile_time_arg_val(39);

uint32_t act_block_h_datums_read_last_block;
if (act_block_h_datums_last_block > act_block_h_datums) {
act_block_h_datums_read_last_block = (act_block_h_datums_last_block - act_block_h_datums_first_reader) / 2;
} else {
act_block_h_datums_read_last_block = 0;
}

constexpr uint32_t act_block_h_datums_read_last_block =
act_block_h_datums_last_block > act_block_h_datums
? (act_block_h_datums_last_block - act_block_h_datums_first_reader) / 2
: 0;
constexpr uint32_t total_weight_num_tiles = weight_block_height_num_outer * num_blocks_weight_h * weight_block_num_tiles;

uint32_t i = 0;
Expand All @@ -86,8 +81,8 @@ void kernel_main() {
volatile tt_l1_ptr uint32_t* weights_mcast_receiver_semaphore_addr_ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(weights_mcast_receiver_semaphore_addr);
const uint64_t weights_mcast_sender_semaphore_noc_addr = get_noc_addr(weights_mcast_sender_noc_x, weights_mcast_sender_noc_y, weights_mcast_sender_semaphore_addr);

const uint32_t tile_nbytes = get_tile_size(cb_id_out0);
const DataFormat out_df = get_dataformat(cb_id_out0);
constexpr uint32_t tile_nbytes = get_tile_size(cb_id_out0);
constexpr DataFormat out_df = get_dataformat(cb_id_out0);

constexpr uint32_t cb_id_act_second_reader = 7;
constexpr uint32_t cb_id_sharded_act = 3;
Expand All @@ -98,20 +93,17 @@ void kernel_main() {
volatile tt_l1_ptr uint32_t* packed_reader_indices_ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(get_write_ptr(cb_reader_indices));
uint32_t reader_idx = 0;

// Copy packed reader indices to local memory for faster access
constexpr bool cache_packed_reader_indices = act_block_h_datums_read <= LOCAL_PACKED_READER_INDICES_MAX_SIZE;

const InterleavedAddrGenFast<out_in_dram> s = {
.bank_base_address = out_addr,
.page_size = tile_nbytes,
.data_format = out_df
};

// read in bias if enabled (done only once for all batches)
#ifdef FUSE_BIAS
// read in bias if enabled (done only once for all batches)
#ifdef FUSE_BIAS
constexpr uint32_t bias_cb_id = get_compile_time_arg_val(3);
bool load_bias = true;
#endif
#endif


// OUTER most loop is looping over out blocks in width dim because blocks from compute are in col major order.
Expand All @@ -136,48 +128,41 @@ void kernel_main() {
// read weight blocks inner dim
// read weight slice - 1 block of weights in width dim and full weight matrix height
// read slice only once for all activation blocks
if constexpr (cache_packed_reader_indices) {
for (uint32_t i = 0; i < act_block_h_datums_read; i++) {
local_packed_reader_indices[i] = packed_reader_indices_ptr[start_reader_idx+i];
}
}
if (read_weights) {

// TODO: Not sure how this loop works with the additional reader; we don't have a use case for this right now
for(uint32_t weight_tile_h_outer_i = 0; weight_tile_h_outer_i < weight_block_height_num_outer; weight_tile_h_outer_i++) {

uint32_t reader_offset = act_l1_read_addr;
for(uint32_t block_weight_h = 0; block_weight_h < num_blocks_weight_h; block_weight_h++) {
// Do the second half of the reads for act
noc_async_read_one_packet_set_state(get_noc_addr(act_l1_read_addr), coalesced_read_bytes);
reader_idx = start_reader_idx;
cb_reserve_back(cb_id_act_second_reader, act_block_num_tiles_read);
uint32_t l1_write_addr_act = get_write_ptr(cb_id_act_second_reader);
uint32_t act_block_h_datums_read_curr = bh == out_num_blocks_h - 1 ? act_block_h_datums_read_last_block : act_block_h_datums_read;
for (uint32_t bhd = 0; bhd < act_block_h_datums_read_curr; bhd++) {
// local read from reader_index + reader_offset;
uint32_t two_reader_indices = cache_packed_reader_indices ? local_packed_reader_indices[bhd] : packed_reader_indices_ptr[reader_idx];
uint32_t reader_idx_1 = two_reader_indices & 0xffff;
uint32_t reader_idx_2 = two_reader_indices >> 16;

act_l1_offset = reader_offset + (reader_idx_1 * conv_act_size_c_bytes);
noc_async_read_one_packet_with_state<true>(act_l1_offset, l1_write_addr_act);
l1_write_addr_act += (coalesced_read_bytes + act_block_w_extra_align_bytes);

act_l1_offset = reader_offset + (reader_idx_2 * conv_act_size_c_bytes);
noc_async_read_one_packet_with_state<true>(act_l1_offset, l1_write_addr_act);
l1_write_addr_act += (coalesced_read_bytes + act_block_w_extra_align_bytes);

reader_idx++;
}
noc_async_read_barrier();
cb_push_back(cb_id_act_second_reader, act_block_num_tiles_read);

reader_offset += window_outer_offset;
// TODO: Not sure how this loop works with the additional reader; we don't have a use case for this right now
for(uint32_t weight_tile_h_outer_i = 0; weight_tile_h_outer_i < weight_block_height_num_outer; weight_tile_h_outer_i++) {
uint32_t reader_offset = act_l1_read_addr;
for(uint32_t block_weight_h = 0; block_weight_h < num_blocks_weight_h; block_weight_h++) {
// Do the second half of the reads for act
noc_async_read_one_packet_set_state(get_noc_addr(act_l1_read_addr), coalesced_read_bytes);
reader_idx = start_reader_idx;
cb_reserve_back(cb_id_act_second_reader, act_block_num_tiles_read);
uint32_t l1_write_addr_act = get_write_ptr(cb_id_act_second_reader);
uint32_t act_block_h_datums_read_curr = bh == out_num_blocks_h - 1 ? act_block_h_datums_read_last_block : act_block_h_datums_read;
for (uint32_t bhd = 0; bhd < act_block_h_datums_read_curr; bhd++) {
// local read from reader_index + reader_offset;
uint32_t two_reader_indices = packed_reader_indices_ptr[reader_idx];
uint32_t reader_idx_1 = two_reader_indices & 0xffff;
uint32_t reader_idx_2 = two_reader_indices >> 16;

act_l1_offset = reader_offset + (reader_idx_1 * conv_act_size_c_bytes);
noc_async_read_one_packet_with_state<true>(act_l1_offset, l1_write_addr_act);
l1_write_addr_act += (coalesced_read_bytes + act_block_w_extra_align_bytes);

act_l1_offset = reader_offset + (reader_idx_2 * conv_act_size_c_bytes);
noc_async_read_one_packet_with_state<true>(act_l1_offset, l1_write_addr_act);
l1_write_addr_act += (coalesced_read_bytes + act_block_w_extra_align_bytes);

reader_idx++;
}
noc_async_read_barrier();
cb_push_back(cb_id_act_second_reader, act_block_num_tiles_read);

reader_offset += window_outer_offset;

// Receive weights
cb_reserve_back(cb_id_weight, weight_block_num_tiles);
// Receive weights
cb_reserve_back(cb_id_weight, weight_block_num_tiles);
if (bh == 0) {
// Set weights semaphore value to INVALID
noc_semaphore_set(weights_mcast_receiver_semaphore_addr_ptr, INVALID);

Expand All @@ -186,51 +171,14 @@ void kernel_main() {

// wait on weights semaphore value to become VALID (set by mcast sender after it multicasts data)
noc_semaphore_wait(weights_mcast_receiver_semaphore_addr_ptr, VALID);
}

cb_push_back(cb_id_weight, weight_block_num_tiles);
} // for num_blocks_weight_h
} // for weight_block_height_num_outer

read_weights = false;
} else {
cb_reserve_back(cb_id_weight, total_weight_num_tiles);
cb_push_back(cb_id_weight, total_weight_num_tiles);

noc_async_read_one_packet_set_state(get_noc_addr(act_l1_read_addr), coalesced_read_bytes);
uint32_t reader_offset = act_l1_read_addr;
for(uint32_t weight_tile_h_outer_i = 0; weight_tile_h_outer_i < weight_block_height_num_outer; weight_tile_h_outer_i++) {
for(uint32_t block_weight_h = 0; block_weight_h < num_blocks_weight_h; block_weight_h++) {
reader_idx = start_reader_idx;

// Do the second half of the reads for act
cb_reserve_back(cb_id_act_second_reader, act_block_num_tiles_read);
uint32_t l1_write_addr_act = get_write_ptr(cb_id_act_second_reader);
uint32_t act_block_h_datums_read_curr = bh == out_num_blocks_h - 1 ? act_block_h_datums_read_last_block : act_block_h_datums_read;
for (uint32_t bhd = 0; bhd < act_block_h_datums_read_curr; bhd++) {
// local read from reader_index + reader_offset;
uint32_t two_reader_indices = cache_packed_reader_indices ? local_packed_reader_indices[bhd] : packed_reader_indices_ptr[reader_idx];
uint32_t reader_idx_1 = two_reader_indices & 0xffff;
uint32_t reader_idx_2 = two_reader_indices >> 16;

act_l1_offset = reader_offset + (reader_idx_1 * conv_act_size_c_bytes);
noc_async_read_one_packet_with_state<true>(act_l1_offset, l1_write_addr_act);
l1_write_addr_act += (coalesced_read_bytes + act_block_w_extra_align_bytes);

act_l1_offset = reader_offset + (reader_idx_2 * conv_act_size_c_bytes);
noc_async_read_one_packet_with_state<true>(act_l1_offset, l1_write_addr_act);
l1_write_addr_act += (coalesced_read_bytes + act_block_w_extra_align_bytes);

reader_idx++;
}
noc_async_read_barrier();
cb_push_back(cb_id_act_second_reader, act_block_num_tiles_read);
cb_push_back(cb_id_weight, weight_block_num_tiles);

reader_offset += window_outer_offset;
}
}
} // for weight_block_height_num_outer
}

#ifdef FUSE_BIAS
#ifdef FUSE_BIAS
if (load_bias) {
cb_reserve_back(bias_cb_id, bias_ntiles);

Expand All @@ -246,9 +194,9 @@ void kernel_main() {
cb_push_back(bias_cb_id, bias_ntiles);
load_bias = false;
}
#endif
#endif

#ifndef SHARDED_OUT
#ifndef SHARDED_OUT
uint32_t out_sbh_start_tile_id = out_block_h_start_tile_id;
uint32_t out_sbh_start_tile_id_h = out_block_h_start_tile_id_h; //
for(uint32_t sbh = 0; sbh < out_num_subblocks_h; sbh++) {
Expand Down Expand Up @@ -291,7 +239,7 @@ void kernel_main() {
} // out_num_subblocks_h
out_block_h_start_tile_id += out_next_block_stride_h;
out_block_h_start_tile_id_h += out_block_height_num_tiles;
#endif
#endif

start_reader_idx = reader_idx + act_block_h_datums_read;
} // out_num_blocks_h
Expand All @@ -302,7 +250,7 @@ void kernel_main() {
weight_start_tile_id += weight_next_block_stride_w;
} // out_num_blocks_w

#ifdef SHARDED_OUT
#ifdef SHARDED_OUT
cb_wait_front(cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h);
#endif
#endif
}
Loading

0 comments on commit 8e0e2e1

Please sign in to comment.