diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu index 8a99cc275..6a3a8f806 100644 --- a/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu @@ -724,4 +724,4 @@ int main(int argc, char const **args) { return 0; } -///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp index 8c838df12..a3efc67e8 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -1024,8 +1024,8 @@ struct CollectiveMma< tiled_mma.accumulate_ = GMMA::ScaleOut::One; warpgroup_commit_batch(); + warpgroup_wait(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, so we can release prior barrier if (k_block == K_BLOCK_MAX - 1) { - warpgroup_wait(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, so we can release prior barrier pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it ++smem_pipe_release; } @@ -1076,8 +1076,9 @@ struct CollectiveMma< cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); tiled_mma.accumulate_ = GMMA::ScaleOut::One; warpgroup_commit_batch(); + + warpgroup_wait(); if (k_block == K_BLOCK_MAX - 1) { // release prior barrier - warpgroup_wait(); pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it ++smem_pipe_release; }