Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize sharded tensor address generators #12223

Merged
merged 1 commit into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions tests/tt_eager/ops/ccl/test_ccl_tensor_slicers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ static void run_width_sharded_tensor_slice_indexer_get_page_location_test(
ASSERT_EQ(result.core_location.noc_x, worker_to_routing_x_wormhole.at(x_logical));
ASSERT_EQ(result.core_location.noc_y, worker_to_routing_y_wormhole.at(y_logical));
ASSERT_EQ(result.page_offset, px + (py * pages_per_shard_x));

auto const& result2 = addrgen.get_page_location_with_contiguous_pages_in_row_in_bank(page_id);
ASSERT_EQ(result2.core_location.noc_x, result.core_location.noc_x);
ASSERT_EQ(result2.core_location.noc_y, result.core_location.noc_y);
ASSERT_EQ(result2.page_offset, result.page_offset);
ASSERT_EQ(result2.contig_pages_in_row, pages_per_shard_x - px);
}

page_id++;
Expand All @@ -84,6 +90,12 @@ static void run_width_sharded_tensor_slice_indexer_get_page_location_test(
ASSERT_EQ(result.core_location.noc_x, worker_to_routing_x_wormhole.at(x_logical));
ASSERT_EQ(result.core_location.noc_y, worker_to_routing_y_wormhole.at(y_logical));
ASSERT_EQ(result.page_offset, px + (py * pages_per_shard_x));

auto const& result2 = addrgen.get_page_location_with_contiguous_pages_in_row_in_bank(page_id);
ASSERT_EQ(result2.core_location.noc_x, result.core_location.noc_x);
ASSERT_EQ(result2.core_location.noc_y, result.core_location.noc_y);
ASSERT_EQ(result2.page_offset, result.page_offset);
ASSERT_EQ(result2.contig_pages_in_row, pages_per_shard_x - px);
}
page_id++;
}
Expand Down Expand Up @@ -363,6 +375,12 @@ static void run_block_sharded_tensor_slice_indexer_get_page_location_test(
ASSERT_EQ(result.core_location.noc_x, worker_to_routing_x_wormhole.at(x_logical));
ASSERT_EQ(result.core_location.noc_y, worker_to_routing_y_wormhole.at(y_logical));
ASSERT_EQ(result.page_offset, px + (py * pages_per_shard_x));

auto const& result2 = addrgen.get_page_location_with_contiguous_pages_in_row_in_bank(page_id);
ASSERT_EQ(result2.core_location.noc_x, result.core_location.noc_x);
ASSERT_EQ(result2.core_location.noc_y, result.core_location.noc_y);
ASSERT_EQ(result2.page_offset, result.page_offset);
ASSERT_EQ(result2.contig_pages_in_row, pages_per_shard_x - px);
}

