Skip to content

Commit

Permalink
Fix merge attention grid size by passing M/seq_len as grid.x (faceboo…
Browse files Browse the repository at this point in the history
  • Loading branch information
jianyuh authored Jun 24, 2024
1 parent 44b8dd9 commit 165642c
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions xformers/ops/fmha/triton_splitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,11 +752,12 @@ def _splitK_reduce(
G: tl.constexpr,
WRITE_LSE: tl.constexpr,
):
off_zhg = tl.program_id(0).to(tl.int64)
# grid = (M, B * G * H, 1)
off_m = tl.program_id(0).to(tl.int64)
off_zhg = tl.program_id(1).to(tl.int64)
off_z = off_zhg // (H * G)
off_h = (off_zhg // G) % H
off_g = off_zhg % G
off_m = tl.program_id(1)

Out_splitK_ptr = (
Out_splitK
Expand Down Expand Up @@ -863,11 +864,12 @@ def _splitK_reduce_varargs(
This version of reduce kernel takes attention and LSE of chunks as lists of tensors,
as opposed to _splitK_reduce, which takes each as a stacked tensor.
"""
off_zhg = tl.program_id(0).to(tl.int64)
# grid = (M, B * G * H, 1)
off_m = tl.program_id(0).to(tl.int64)
off_zhg = tl.program_id(1).to(tl.int64)
off_z = off_zhg // (H * G)
off_h = (off_zhg // G) % H
off_g = off_zhg % G
off_m = tl.program_id(1)

out_splitk_offset: "VAR_ARGS_ARRAY" # noqa: F821
for i in range(len(Out_splitK)):
Expand Down Expand Up @@ -986,11 +988,12 @@ def _splitK_reduce_varargs_backward(
and outputs the corresponding gradients in the same format.
"""

off_zhg = tl.program_id(0).to(tl.int64)
# grid = (M, B * G * H, 1)
off_m = tl.program_id(0).to(tl.int64)
off_zhg = tl.program_id(1).to(tl.int64)
off_z = off_zhg // (H * G)
off_h = (off_zhg // G) % H
off_g = off_zhg % G
off_m = tl.program_id(1)

# Compute offsets inside each attention/LSE chunk.
# Note that each chunk can have different strides, so offsets can also be different.
Expand Down Expand Up @@ -1694,7 +1697,7 @@ def merge_attentions(

num_warps = 4 if B * G * H < 32 or torch.version.hip else 2
splitK_pow2 = triton.next_power_of_2(split_k)
grid = (B * G * H, M, 1)
grid = (M, B * G * H, 1)
_splitK_reduce[grid](
attn_split,
lse_split,
Expand Down Expand Up @@ -1820,7 +1823,7 @@ def _prepare_reduce_kernel_params(
)

num_warps = 4 if B * G * H < 32 or torch.version.hip else 2
grid = (B * G * H, M, 1)
grid = (M, B * G * H, 1)

kernel_args = {
"G": G,
Expand Down

0 comments on commit 165642c

Please sign in to comment.