Skip to content

Commit

Permalink
Fix for Sm75
Browse files Browse the repository at this point in the history
  • Loading branch information
danthe3rd committed Sep 5, 2022
1 parent 49cd011 commit 13cd214
Showing 1 changed file with 16 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ class MmaPipelinedFromSharedMemory
++warp_mma_k) {
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
bool hasNext = true;

if (warp_mma_k == Base::kWarpGemmIterations - 1) {
// Write fragments to shared memory
Expand All @@ -507,24 +508,28 @@ class MmaPipelinedFromSharedMemory
}

smem_write_stage_idx ^= 1;
hasNext = gemm_k_iterations > 1;
}

this->warp_tile_iterator_B_.set_kgroup_index(
(warp_mma_k + 1) % Base::kWarpGemmIterations);
// Only read the next if we need to
if (hasNext) {
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_mma_k + 1) % Base::kWarpGemmIterations);

this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);

++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;

if (warp_mma_k == 0) {
iterator_B.load(tb_frag_B);
if (warp_mma_k == 0) {
iterator_B.load(tb_frag_B);

++iterator_B;
++iterator_B;

// Avoid reading out of bounds if this was the last loop iteration
iterator_B.clear_mask(gemm_k_iterations <= 2);
// Avoid reading out of bounds if this was the last loop iteration
iterator_B.clear_mask(gemm_k_iterations <= 2);
}
}

warp_mma(
Expand Down

0 comments on commit 13cd214

Please sign in to comment.