diff --git a/README.md b/README.md index fe0fc14..aa6293a 100644 --- a/README.md +++ b/README.md @@ -11,10 +11,11 @@ Please cite (see below) and credit FlashAttention if you use it. ## Installation and features Requirements: -- CUDA 11.6 and above. +- CUDA 11.8 and above. - Linux. Same story as with the pytorch repo. I haven't tested compilation of the jax bindings on windows. +- JAX >=`0.4.24`. The custom sharding used for ring attention requires some somewhat advanced features. -To install: TODO +To install: For now, download the appropriate release from the releases page and install it with pip. Interface: `src/flash_attn_jax/flash.py` @@ -28,6 +29,17 @@ Accepts q,k,v with shape `[n, l, h, d]`, and returns `[n, l, h, d]`. `softmax_sc multiplier for the softmax, defaulting to `1/sqrt(d)`. Set window_size to positive values for sliding window attention. +### Now Supports Ring Attention + +Use jax.Array and shard your tensors along the length dimension, and flash_mha will automatically use the ring attention algorithm: + +```py +with Mesh(devices, axis_names=('len',)) as mesh: + sharding = NamedSharding(mesh, P(None,'len',None)) # n l d + tokens = jax.device_put(tokens, sharding) + # invoke your jax.jit'd transformer.forward +``` + FlashAttention-2 currently supports: 1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing diff --git a/src/flash_attn_jax/__init__.py b/src/flash_attn_jax/__init__.py index 04ee448..876e930 100644 --- a/src/flash_attn_jax/__init__.py +++ b/src/flash_attn_jax/__init__.py @@ -1,2 +1,2 @@ from .flash import flash_mha -__version__ = 'v2.5.0' +__version__ = 'v2.5.5' diff --git a/src/flash_attn_jax/flash_sharding.py b/src/flash_attn_jax/flash_sharding.py index ad7360c..33766e9 100644 --- a/src/flash_attn_jax/flash_sharding.py +++ b/src/flash_attn_jax/flash_sharding.py @@ -30,53 +30,6 @@ from jax._src.ad_checkpoint import _optimization_barrier -def ring_fwd(softmax_scale, is_causal, axis_name, axis_size, q,k,v): - [n,l,h,d] = q.shape - - q_ix = jax.lax.axis_index(axis_name) - k_ix = jax.lax.axis_index(axis_name) - - o = jnp.zeros([n,l,h,d], jnp.float32) - lse = jnp.full([n,h,l], float('-inf'), jnp.float32) - - # scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b]) - def f(c, a): - (k, v, o, lse, k_ix) = c - - o1, lse1 = o, lse - if is_causal: - o2, lse2 = jax.lax.switch((k_ix < q_ix).astype(jnp.int32) + (k_ix <= q_ix).astype(jnp.int32), - [ - lambda q,k,v: (jnp.zeros([n,l,h,d], q.dtype), jnp.full([n,h,l], float('-inf'), jnp.float32)), - lambda q,k,v: _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=True, window_size=(-1,-1)), - lambda q,k,v: _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1)), - ], q, k, v) - else: - o2, lse2 = _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1)) - o2 = o2.astype(jnp.float32) - - mx = jnp.maximum(lse1,lse2) - mn = jnp.minimum(lse1,lse2) - lse = jnp.log1p(jnp.exp(mn-mx)) + mx - - o = (o1 * rearrange(jnp.exp(lse1 - lse), 'n h l -> n l h 1') + - o2 * rearrange(jnp.exp(lse2 - lse), 'n h l -> n l h 1')) - - k2 = jax.lax.ppermute(k, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)]) - v2 = jax.lax.ppermute(v, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)]) - k_ix = jax.lax.ppermute(k_ix, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)]) - - return ((k2, v2, o, lse, k_ix), None) - acc = (k,v,o,lse,k_ix) - # We sadly have to manually unroll this because scan breaks the axis context preventing us from using ppermute (unroll=axis_size doesn't help either). - # Optimization barrier prevents instruction reordering so that ppermute and flash_mha execute concurrently. - for _ in range(axis_size): - acc, _ = f(acc, None) - acc = _optimization_barrier(acc) - (_,_,o,lse,_) = acc - # (_,_,o,lse), _ = jax.lax.scan(f,init,None,axis_size) - return o.astype(q.dtype), lse - def partition_fwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, result_shape): result_shardings = jax.tree_map(lambda x: x.sharding, result_shape) arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes) @@ -147,17 +100,32 @@ def partition_bwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, resul o_sharding = arg_shardings[4] lse_sharding = arg_shardings[5] if isinstance(q_sharding, PositionalSharding): - do_sharding = q_sharding.replicate((1,3)) - [n, l, h, d] = do_sharding.shape - lse_sharding = do_sharding.reshape(n,l,h).transpose(0,2,1) # n h l - result_shardings = (do_sharding,)*3 - arg_shardings = (do_sharding,)*5 + (lse_sharding,) + assert q_sharding == k_sharding, "Expect q and k sharding to match" + assert q_sharding == v_sharding, "Expect q and v sharding to match" + [n, l, h, d] = q_sharding.shape + assert d == 1, "Sharding across `d` won't be efficient, so it's not supported." + assert l == 1, "For ring attention, use `with Mesh(...) as mesh` and NamedSharding." + lse_sharding = q_sharding.reshape(n,h,1) # n h l + result_shardings = (q_sharding,)*3 + arg_shardings = (q_sharding,)*5 + (lse_sharding,) elif isinstance(q_sharding, NamedSharding): mesh = q_sharding.mesh [n,l,h,d] = q_sharding.spec - do_sharding = NamedSharding(mesh, P(n,None,h,None)) - lse_sharding = NamedSharding(mesh, P(n,h,None)) - result_shardings = (do_sharding,)*3 + assert d == None, "Sharding across `d` won't be efficient, so it's not supported." + if l != None: + # assert not is_causal and window_size == (-1,-1), "Ring attention doesn't support causal or local masking yet." + assert window_size == (-1,-1), "Ring attention doesn't support local masking yet." + result_shardings = q_sharding, q_sharding, q_sharding + lse_sharding = NamedSharding(mesh, P(n,h,l)) + arg_shardings = (q_sharding,)*5 + (lse_sharding,) + axis_name = l + axis_size = mesh.shape[axis_name] + # ring attention + return mesh, partial(ring_bwd, softmax_scale, is_causal, axis_name, axis_size), result_shardings, arg_shardings + else: + result_shardings = q_sharding, q_sharding, q_sharding + lse_sharding = NamedSharding(mesh, P(n,h,l)) + arg_shardings = (q_sharding,)*5 + (lse_sharding,) def fwd(*args): return _flash_mha_bwd_hlo(*args, softmax_scale=softmax_scale, is_causal=is_causal, window_size=window_size) return mesh, fwd, result_shardings, arg_shardings @@ -165,3 +133,103 @@ def fwd(*args): _flash_mha_bwd_hlo_sharded.def_partition( infer_sharding_from_operands=infer_sharding_bwd, partition=partition_bwd) + +# ==== Ring Forward ==== + +def ring_fwd(softmax_scale, is_causal, axis_name, axis_size, q,k,v): + [n,l,h,d] = q.shape + + q_ix = jax.lax.axis_index(axis_name) + k_ix = jax.lax.axis_index(axis_name) + + o = jnp.zeros([n,l,h,d], jnp.float32) + lse = jnp.full([n,h,l], float('-inf'), jnp.float32) + + # scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b]) + def f(c, a): + (k, v, o, lse, k_ix) = c + + o1, lse1 = o, lse + if is_causal: + o2, lse2 = jax.lax.switch((k_ix < q_ix).astype(jnp.int32) + (k_ix <= q_ix).astype(jnp.int32), + [ + lambda q,k,v: (jnp.zeros([n,l,h,d], q.dtype), jnp.full([n,h,l], float('-inf'), jnp.float32)), + lambda q,k,v: _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=True, window_size=(-1,-1)), + lambda q,k,v: _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1)), + ], q, k, v) + else: + o2, lse2 = _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1)) + o2 = o2.astype(jnp.float32) + + mx = jnp.maximum(lse1,lse2) + mn = jnp.minimum(lse1,lse2) + lse = jnp.log1p(jnp.exp(mn-mx)) + mx + + o = (o1 * rearrange(jnp.exp(lse1 - lse), 'n h l -> n l h 1') + + o2 * rearrange(jnp.exp(lse2 - lse), 'n h l -> n l h 1')) + + k2 = jax.lax.ppermute(k, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)]) + v2 = jax.lax.ppermute(v, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)]) + k_ix = jax.lax.ppermute(k_ix, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)]) + + return ((k2, v2, o, lse, k_ix), None) + acc = (k,v,o,lse,k_ix) + # We sadly have to manually unroll this because scan breaks the axis context preventing us from using ppermute (unroll=axis_size doesn't help either). + # Optimization barrier prevents instruction reordering so that ppermute and flash_mha execute concurrently. + for _ in range(axis_size): + acc, _ = f(acc, None) + acc = _optimization_barrier(acc) + (_,_,o,lse,_) = acc + # (_,_,o,lse), _ = jax.lax.scan(f,init,None,axis_size) + return o.astype(q.dtype), lse + +# ==== Ring Backward === + +# This doesn't seem like the most efficient way to do this, kind of wasting compute by calculating every dq,dk,dv twice. +# Should we send the accumulator for dk,dv cross-device instead? Relying on the fact that after a full cycle, they return to the starting device. +def ring_bwd(softmax_scale, is_causal, axis_name, axis_size, do,q,k,v,o,lse): + [n,l,h,d] = q.shape + + ix = jax.lax.axis_index(axis_name) + + dq = jnp.zeros([n,l,h,d], jnp.float32) + dk = jnp.zeros([n,l,h,d], jnp.float32) + dv = jnp.zeros([n,l,h,d], jnp.float32) + + # scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b]) + def f(acc, a): + (do2,q2,k2,v2,o2,lse2,ix2, dq,dk,dv) = acc + + cmp = (ix2 < ix).astype(jnp.int32) + (ix2 <= ix).astype(jnp.int32) + # 0: ix < ix2 + # 1: ix = ix2 + # 2: ix > ix2 + if is_causal: + dqa = jax.lax.switch(cmp, [ + lambda q,k,v: jnp.zeros([n,l,h,d], q.dtype), + lambda q,k,v: _flash_mha_bwd_hlo(do,q,k2,v2,o,lse, softmax_scale=softmax_scale, is_causal=True, window_size=(-1,-1))[0], + lambda q,k,v: _flash_mha_bwd_hlo(do,q,k2,v2,o,lse, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))[0], + ], q, k, v) + dka,dva = jax.lax.switch(cmp, [ + lambda q,k,v: _flash_mha_bwd_hlo(do2,q2,k,v,o2,lse2, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))[1:], + lambda q,k,v: _flash_mha_bwd_hlo(do2,q2,k,v,o2,lse2, softmax_scale=softmax_scale, is_causal=True, window_size=(-1,-1))[1:], + lambda q,k,v: (jnp.zeros([n,l,h,d], q.dtype),jnp.zeros([n,l,h,d], q.dtype)), + ], q, k, v) + else: + dqa,_,_ = _flash_mha_bwd_hlo(do,q,k2,v2,o,lse, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1)) + _,dka,dva = _flash_mha_bwd_hlo(do2,q2,k,v,o2,lse2, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1)) + + dq += dqa + dk += dka + dv += dva + + (do2,q2,k2,v2,o2,lse2,ix2) = jax.lax.ppermute((do2,q2,k2,v2,o2,lse2,ix2), axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)]) + + return ((do2,q2,k2,v2,o2,lse2,ix2, dq,dk,dv), None) + acc = (do,q,k,v,o,lse,ix,dq,dk,dv) + # Unrolled as above. + for _ in range(axis_size): + acc, _ = f(acc, None) + acc = _optimization_barrier(acc) + (do2,q2,k2,v2,o2,lse2,ix2, dq,dk,dv) = acc + return dq.astype(q.dtype),dk.astype(q.dtype),dv.astype(q.dtype) \ No newline at end of file diff --git a/tests/test_sharding.py b/tests/test_sharding.py index 71bb970..14dde64 100644 --- a/tests/test_sharding.py +++ b/tests/test_sharding.py @@ -103,8 +103,7 @@ def with_sharding(q_sharding, kv_sharding=None): @pytest.mark.parametrize("d", [32]) @pytest.mark.parametrize("h", [4]) @pytest.mark.parametrize("seqlen", [128]) -@pytest.mark.parametrize("shard_dim", [0,2]) -def test_flash_bwd_sharded_hlo(seqlen, h, d, causal, local, dtype, shard_dim): +def test_flash_bwd_sharded_hlo(seqlen, h, d, causal, local, dtype): window_size = (3,3) if local else (-1,-1) devices = jax.local_devices()[:4] @@ -117,19 +116,35 @@ def test_flash_bwd_sharded_hlo(seqlen, h, d, causal, local, dtype, shard_dim): def flash(qkv): return (flash_mha(*qkv, is_causal=bool(causal), window_size=window_size)**2).sum() - q = jax.random.normal(jax.random.PRNGKey(0), [n, seqlen, h, d], dtype=dtype) - k = jax.random.normal(jax.random.PRNGKey(1), [n, seqlen, h, d], dtype=dtype) - v = jax.random.normal(jax.random.PRNGKey(2), [n, seqlen, h, d], dtype=dtype) + def with_sharding(sharding): + q = jax.random.normal(jax.random.PRNGKey(0), [n, seqlen, h, d], dtype=dtype) + k = jax.random.normal(jax.random.PRNGKey(1), [n, seqlen, h, d], dtype=dtype) + v = jax.random.normal(jax.random.PRNGKey(2), [n, seqlen, h, d], dtype=dtype) + (q,k,v) = jax.device_put((q,k,v), sharding) + hlo = flash.lower((q,k,v)).compile().as_text() + return hlo - shape = [1,1,1,1] - shape[shard_dim] = n - sharding = PositionalSharding(devices).reshape(shape) + hlo = with_sharding(PositionalSharding(devices).reshape(n,1,1,1)) + assert 'all-gather' not in hlo + assert 'dynamic-slice' not in hlo - q,k,v = jax.device_put((q,k,v), sharding) - hlo = flash.lower((q,k,v)).compile().as_text() + hlo = with_sharding(PositionalSharding(devices).reshape(1,1,n,1)) assert 'all-gather' not in hlo assert 'dynamic-slice' not in hlo + if not local: + with Mesh(np.array(devices), axis_names=('x',)) as mesh: + sharding = NamedSharding(mesh, P(None,'x',None,None)) + hlo = with_sharding(sharding) + # No resharding should occur, only manual collective-permute. + assert 'all-gather' not in hlo + assert 'dynamic-slice' not in hlo + assert 'collective-permute' in hlo + # Should always run concurrently, meaning custom-call is always between start and done. + import re + collectives = ''.join(re.findall(" collective-permute-start| collective-permute-done| custom-call", hlo)) + assert 'collective-permute-start collective-permute-done' not in collectives, hlo + @pytest.mark.skipif(len(jax.local_devices()) < 2, reason='Requires >1 gpu device') @pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16]) @pytest.mark.parametrize("local", ['local','']) @@ -181,8 +196,7 @@ def check_sharding(sharding,q,k,v): @pytest.mark.parametrize("d", [32]) @pytest.mark.parametrize("h", [4, 8]) @pytest.mark.parametrize("seqlen", [128]) -@pytest.mark.parametrize("shard_dim", [0,2]) -def test_flash_bwd_sharded(seqlen, h, d, causal, local, dtype, shard_dim): +def test_flash_bwd_sharded(seqlen, h, d, causal, local, dtype): window_size = (3,3) if local else (-1,-1) devices = jax.local_devices() @@ -200,23 +214,25 @@ def flash(qkv): k = jax.random.normal(jax.random.PRNGKey(1), [n, seqlen, h, d], dtype=jnp.float32) v = jax.random.normal(jax.random.PRNGKey(2), [n, seqlen, h, d], dtype=jnp.float32) - if q.shape[shard_dim] % n != 0: - pytest.skip(f"{q.shape[shard_dim]} doesn't divide into {n} so we can't shard it.") - ref_out = ref((q,k,v)) q = q.astype(dtype) k = k.astype(dtype) v = v.astype(dtype) - repl_out = flash((q,k,v)) + ref16_out = flash((q,k,v)) + + def check_sharding(sharding,q,k,v): + (q,k,v) = jax.device_put((q,k,v), sharding) + out = flash((q,k,v)) + check(ref_out,ref16_out,out) - shape = [1,1,1,1] - shape[shard_dim] = n - sharding = PositionalSharding(devices).reshape(shape) + check_sharding(PositionalSharding(devices).reshape(n,1,1,1),q,k,v) + check_sharding(PositionalSharding(devices).reshape(1,1,n,1),q,k,v) - (q,k,v) = jax.device_put((q,k,v), sharding) - hlo = flash.lower((q,k,v)).compile().as_text() - out = flash((q,k,v)) - check(ref_out, repl_out, out) + if not local: + # Ring attention + with Mesh(np.array(devices), axis_names=('x',)) as mesh: + sharding = NamedSharding(mesh, P(None,'x',None,None)) + check_sharding(sharding,q,k,v) if __name__ == '__main__': test_flash_fwd_sharded_hlo(128,4,32,False,False,jnp.float16)