Skip to content

Commit

Permalink
Fix the racing condition of mixed-input gemm when writing the registe…
Browse files Browse the repository at this point in the history
…rs (#1931)

* move two warpgroup_wait

* merge main

---------

Co-authored-by: Siyuan Fu <siyuanf@nvidia.com>
  • Loading branch information
IwakuraRein and Siyuan Fu authored Nov 8, 2024
1 parent d656afb commit 8aa95db
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -724,4 +724,4 @@ int main(int argc, char const **args) {
return 0;
}

/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
Original file line number Diff line number Diff line change
Expand Up @@ -1024,8 +1024,8 @@ struct CollectiveMma<
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
warpgroup_commit_batch();

warpgroup_wait<K_BLOCK_MAX - 1>(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, so we can release prior barrier
if (k_block == K_BLOCK_MAX - 1) {
warpgroup_wait<K_BLOCK_MAX - 1>(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, so we can release prior barrier
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
++smem_pipe_release;
}
Expand Down Expand Up @@ -1076,8 +1076,9 @@ struct CollectiveMma<
cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
warpgroup_commit_batch();

warpgroup_wait<K_BLOCK_MAX - 1>();
if (k_block == K_BLOCK_MAX - 1) { // release prior barrier
warpgroup_wait<K_BLOCK_MAX - 1>();
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
++smem_pipe_release;
}
Expand Down

0 comments on commit 8aa95db

Please sign in to comment.