Skip to content

Commit

Permalink
Expanded vmap support for flash_mha. Vmapping q but not k,v reduces t…
Browse files Browse the repository at this point in the history
…o a grouped-query attention, which we now support.
  • Loading branch information
nshepperd committed May 1, 2024
1 parent 4367317 commit d43cbca
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 25 deletions.
104 changes: 84 additions & 20 deletions src/flash_attn_jax/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from jax.sharding import PositionalSharding

from einops import rearrange
import einops
import math

from .flash_sharding import _flash_mha_fwd_hlo_sharded, _flash_mha_bwd_hlo_sharded
Expand Down Expand Up @@ -104,28 +105,91 @@ def _flash_mha_bwd_abstract(dout, q, k, v, out, lse, softmax_scale=None, is_caus
# ==== VMap rules ====

def mha_fwd_batch(vector_arg_values, batch_axes, **kwargs):
assert tuple(batch_axes) == (0,0,0), "Only support vmapping mha over axis 0 for now,"
[q, k, v] = vector_arg_values
[b, n, l, h, d] = q.shape
[b, n, lk, hk, d] = k.shape
assert [b, n, lk, hk, d] == list(v.shape)
out, lse = _flash_mha_fwd_p.bind(q.reshape([b*n,l,h,d]),
k.reshape([b*n,lk,hk,d]),
v.reshape([b*n,lk,hk,d]),
**kwargs)
return (out.reshape([b,n,*out.shape[1:]]), lse.reshape([b,n,*lse.shape[1:]])), (0,0)
assert all(isinstance(b, int) or b is None for b in batch_axes)
mapped = tuple(isinstance(b, int) for b in batch_axes)
if mapped == (True, True, True):
x = vector_arg_values[0].shape[batch_axes[0]]
def squish(val, axis):
dims = ['n', 'l', 'h', 'd']
dims.insert(axis, 'x')
dims = ' '.join(dims)
return einops.rearrange(val, f'{dims} -> (x n) l h d')
def unsquish(val):
return einops.rearrange(val, f'(x n) ... -> x n ...', x=x)
[q, k, v] = [squish(x, axis) for x, axis in zip(vector_arg_values, batch_axes)]
out, lse = _flash_mha_fwd_p.bind(q, k, v, **kwargs)
return (unsquish(out), unsquish(lse)), (0,0)
elif mapped == (True, False, False):
# This is just a GQA!
x = vector_arg_values[0].shape[batch_axes[0]]
def squish(val, axis):
if axis is None:
return val
dims = ['n', 'l', 'h', 'd']
dims.insert(axis, 'x')
dims = ' '.join(dims)
return einops.rearrange(val, f'{dims} -> n l (h x) d')
def unsquish(val):
return einops.rearrange(val, 'n l (h x) d -> x n l h d', x=x)
[q, k, v] = [squish(x, axis) for x, axis in zip(vector_arg_values, batch_axes)]
out, lse = _flash_mha_fwd_p.bind(q, k, v, **kwargs)
out = einops.rearrange(out, 'n l (h x) d -> x n l h d', x=x)
lse = einops.rearrange(lse, 'n (h x) l -> x n h l', x=x)
return (out, lse), (0,0)
else:
raise NotImplementedError("MHA fwd only support vmapping over q or (q,k,v) for now, got batch axes " + str(batch_axes))

def mha_bwd_batch(vector_arg_values, batch_axes, **kwargs):
assert tuple(batch_axes) == (0,0,0,0,0,0), "Only support vmapping mha over axis 0 for now,"
dout, q, k, v, out, lse = vector_arg_values
b = dout.shape[batch_axes[0]]
def join(*args):
return [rearrange(a, 'b n ... -> (b n) ...') for a in args]
def unjoin(*args):
return [rearrange(a, '(b n) ... -> b n ...', b=b) for a in args]
dq, dk, dv = _flash_mha_bwd_p.bind(*join(dout,q,k,v,out,lse),
**kwargs)
return tuple(unjoin(dq,dk,dv)), (0,0,0)
assert all(isinstance(b, int) or b is None for b in batch_axes)
mapped = tuple(isinstance(b, int) for b in batch_axes)
if mapped == (True, True, True, True, True, True):
x = vector_arg_values[0].shape[batch_axes[0]]
def squish(val, axis):
if len(val.shape) == 5:
# q/k/v/o
dims = ['n', 'l', 'h', 'd']
dims.insert(axis, 'x')
dims = ' '.join(dims)
return einops.rearrange(val, f'{dims} -> (x n) l h d')
elif len(val.shape) == 4:
# lse
dims = ['n', 'h', 'l']
dims.insert(axis, 'x')
dims = ' '.join(dims)
return einops.rearrange(val, f'{dims} -> (x n) h l')
do, q, k, v, o, lse = [squish(x, axis) for x, axis in zip(vector_arg_values, batch_axes)]
dq, dk, dv = _flash_mha_bwd_p.bind(do, q, k, v, o, lse, **kwargs)
dq = einops.rearrange(dq, '(n x) l h d -> x n l h d', x=x)
dk = einops.rearrange(dk, '(n x) l h d -> x n l h d', x=x)
dv = einops.rearrange(dv, '(n x) l h d -> x n l h d', x=x)
return (dq,dk,dv), (0,0,0)
elif mapped == (True, True, False, False, True, True):
# Everything is mapped except k and v, which is a GQA backward
x = vector_arg_values[0].shape[batch_axes[0]]
def squish(val, axis):
if len(val.shape) == 5:
# q/k/v/o
dims = ['n', 'l', 'h', 'd']
dims.insert(axis, 'x')
dims = ' '.join(dims)
return einops.rearrange(val, f'{dims} -> n l (h x) d')
elif len(val.shape) == 4:
# lse
dims = ['n', 'h', 'l']
dims.insert(axis, 'x')
dims = ' '.join(dims)
return einops.rearrange(val, f'{dims} -> n (h x) l')
do = squish(vector_arg_values[0], batch_axes[0])
q = squish(vector_arg_values[1], batch_axes[1])
k = vector_arg_values[2]
v = vector_arg_values[3]
o = squish(vector_arg_values[4], batch_axes[4])
lse = squish(vector_arg_values[5], batch_axes[5])
dq, dk, dv = _flash_mha_bwd_p.bind(do, q, k, v, o, lse, **kwargs)
dq = einops.rearrange(dq, 'n l (h x) d -> x n l h d', x=x)
return (dq,dk,dv), (0,None,None)
else:
raise NotImplementedError("MHA bwd only support vmapping over q or (q,k,v) for now, got batch axes " + str(batch_axes))

batching.primitive_batchers[_flash_mha_fwd_p] = mha_fwd_batch
batching.primitive_batchers[_flash_mha_bwd_p] = mha_bwd_batch
Expand Down
103 changes: 98 additions & 5 deletions tests/test_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
sys.path.insert(0, glob.glob('build/lib.linux-*')[0])
sys.path.insert(0,'./src')

from functools import partial
import pytest
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -100,22 +101,114 @@ def test_flash_fwd_vmap(n, seqlen, h, d, causal, local, dtype):
k = jax.random.normal(jax.random.PRNGKey(1), [x, n, seqlen, h, d], dtype=jnp.float32)
v = jax.random.normal(jax.random.PRNGKey(2), [x, n, seqlen, h, d], dtype=jnp.float32)

@jax.jit
def ref(q,k,v):
return ref_mha(q,k,v, is_causal=bool(causal), window_size=window_size)
@jax.jit
def flash(q,k,v):
return flash_mha(q,k,v, is_causal=bool(causal), window_size=window_size)

ref_out = jnp.stack([ref(q[i],k[i],v[i]) for i in range(x)])
ref_out = jax.vmap(ref)(q,k,v)
q = q.astype(dtype)
k = k.astype(dtype)
v = v.astype(dtype)
f16_out = jnp.stack([ref(q[i],k[i],v[i]) for i in range(x)])

f16_out = jax.vmap(ref)(q,k,v)

out = jax.vmap(flash)(q,k,v)
check(ref_out, f16_out, out)

@pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16])
@pytest.mark.parametrize("local", ['local',''])
@pytest.mark.parametrize("causal", ['causal',''])
@pytest.mark.parametrize("d", [59, 32])
@pytest.mark.parametrize("h", [1, 4])
@pytest.mark.parametrize("seqlen", [97, 128])
@pytest.mark.parametrize("n", [1])
def test_flash_fwd_vmapq(n, seqlen, h, d, causal, local, dtype):
window_size = (3,3) if local else (-1,-1)

