Skip to content

Commit

Permalink
#0: Remove explicit initialization of NOC1 local state in cq_dispatch…
Browse files Browse the repository at this point in the history
…_async_handler
  • Loading branch information
tt-asaigal committed Sep 23, 2024
1 parent 551da4b commit 03f8b1a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 11 deletions.
8 changes: 4 additions & 4 deletions tests/ttnn/unit_tests/test_multi_device_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
)
@pytest.mark.parametrize("use_all_gather", [True])
@pytest.mark.parametrize("enable_async", [True])
@pytest.mark.parametrize("enable_multi_cq", [True])
@pytest.mark.parametrize("device_params", [{"trace_region_size": 60000, "num_command_queues": 2}], indirect=True)
@pytest.mark.parametrize("enable_multi_cq", [False])
@pytest.mark.parametrize("device_params", [{"trace_region_size": 60000}], indirect=True)
def test_multi_device_single_trace(t3k_mesh_device, shape, use_all_gather, enable_async, enable_multi_cq):
if t3k_mesh_device.get_num_devices() <= 1:
pytest.skip("This test requires multiple devices")
Expand Down Expand Up @@ -134,8 +134,8 @@ def event_sync(event, record_cq, wait_cq):
)
@pytest.mark.parametrize("use_all_gather", [True, False])
@pytest.mark.parametrize("enable_async", [True])
@pytest.mark.parametrize("enable_multi_cq", [True])
@pytest.mark.parametrize("device_params", [{"trace_region_size": 200000, "num_command_queues": 2}], indirect=True)
@pytest.mark.parametrize("enable_multi_cq", [False])
@pytest.mark.parametrize("device_params", [{"trace_region_size": 200000}], indirect=True)
def test_multi_device_multi_trace(t3k_mesh_device, shape, use_all_gather, enable_async, enable_multi_cq):
torch.manual_seed(0)
if t3k_mesh_device.get_num_devices() <= 1:
Expand Down
43 changes: 36 additions & 7 deletions tt_metal/impl/dispatch/kernels/cq_dispatch_async_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "tt_metal/impl/dispatch/kernels/cq_common.hpp"
#include "tt_metal/impl/dispatch/kernels/packet_queue_ctrl.hpp"

constexpr uint32_t DISPATCH_S_ATOMIC_CMD_BUF = 2;
constexpr uint32_t cb_base = get_compile_time_arg_val(0);
constexpr uint32_t cb_log_page_size = get_compile_time_arg_val(1);
constexpr uint32_t cb_size = get_compile_time_arg_val(2);
Expand Down Expand Up @@ -52,6 +53,26 @@ constexpr uint8_t send_unicast = 0x2;

// Initialize the go_signal data that will be sent to workers over NOC1 in L1
uint32_t aligned_go_signal __attribute__((aligned(16))) __attribute__((section("l1_data"))) __attribute__((used)) = RUN_MSG_GO;
uint32_t aligned_worker_update __attribute__((aligned(16))) __attribute__((section("l1_data"))) __attribute__((used)) = 0;

FORCE_INLINE
void dispatch_s_atomic_cmd_buf_init() {
uint64_t atomic_ret_addr = NOC_XY_ADDR(my_x, my_y, (uint32_t)(&atomic_ret_val));
NOC_CMD_BUF_WRITE_REG(my_noc_index, DISPATCH_S_ATOMIC_CMD_BUF, NOC_RET_ADDR_LO, (uint32_t)(atomic_ret_addr & 0xFFFFFFFF));
NOC_CMD_BUF_WRITE_REG(my_noc_index, DISPATCH_S_ATOMIC_CMD_BUF, NOC_RET_ADDR_COORDINATE, (uint32_t)(atomic_ret_addr >> NOC_ADDR_COORD_SHIFT));
}

