From 165642cd4ea6940e7a8bc7e3a1e5d00646cb188e Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Mon, 24 Jun 2024 13:05:38 -0700 Subject: [PATCH] Fix merge attention grid size by passing M/seq_len as grid.x (#1141) --- xformers/ops/fmha/triton_splitk.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/xformers/ops/fmha/triton_splitk.py b/xformers/ops/fmha/triton_splitk.py index d1a95acbb2..60fe792f94 100644 --- a/xformers/ops/fmha/triton_splitk.py +++ b/xformers/ops/fmha/triton_splitk.py @@ -752,11 +752,12 @@ def _splitK_reduce( G: tl.constexpr, WRITE_LSE: tl.constexpr, ): - off_zhg = tl.program_id(0).to(tl.int64) + # grid = (M, B * G * H, 1) + off_m = tl.program_id(0).to(tl.int64) + off_zhg = tl.program_id(1).to(tl.int64) off_z = off_zhg // (H * G) off_h = (off_zhg // G) % H off_g = off_zhg % G - off_m = tl.program_id(1) Out_splitK_ptr = ( Out_splitK @@ -863,11 +864,12 @@ def _splitK_reduce_varargs( This version of reduce kernel takes attention and LSE of chunks as lists of tensors, as opposed to _splitK_reduce, which takes each as a stacked tensor. """ - off_zhg = tl.program_id(0).to(tl.int64) + # grid = (M, B * G * H, 1) + off_m = tl.program_id(0).to(tl.int64) + off_zhg = tl.program_id(1).to(tl.int64) off_z = off_zhg // (H * G) off_h = (off_zhg // G) % H off_g = off_zhg % G - off_m = tl.program_id(1) out_splitk_offset: "VAR_ARGS_ARRAY" # noqa: F821 for i in range(len(Out_splitK)): @@ -986,11 +988,12 @@ def _splitK_reduce_varargs_backward( and outputs the corresponding gradients in the same format. """ - off_zhg = tl.program_id(0).to(tl.int64) + # grid = (M, B * G * H, 1) + off_m = tl.program_id(0).to(tl.int64) + off_zhg = tl.program_id(1).to(tl.int64) off_z = off_zhg // (H * G) off_h = (off_zhg // G) % H off_g = off_zhg % G - off_m = tl.program_id(1) # Compute offsets inside each attention/LSE chunk. # Note that each chunk can have different strides, so offsets can also be different. @@ -1694,7 +1697,7 @@ def merge_attentions( num_warps = 4 if B * G * H < 32 or torch.version.hip else 2 splitK_pow2 = triton.next_power_of_2(split_k) - grid = (B * G * H, M, 1) + grid = (M, B * G * H, 1) _splitK_reduce[grid]( attn_split, lse_split, @@ -1820,7 +1823,7 @@ def _prepare_reduce_kernel_params( ) num_warps = 4 if B * G * H < 32 or torch.version.hip else 2 - grid = (B * G * H, M, 1) + grid = (M, B * G * H, 1) kernel_args = { "G": G,