diff --git a/src/flash_attn_jax/flash.py b/src/flash_attn_jax/flash.py index faebfc7..dd0ac13 100644 --- a/src/flash_attn_jax/flash.py +++ b/src/flash_attn_jax/flash.py @@ -19,6 +19,7 @@ from jax.sharding import PositionalSharding from einops import rearrange +import einops import math from .flash_sharding import _flash_mha_fwd_hlo_sharded, _flash_mha_bwd_hlo_sharded @@ -104,28 +105,91 @@ def _flash_mha_bwd_abstract(dout, q, k, v, out, lse, softmax_scale=None, is_caus # ==== VMap rules ==== def mha_fwd_batch(vector_arg_values, batch_axes, **kwargs): - assert tuple(batch_axes) == (0,0,0), "Only support vmapping mha over axis 0 for now," - [q, k, v] = vector_arg_values - [b, n, l, h, d] = q.shape - [b, n, lk, hk, d] = k.shape - assert [b, n, lk, hk, d] == list(v.shape) - out, lse = _flash_mha_fwd_p.bind(q.reshape([b*n,l,h,d]), - k.reshape([b*n,lk,hk,d]), - v.reshape([b*n,lk,hk,d]), - **kwargs) - return (out.reshape([b,n,*out.shape[1:]]), lse.reshape([b,n,*lse.shape[1:]])), (0,0) + assert all(isinstance(b, int) or b is None for b in batch_axes) + mapped = tuple(isinstance(b, int) for b in batch_axes) + if mapped == (True, True, True): + x = vector_arg_values[0].shape[batch_axes[0]] + def squish(val, axis): + dims = ['n', 'l', 'h', 'd'] + dims.insert(axis, 'x') + dims = ' '.join(dims) + return einops.rearrange(val, f'{dims} -> (x n) l h d') + def unsquish(val): + return einops.rearrange(val, f'(x n) ... -> x n ...', x=x) + [q, k, v] = [squish(x, axis) for x, axis in zip(vector_arg_values, batch_axes)] + out, lse = _flash_mha_fwd_p.bind(q, k, v, **kwargs) + return (unsquish(out), unsquish(lse)), (0,0) + elif mapped == (True, False, False): + # This is just a GQA! + x = vector_arg_values[0].shape[batch_axes[0]] + def squish(val, axis): + if axis is None: + return val + dims = ['n', 'l', 'h', 'd'] + dims.insert(axis, 'x') + dims = ' '.join(dims) + return einops.rearrange(val, f'{dims} -> n l (h x) d') + def unsquish(val): + return einops.rearrange(val, 'n l (h x) d -> x n l h d', x=x) + [q, k, v] = [squish(x, axis) for x, axis in zip(vector_arg_values, batch_axes)] + out, lse = _flash_mha_fwd_p.bind(q, k, v, **kwargs) + out = einops.rearrange(out, 'n l (h x) d -> x n l h d', x=x) + lse = einops.rearrange(lse, 'n (h x) l -> x n h l', x=x) + return (out, lse), (0,0) + else: + raise NotImplementedError("MHA fwd only support vmapping over q or (q,k,v) for now, got batch axes " + str(batch_axes)) def mha_bwd_batch(vector_arg_values, batch_axes, **kwargs): - assert tuple(batch_axes) == (0,0,0,0,0,0), "Only support vmapping mha over axis 0 for now," - dout, q, k, v, out, lse = vector_arg_values - b = dout.shape[batch_axes[0]] - def join(*args): - return [rearrange(a, 'b n ... -> (b n) ...') for a in args] - def unjoin(*args): - return [rearrange(a, '(b n) ... -> b n ...', b=b) for a in args] - dq, dk, dv = _flash_mha_bwd_p.bind(*join(dout,q,k,v,out,lse), - **kwargs) - return tuple(unjoin(dq,dk,dv)), (0,0,0) + assert all(isinstance(b, int) or b is None for b in batch_axes) + mapped = tuple(isinstance(b, int) for b in batch_axes) + if mapped == (True, True, True, True, True, True): + x = vector_arg_values[0].shape[batch_axes[0]] + def squish(val, axis): + if len(val.shape) == 5: + # q/k/v/o + dims = ['n', 'l', 'h', 'd'] + dims.insert(axis, 'x') + dims = ' '.join(dims) + return einops.rearrange(val, f'{dims} -> (x n) l h d') + elif len(val.shape) == 4: + # lse + dims = ['n', 'h', 'l'] + dims.insert(axis, 'x') + dims = ' '.join(dims) + return einops.rearrange(val, f'{dims} -> (x n) h l') + do, q, k, v, o, lse = [squish(x, axis) for x, axis in zip(vector_arg_values, batch_axes)] + dq, dk, dv = _flash_mha_bwd_p.bind(do, q, k, v, o, lse, **kwargs) + dq = einops.rearrange(dq, '(n x) l h d -> x n l h d', x=x) + dk = einops.rearrange(dk, '(n x) l h d -> x n l h d', x=x) + dv = einops.rearrange(dv, '(n x) l h d -> x n l h d', x=x) + return (dq,dk,dv), (0,0,0) + elif mapped == (True, True, False, False, True, True): + # Everything is mapped except k and v, which is a GQA backward + x = vector_arg_values[0].shape[batch_axes[0]] + def squish(val, axis): + if len(val.shape) == 5: + # q/k/v/o + dims = ['n', 'l', 'h', 'd'] + dims.insert(axis, 'x') + dims = ' '.join(dims) + return einops.rearrange(val, f'{dims} -> n l (h x) d') + elif len(val.shape) == 4: + # lse + dims = ['n', 'h', 'l'] + dims.insert(axis, 'x') + dims = ' '.join(dims) + return einops.rearrange(val, f'{dims} -> n (h x) l') + do = squish(vector_arg_values[0], batch_axes[0]) + q = squish(vector_arg_values[1], batch_axes[1]) + k = vector_arg_values[2] + v = vector_arg_values[3] + o = squish(vector_arg_values[4], batch_axes[4]) + lse = squish(vector_arg_values[5], batch_axes[5]) + dq, dk, dv = _flash_mha_bwd_p.bind(do, q, k, v, o, lse, **kwargs) + dq = einops.rearrange(dq, 'n l (h x) d -> x n l h d', x=x) + return (dq,dk,dv), (0,None,None) + else: + raise NotImplementedError("MHA bwd only support vmapping over q or (q,k,v) for now, got batch axes " + str(batch_axes)) batching.primitive_batchers[_flash_mha_fwd_p] = mha_fwd_batch batching.primitive_batchers[_flash_mha_bwd_p] = mha_bwd_batch diff --git a/tests/test_flash.py b/tests/test_flash.py index be62705..ed3d764 100644 --- a/tests/test_flash.py +++ b/tests/test_flash.py @@ -4,6 +4,7 @@ sys.path.insert(0, glob.glob('build/lib.linux-*')[0]) sys.path.insert(0,'./src') +from functools import partial import pytest import jax import jax.numpy as jnp @@ -100,22 +101,114 @@ def test_flash_fwd_vmap(n, seqlen, h, d, causal, local, dtype): k = jax.random.normal(jax.random.PRNGKey(1), [x, n, seqlen, h, d], dtype=jnp.float32) v = jax.random.normal(jax.random.PRNGKey(2), [x, n, seqlen, h, d], dtype=jnp.float32) - @jax.jit def ref(q,k,v): return ref_mha(q,k,v, is_causal=bool(causal), window_size=window_size) - @jax.jit def flash(q,k,v): return flash_mha(q,k,v, is_causal=bool(causal), window_size=window_size) - ref_out = jnp.stack([ref(q[i],k[i],v[i]) for i in range(x)]) + ref_out = jax.vmap(ref)(q,k,v) q = q.astype(dtype) k = k.astype(dtype) v = v.astype(dtype) - f16_out = jnp.stack([ref(q[i],k[i],v[i]) for i in range(x)]) - + f16_out = jax.vmap(ref)(q,k,v) out = jax.vmap(flash)(q,k,v) check(ref_out, f16_out, out) +@pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16]) +@pytest.mark.parametrize("local", ['local','']) +@pytest.mark.parametrize("causal", ['causal','']) +@pytest.mark.parametrize("d", [59, 32]) +@pytest.mark.parametrize("h", [1, 4]) +@pytest.mark.parametrize("seqlen", [97, 128]) +@pytest.mark.parametrize("n", [1]) +def test_flash_fwd_vmapq(n, seqlen, h, d, causal, local, dtype): + window_size = (3,3) if local else (-1,-1) + + x = 4 + q = jax.random.normal(jax.random.PRNGKey(0), [x, n, seqlen, h, d], dtype=jnp.float32) + k = jax.random.normal(jax.random.PRNGKey(1), [n, seqlen, h, d], dtype=jnp.float32) + v = jax.random.normal(jax.random.PRNGKey(2), [n, seqlen, h, d], dtype=jnp.float32) + + def ref(q,k,v): + return ref_mha(q,k,v, is_causal=bool(causal), window_size=window_size) + def flash(q,k,v): + return flash_mha(q,k,v, is_causal=bool(causal), window_size=window_size) + + ref_out = jax.vmap(ref, in_axes=(0,None,None))(q,k,v) + q = q.astype(dtype) + k = k.astype(dtype) + v = v.astype(dtype) + f16_out = jax.vmap(ref, in_axes=(0,None,None))(q,k,v) + + out = jax.vmap(flash, in_axes=(0,None,None))(q,k,v) + check(ref_out, f16_out, out) + +@pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16]) +@pytest.mark.parametrize("local", ['local','']) +@pytest.mark.parametrize("causal", ['causal','']) +@pytest.mark.parametrize("d", [59, 32]) +@pytest.mark.parametrize("h", [1, 4]) +@pytest.mark.parametrize("seqlen", [97, 128]) +@pytest.mark.parametrize("n", [1]) +def test_flash_bwd_vmap(n, seqlen, h, d, causal, local, dtype): + window_size = (3,3) if local else (-1,-1) + + x = 4 + q = jax.random.normal(jax.random.PRNGKey(0), [x, n, seqlen, h, d], dtype=jnp.float32) + k = jax.random.normal(jax.random.PRNGKey(1), [x, n, seqlen, h, d], dtype=jnp.float32) + v = jax.random.normal(jax.random.PRNGKey(2), [x, n, seqlen, h, d], dtype=jnp.float32) + do = jax.random.normal(jax.random.PRNGKey(3), [x, n, seqlen, h, d], dtype=jnp.float32) + + def func(mha, q,k,v): + @partial(jax.vmap, in_axes=(0,0,0)) + def fwd(q,k,v): + return mha(q,k,v, is_causal=bool(causal), window_size=window_size) + o, bwd = jax.vjp(fwd,q,k,v) + return bwd(do) + + ref_out = func(ref_mha, q,k,v) + q = q.astype(dtype) + k = k.astype(dtype) + v = v.astype(dtype) + do = do.astype(dtype) + f16_out = func(ref_mha, q,k,v) + + out = func(flash_mha, q,k,v) + check(ref_out, f16_out, out) + +@pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16]) +@pytest.mark.parametrize("local", ['local','']) +@pytest.mark.parametrize("causal", ['causal','']) +@pytest.mark.parametrize("d", [59, 32]) +@pytest.mark.parametrize("h", [1, 4]) +@pytest.mark.parametrize("seqlen", [97, 128]) +@pytest.mark.parametrize("n", [1]) +def test_flash_bwd_vmapq(n, seqlen, h, d, causal, local, dtype): + window_size = (3,3) if local else (-1,-1) + + x = 4 + q = jax.random.normal(jax.random.PRNGKey(0), [x, n, seqlen, h, d], dtype=jnp.float32) + k = jax.random.normal(jax.random.PRNGKey(1), [n, seqlen, h, d], dtype=jnp.float32) + v = jax.random.normal(jax.random.PRNGKey(2), [n, seqlen, h, d], dtype=jnp.float32) + do = jax.random.normal(jax.random.PRNGKey(3), [x, n, seqlen, h, d], dtype=jnp.float32) + + def func(mha, q,k,v): + @partial(jax.vmap, in_axes=(0,None,None)) + def fwd(q,k,v): + return mha(q,k,v, is_causal=bool(causal), window_size=window_size) + o, bwd = jax.vjp(fwd,q,k,v) + return bwd(do) + + ref_out = func(ref_mha, q,k,v) + q = q.astype(dtype) + k = k.astype(dtype) + v = v.astype(dtype) + do = do.astype(dtype) + f16_out = func(ref_mha, q,k,v) + + out = func(flash_mha, q,k,v) + check(ref_out, f16_out, out) + if __name__ == '__main__': test_flash_bwd(1,4,1,32,4,False,False,jnp.float16)