page_id++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ FORCE_INLINE void write_and_send_chunk_write_to_tensor_segment(
const uint32_t& page_size,
uint32_t l1_read_addr) {

for (uint32_t i = 0; i < num_pages; ++i) {
// for (uint32_t i = 0; i < num_pages; ++i) {
int32_t contig_pages = 1;
for (int32_t pages_remaining = num_pages; pages_remaining != 0; pages_remaining -= contig_pages) {
contig_pages = 1;
#ifdef ROW_MAJOR_LAYOUT
#ifdef INTERLEAVED_MEM_LAYOUT
uint64_t dst_noc_addr = get_noc_addr(output_page_idx, d);
Expand All @@ -53,12 +56,14 @@ FORCE_INLINE void write_and_send_chunk_write_to_tensor_segment(
noc_async_write_tile(output_page_idx, d, l1_read_addr);
#elif defined SHARDED_MEM_LAYOUT
// TODO: Make d.get_noc_addr work on host + device
auto const&[noc_yx, page_offset] = d.get_page_location(output_page_idx);
// auto const&[noc_yx, page_offset] = d.get_page_location(output_page_idx);
auto [noc_yx, page_offset, contig_pages_] = d.get_page_location_with_contiguous_pages_in_row_in_bank(output_page_idx);
contig_pages = std::min<int32_t>(pages_remaining, std::min<int32_t>(contig_pages_, num_cols - col_idx));
uint64_t dst_noc_addr = get_noc_addr(static_cast<uint32_t>(noc_yx.noc_x), noc_yx.noc_y, d.bank_base_address + (page_offset * d.page_size) + 0);
noc_async_write(l1_read_addr, dst_noc_addr, page_size);
noc_async_write(l1_read_addr, dst_noc_addr, page_size * contig_pages);
#endif
output_page_idx++;
col_idx++;
output_page_idx += contig_pages;
col_idx += contig_pages;
if (col_idx == num_cols) {
output_page_idx += col_offset;
col_idx = 0;
Expand All @@ -69,7 +74,7 @@ FORCE_INLINE void write_and_send_chunk_write_to_tensor_segment(
}
}
#endif
l1_read_addr += page_size;
l1_read_addr += page_size * contig_pages;
}
noc_async_write_barrier();
cb_pop_front(cb_id, num_pages);
Expand Down Expand Up @@ -160,7 +165,10 @@ FORCE_INLINE void write_chunk(
const uint32_t& page_size) {
cb_wait_front(cb_id, num_pages);
uint32_t l1_read_addr = get_read_ptr(cb_id);
for (uint32_t i = 0; i < num_pages; ++i) {
int32_t contig_pages = 1;

for (int32_t pages_remaining = num_pages; pages_remaining != 0; pages_remaining -= contig_pages) {
contig_pages = 1;
#ifdef ROW_MAJOR_LAYOUT
#ifdef INTERLEAVED_MEM_LAYOUT
uint64_t dst_noc_addr = get_noc_addr(output_page_idx, d);
Expand All @@ -181,15 +189,15 @@ FORCE_INLINE void write_chunk(
#ifdef INTERLEAVED_MEM_LAYOUT
noc_async_write_tile(output_page_idx, d, l1_read_addr);
#elif defined SHARDED_MEM_LAYOUT
auto const&[noc_yx, page_offset] = d.get_page_location(output_page_idx);

auto [noc_yx, page_offset, contig_pages_] = d.get_page_location_with_contiguous_pages_in_row_in_bank(output_page_idx);
contig_pages = std::min<int32_t>(pages_remaining, std::min<int32_t>(contig_pages_, num_cols - col_idx));
uint32_t local_address = d.bank_base_address + (page_offset * d.page_size) + 0;
uint64_t dst_noc_addr = get_noc_addr(static_cast<uint32_t>(noc_yx.noc_x), static_cast<uint32_t>(noc_yx.noc_y), local_address);
ASSERT(((dst_noc_addr >> 32) & 0xF) == 0);
noc_async_write(l1_read_addr, dst_noc_addr, page_size);
noc_async_write(l1_read_addr, dst_noc_addr, page_size * contig_pages);
#endif
output_page_idx++;
col_idx++;
output_page_idx += contig_pages;
col_idx += contig_pages;
if (col_idx == num_cols) {
output_page_idx += col_offset;
col_idx = 0;
Expand All @@ -200,7 +208,7 @@ FORCE_INLINE void write_chunk(
}
}
#endif
l1_read_addr += page_size;
l1_read_addr += page_size * contig_pages;
}
noc_async_write_barrier();
cb_pop_front(cb_id, num_pages);
Expand All @@ -217,7 +225,10 @@ FORCE_INLINE void read_chunk_from_input_tensor(
const uint32_t end_read_idx = input_page_idx + num_pages;
cb_reserve_back(cb_id, num_pages);
uint32_t local_l1_read_addr = get_write_ptr(cb_id);
for (; input_page_idx < end_read_idx; ++input_page_idx) {
int32_t contig_pages = 1;

for (int32_t pages_remaining = num_pages; pages_remaining != 0; pages_remaining -= contig_pages) {
contig_pages = 1;
#ifdef ROW_MAJOR_LAYOUT
// #ifdef INTERLEAVED_MEM_LAYOUT || defined SHARDED_MEM_LAYOUT
uint64_t src_noc_addr = get_noc_addr(input_page_idx, s);
Expand All @@ -227,12 +238,14 @@ FORCE_INLINE void read_chunk_from_input_tensor(
noc_async_read_tile(input_page_idx, s, local_l1_read_addr);
#elif defined SHARDED_MEM_LAYOUT
// TODO: Make d.get_noc_addr work on host + device
auto const&[noc_yx, page_offset] = s.get_page_location(input_page_idx);
auto const&[noc_yx, page_offset, contig_pages_] = s.get_page_location_with_contiguous_pages_in_row_in_bank(input_page_idx);
contig_pages = std::min<int32_t>(pages_remaining, contig_pages_);
uint64_t src_noc_addr = get_noc_addr(static_cast<uint32_t>(noc_yx.noc_x), static_cast<uint32_t>(noc_yx.noc_y), s.bank_base_address + (page_offset * s.page_size) + 0);
noc_async_read(src_noc_addr, local_l1_read_addr, page_size);
noc_async_read(src_noc_addr, local_l1_read_addr, page_size * contig_pages);
#endif
#endif
local_l1_read_addr += page_size;
local_l1_read_addr += (page_size * contig_pages);
input_page_idx += contig_pages;
}
noc_async_read_barrier();
cb_push_back(cb_id, num_pages);
Expand All @@ -254,7 +267,9 @@ FORCE_INLINE void read_chunk_from_output_tensor(
const uint32_t& page_size) {
cb_reserve_back(cb_id, num_pages);
uint32_t local_l1_read_addr = get_write_ptr(cb_id);
for (uint32_t i = 0; i < num_pages; ++i) {
uint32_t contig_pages = 1;
for (int32_t pages_remaining = num_pages; pages_remaining != 0; pages_remaining -= contig_pages) {
contig_pages = 1;
#ifdef ROW_MAJOR_LAYOUT
#ifdef INTERLEAVED_MEM_LAYOUT
uint64_t src_noc_addr = get_noc_addr(input_page_idx, s);
Expand All @@ -277,12 +292,13 @@ FORCE_INLINE void read_chunk_from_output_tensor(
noc_async_read_tile(input_page_idx, s, local_l1_read_addr);
#elif defined SHARDED_MEM_LAYOUT
// TODO: Make d.get_noc_addr work on host + device
auto const&[noc_yx, page_offset] = s.get_page_location(input_page_idx);
uint64_t src_noc_addr = get_noc_addr(static_cast<uint32_t>(noc_yx.noc_x), noc_yx.noc_y, s.bank_base_address + (page_offset * s.page_size) + 0);
noc_async_read(src_noc_addr, local_l1_read_addr, page_size);
auto [noc_yx, page_offset, contig_pages_] = s.get_page_location_with_contiguous_pages_in_row_in_bank(input_page_idx);
contig_pages = std::min<int32_t>(pages_remaining, std::min<int32_t>(contig_pages_, num_cols - col_idx));
uint64_t src_noc_addr = get_noc_addr(static_cast<uint32_t>(noc_yx.noc_x), static_cast<uint32_t>(noc_yx.noc_y), s.bank_base_address + (page_offset * s.page_size) + 0);
noc_async_read(src_noc_addr, local_l1_read_addr, page_size * contig_pages);
#endif
input_page_idx++;
col_idx++;
input_page_idx += contig_pages;
col_idx += contig_pages;
if (col_idx == num_cols) {
input_page_idx += col_offset;
col_idx = 0;
Expand All @@ -293,7 +309,7 @@ FORCE_INLINE void read_chunk_from_output_tensor(
}
}
#endif
local_l1_read_addr += page_size;
local_l1_read_addr += page_size * contig_pages;
}
noc_async_read_barrier();
cb_push_back(cb_id, num_pages);
Expand Down
Loading
Loading