Skip to content

Commit

Permalink
add approx softcap
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Aug 7, 2024
1 parent 738268e commit 757b03d
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 19 deletions.
1 change: 1 addition & 0 deletions attn_gym/mods/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from attn_gym.mods.alibi import generate_alibi_bias
from attn_gym.mods.softcapping import generate_tanh_softcap
59 changes: 56 additions & 3 deletions attn_gym/mods/softcapping.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,75 @@
"""Implementation of an tanh softcapping score mod popularized in Gemma2 paper."""

from torch import tanh
import torch
from torch import Tensor
from torch.nn.attention.flex_attention import _score_mod_signature
from torch._inductor.lowering import make_pointwise, register_lowering

# Some internal torch.compile details
from torch._inductor.virtualized import ops
from functools import partial

def generate_tanh_softcap(soft_cap: int) -> _score_mod_signature:

@torch.library.custom_op("approx::tanh", mutates_args=())
def tanh_approx(inp: Tensor) -> Tensor:
return torch.tanh(inp)


@tanh_approx.register_fake
def _(inp: torch.Tensor) -> torch.Tensor:
return torch.tanh(inp)


def tanh_approx_lowering(inp):
fn = partial(ops.inline_asm_elementwise, asm="tanh.approx.f32 $0, $1;")
return make_pointwise(fn)(inp)


register_lowering(torch.ops.approx.tanh)(tanh_approx_lowering)


class _TanhApprox(torch.autograd.Function):
@staticmethod
def forward(x):
return torch.ops.approx.tanh(x)

@staticmethod
def setup_context(ctx, inputs, output):
(x,) = inputs
result = output
ctx.save_for_backward(result)

@staticmethod
def backward(ctx, grad_output):
(result,) = ctx.saved_tensors
return grad_output * (1 - result * result)

@staticmethod
def vmap(info, in_dims, x):
return torch.tanh(x), 0


_tanh_approx = _TanhApprox.apply


def generate_tanh_softcap(soft_cap: int, approx: bool = False) -> _score_mod_signature:
"""Returns an tanh bias score_mod given the number of heads H
Args:
soft_cap: The soft cap value to use for normalizing logits
approx: Whether to use the `tanh.approx.` ptx instruction
Returns:
tanh_softcap: score_mod
"""
tanh = _tanh_approx if approx else torch.tanh

def tanh_softcap(score, b, h, q_idx, kv_idx):
return score * tanh(score / soft_cap)

prefix = "tanh_softcap_approx" if approx else "tanh_softcap"
tanh_softcap.__name__ = f"{prefix}_{soft_cap}"

return tanh_softcap


Expand All @@ -36,7 +89,7 @@ def make_tensor():

query, key = make_tensor(), make_tensor()

tanh_softcap_score_mod = generate_tanh_softcap(30)
tanh_softcap_score_mod = generate_tanh_softcap(30, approx=True)

visualize_attention_scores(
query, key, score_mod=tanh_softcap_score_mod, device=device, name="tanh_softcap_score_mod"
Expand Down
63 changes: 48 additions & 15 deletions examples/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import lru_cache
from typing import Optional
from typing import Optional, List

import torch
import torch.nn.functional as F
Expand All @@ -23,7 +23,7 @@
generate_prefix_lm_mask,
generate_doc_mask_mod,
)
from attn_gym.mods import generate_alibi_bias
from attn_gym.mods import generate_alibi_bias, generate_tanh_softcap


torch.set_default_device("cuda")
Expand Down Expand Up @@ -204,20 +204,53 @@ def generate_random_lengths(total_length, num_documents):
test_mask(mask_mod=document_causal_mask, S=32768)


def main():
test_mask(mask_mod=causal_mask)
# Correctness check here is simple and only works with mask_fns and not actual score_mods
test_mask(score_mod=generate_alibi_bias(16), skip_correctness=True)

sliding_window_mask = generate_sliding_window(window_size=1024)
test_mask(mask_mod=sliding_window_mask)

prefix_lm_mask = generate_prefix_lm_mask(prefix_length=1024)
test_mask(mask_mod=prefix_lm_mask)
def main(examples: List[str] = ["all"]):
"""Run the benchmark with the given examples.
Args:
examples: List of examples to run. If "all" is specified, all examples will be run.
"""
available_examples = {
"causal": lambda: test_mask(mask_mod=causal_mask),
"alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=True),
"sliding_window": lambda: test_mask(mask_mod=generate_sliding_window(window_size=1024)),
"prefix_lm": lambda: test_mask(mask_mod=generate_prefix_lm_mask(prefix_length=1024)),
"document": lambda: run_document_masking(max_seq_len=32768, num_docs=12),
"softcap": lambda: test_mask(
score_mod=generate_tanh_softcap(30, approx=False), skip_correctness=True
),
"softcap_approx": lambda: test_mask(
score_mod=generate_tanh_softcap(30, approx=True), skip_correctness=True
),
}

if "all" in examples:
ex_to_run = list(available_examples.keys())
else:
ex_to_run = examples

# Document masking
run_document_masking(max_seq_len=32768, num_docs=12)
for ex in ex_to_run:
if ex in available_examples:
available_examples[ex]()
else:
print(f"Warning: Unknown example key '{ex}'. Skipping.")


if __name__ == "__main__":
main()
try:
from jsonargparse import ArgumentParser
except ImportError:
raise ImportError("Be sure to run: pip install -e .'[viz]'")
parser = ArgumentParser(description="Run specific examples or all examples.")
parser.add_argument(
"--examples",
type=str,
nargs="+",
default=["all"],
help="List of examples to run. Use space to separate multiple examples. "
"Available options: causal, alibi, sliding_window, prefix_lm, "
"document, softcap, softcap_approx, or 'all' to run all examples.",
)

args = parser.parse_args()
main(**vars(args))
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ dev = [
"pytest",
"ruff",
"jsonargparse",
"docstring-parser"
"docstring-parser",
"pytest"
]

viz = [
Expand Down
33 changes: 33 additions & 0 deletions test/test_mods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
from torch.autograd import grad
from torch.nn.attention.flex_attention import flex_attention
import pytest
from functools import partial
from attn_gym.mods import generate_tanh_softcap


def test_tanh_approx():
softcap_mod = generate_tanh_softcap(30, approx=False)
softcap_mod_approx = generate_tanh_softcap(30, approx=True)
make_tensor = partial(
torch.randn, 1, 1, 128, 64, dtype=torch.float16, device="cuda", requires_grad=True
)

query, key, value = make_tensor(), make_tensor(), make_tensor()

flex_attention_compile = torch.compile(flex_attention)
out = flex_attention_compile(query, key, value, score_mod=softcap_mod)

grad_q, grad_k, grad_v = grad(out.sum(), (query, key, value))

out_approx = flex_attention_compile(query, key, value, score_mod=softcap_mod_approx)
grad_q_approx, grad_k_approx, grad_v_approx = grad(out_approx.sum(), (query, key, value))

for tensor_softcap, tensor_softcap_approx in zip(
[out, grad_q, grad_k, grad_v], [out_approx, grad_q_approx, grad_k_approx, grad_v_approx]
):
torch.testing.assert_close(tensor_softcap, tensor_softcap_approx, atol=7e-5, rtol=1e-3)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 757b03d

Please sign in to comment.