x = 4
q = jax.random.normal(jax.random.PRNGKey(0), [x, n, seqlen, h, d], dtype=jnp.float32)
k = jax.random.normal(jax.random.PRNGKey(1), [n, seqlen, h, d], dtype=jnp.float32)
v = jax.random.normal(jax.random.PRNGKey(2), [n, seqlen, h, d], dtype=jnp.float32)

def ref(q,k,v):
return ref_mha(q,k,v, is_causal=bool(causal), window_size=window_size)
def flash(q,k,v):
return flash_mha(q,k,v, is_causal=bool(causal), window_size=window_size)

ref_out = jax.vmap(ref, in_axes=(0,None,None))(q,k,v)
q = q.astype(dtype)
k = k.astype(dtype)
v = v.astype(dtype)
f16_out = jax.vmap(ref, in_axes=(0,None,None))(q,k,v)

out = jax.vmap(flash, in_axes=(0,None,None))(q,k,v)
check(ref_out, f16_out, out)

@pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16])
@pytest.mark.parametrize("local", ['local',''])
@pytest.mark.parametrize("causal", ['causal',''])
@pytest.mark.parametrize("d", [59, 32])
@pytest.mark.parametrize("h", [1, 4])
@pytest.mark.parametrize("seqlen", [97, 128])
@pytest.mark.parametrize("n", [1])
def test_flash_bwd_vmap(n, seqlen, h, d, causal, local, dtype):
window_size = (3,3) if local else (-1,-1)