FORCE_INLINE
void dispatch_s_noc_semaphore_inc(uint64_t addr, uint32_t incr, uint8_t noc_id) {
// dispatch_s specific atomic inc API, which will use DISPATCH_S_ATOMIC_CMD_BUF to ensure that
// ncrisc and brisc don't clobber each other's resources when dispatch_s and dispatch_d are on
// the same tensix core
WAYPOINT("NSIW");
DEBUG_SANITIZE_NOC_ADDR(noc_id, addr, 4);
DEBUG_INSERT_DELAY(TransactionAtomic);
noc_fast_atomic_increment(noc_id, DISPATCH_S_ATOMIC_CMD_BUF, addr, NOC_UNICAST_WRITE_VC, incr, 31 /*wrap*/, false /*linked*/, false /*posted*/);
WAYPOINT("NSID");
}

FORCE_INLINE
uint32_t wrapped_distance(uint32_t num_pages_released, uint32_t num_pages_acquired) {
Expand All @@ -69,10 +90,10 @@ FORCE_INLINE
void update_worker_completion_count_on_dispatch_d() {
if constexpr(distributed_dispatcher) {
uint32_t num_workers_signalling_completion = *reinterpret_cast<volatile tt_l1_ptr uint32_t*>(worker_sem_addr);
if (num_workers_signalling_completion != curr_num_workers_completed) {
curr_num_workers_completed = num_workers_signalling_completion;
if (num_workers_signalling_completion != aligned_worker_update) {
aligned_worker_update = num_workers_signalling_completion;
uint64_t dispatch_d_dst = get_noc_addr_helper(dispatch_d_noc_xy, worker_sem_addr);
noc_async_write_one_packet(worker_sem_addr, dispatch_d_dst, sizeof(uint32_t));
noc_async_write_one_packet((uint32_t)(&aligned_worker_update), dispatch_d_dst, sizeof(uint32_t));
}
}
}
Expand All @@ -98,6 +119,12 @@ void cb_acquire_pages_dispatch_s(uint32_t n) {
num_pages_acquired += n;
}

template<uint32_t noc_xy, uint32_t sem_id>
FORCE_INLINE
void cb_release_pages_dispatch_s(uint32_t n) {
dispatch_s_noc_semaphore_inc(get_noc_addr_helper(noc_xy, get_semaphore<fd_core_type>(sem_id)), n, my_noc_index);
}

FORCE_INLINE
void process_go_signal_mcast_cmd() {
volatile CQDispatchCmd tt_l1_ptr *cmd = (volatile CQDispatchCmd tt_l1_ptr *)cmd_ptr;
Expand Down Expand Up @@ -156,15 +183,17 @@ void process_dispatch_s_wait_cmd() {
while (*worker_sem < cmd->wait.count);
// Send updated worker count to dispatch_d
update_worker_completion_count_on_dispatch_d();
// Wait for updated count to get written and then clear the counter.
// Wait for updated count to get picked up by NOC and then clear the counter.
// dispatch_d will clear its own counter
while(!ncrisc_noc_nonposted_writes_flushed(1));
noc_async_write_barrier();
*worker_sem = 0;
aligned_worker_update = 0; // Local worker count should reflect state of worker semaphore
cmd_ptr += sizeof(CQDispatchCmd);
}

void kernel_main() {
// DPRINT << "Dispatch Handler Started: " << cb_base << " " << cb_end << ENDL();
noc_local_state_init(1);
// dispatch_s_atomic_cmd_buf_init();
cmd_ptr = cb_base;
bool done = false;
while(!done) {
Expand Down Expand Up @@ -194,7 +223,7 @@ void kernel_main() {
}
cmd_ptr = round_up_pow2(cmd_ptr, cb_page_size);
// Release a single page to prefetcher. Assumption is that all dispatch_s commands fit inside a single page for now.
cb_release_pages<my_noc_index, upstream_noc_xy, upstream_dispatch_cb_sem_id>(1);
cb_release_pages_dispatch_s<upstream_noc_xy, upstream_dispatch_cb_sem_id>(1);
if (cmd_ptr == cb_end) {
cmd_ptr = cb_base;
}
Expand Down

0 comments on commit 03f8b1a

Please sign in to comment.