x = 4
q = jax.random.normal(jax.random.PRNGKey(0), [x, n, seqlen, h, d], dtype=jnp.float32)
k = jax.random.normal(jax.random.PRNGKey(1), [x, n, seqlen, h, d], dtype=jnp.float32)
v = jax.random.normal(jax.random.PRNGKey(2), [x, n, seqlen, h, d], dtype=jnp.float32)
do = jax.random.normal(jax.random.PRNGKey(3), [x, n, seqlen, h, d], dtype=jnp.float32)

def func(mha, q,k,v):
@partial(jax.vmap, in_axes=(0,0,0))
def fwd(q,k,v):
return mha(q,k,v, is_causal=bool(causal), window_size=window_size)
o, bwd = jax.vjp(fwd,q,k,v)
return bwd(do)

ref_out = func(ref_mha, q,k,v)
q = q.astype(dtype)
k = k.astype(dtype)
v = v.astype(dtype)
do = do.astype(dtype)
f16_out = func(ref_mha, q,k,v)

out = func(flash_mha, q,k,v)
check(ref_out, f16_out, out)

@pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16])
@pytest.mark.parametrize("local", ['local',''])
@pytest.mark.parametrize("causal", ['causal',''])
@pytest.mark.parametrize("d", [59, 32])
@pytest.mark.parametrize("h", [1, 4])
@pytest.mark.parametrize("seqlen", [97, 128])
@pytest.mark.parametrize("n", [1])
def test_flash_bwd_vmapq(n, seqlen, h, d, causal, local, dtype):
window_size = (3,3) if local else (-1,-1)

x = 4
q = jax.random.normal(jax.random.PRNGKey(0), [x, n, seqlen, h, d], dtype=jnp.float32)
k = jax.random.normal(jax.random.PRNGKey(1), [n, seqlen, h, d], dtype=jnp.float32)
v = jax.random.normal(jax.random.PRNGKey(2), [n, seqlen, h, d], dtype=jnp.float32)
do = jax.random.normal(jax.random.PRNGKey(3), [x, n, seqlen, h, d], dtype=jnp.float32)

def func(mha, q,k,v):
@partial(jax.vmap, in_axes=(0,None,None))
def fwd(q,k,v):
return mha(q,k,v, is_causal=bool(causal), window_size=window_size)
o, bwd = jax.vjp(fwd,q,k,v)
return bwd(do)

ref_out = func(ref_mha, q,k,v)
q = q.astype(dtype)
k = k.astype(dtype)
v = v.astype(dtype)
do = do.astype(dtype)
f16_out = func(ref_mha, q,k,v)

out = func(flash_mha, q,k,v)
check(ref_out, f16_out, out)

if __name__ == '__main__':
test_flash_bwd(1,4,1,32,4,False,False,jnp.float16)

0 comments on commit d43cbca

Please sign in to comment.