-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
sub-quadratic attention #1
base: main
Are you sure you want to change the base?
Changes from all commits
c810c32
c9b3b9f
70dc50d
c794f0b
04a5cbe
b44fa12
8694703
5bfe96d
da8901b
0c4d82f
c5e8e31
b16edc9
b7fc3a8
8f003c2
1334670
0676c13
264dfb7
205f55b
1880c0e
8603c30
96e0d8c
63ca66d
f4c0bf4
624123f
5b92dab
60f0a5e
48db711
ef20fb9
69a8d2e
db25934
7aa8bac
59002c3
a3152d8
0eafb95
9dc6822
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
# original source: | ||
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py | ||
# license: | ||
# unspecified | ||
# credit: | ||
# Amin Rezaei (original author) | ||
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) | ||
# implementation of: | ||
# Self-attention Does Not Need O(n2) Memory": | ||
# https://arxiv.org/abs/2112.05682v2 | ||
|
||
from functools import partial | ||
import torch | ||
from torch import Tensor | ||
from torch.utils.checkpoint import checkpoint | ||
import math | ||
from typing import Optional, NamedTuple, Protocol, List | ||
from ..utils.dynamic_slice import dynamic_slice | ||
|
||
class AttnChunk(NamedTuple): | ||
exp_values: Tensor | ||
exp_weights_sum: Tensor | ||
max_score: Tensor | ||
|
||
class SummarizeChunk(Protocol): | ||
@staticmethod | ||
def __call__( | ||
query: Tensor, | ||
key_t: Tensor, | ||
value: Tensor, | ||
) -> AttnChunk: ... | ||
|
||
class ComputeQueryChunkAttn(Protocol): | ||
@staticmethod | ||
def __call__( | ||
query: Tensor, | ||
key_t: Tensor, | ||
value: Tensor, | ||
) -> Tensor: ... | ||
|
||
def _summarize_chunk( | ||
query: Tensor, | ||
key_t: Tensor, | ||
value: Tensor, | ||
scale: float, | ||
) -> AttnChunk: | ||
attn_weights = torch.baddbmm( | ||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), | ||
query, | ||
key_t, | ||
alpha=scale, | ||
beta=0, | ||
) | ||
max_score, _ = torch.max(attn_weights, -1, keepdim=True) | ||
max_score = max_score.detach() | ||
exp_weights = torch.exp(attn_weights - max_score) | ||
exp_values = torch.bmm(exp_weights, value) | ||
max_score = max_score.squeeze(-1) | ||
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) | ||
|
||
def _query_chunk_attention( | ||
query: Tensor, | ||
key_t: Tensor, | ||
value: Tensor, | ||
summarize_chunk: SummarizeChunk, | ||
kv_chunk_size: int, | ||
) -> Tensor: | ||
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape | ||
_, _, v_channels_per_head = value.shape | ||
|
||
def chunk_scanner(chunk_idx: int) -> AttnChunk: | ||
key_chunk = dynamic_slice( | ||
key_t, | ||
(0, 0, chunk_idx), | ||
(batch_x_heads, k_channels_per_head, kv_chunk_size) | ||
) | ||
value_chunk = dynamic_slice( | ||
value, | ||
(0, chunk_idx, 0), | ||
(batch_x_heads, kv_chunk_size, v_channels_per_head) | ||
) | ||
return summarize_chunk(query, key_chunk, value_chunk) | ||
|
||
chunks: List[AttnChunk] = [ | ||
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) | ||
] | ||
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) | ||
chunk_values, chunk_weights, chunk_max = acc_chunk | ||
|
||
global_max, _ = torch.max(chunk_max, 0, keepdim=True) | ||
max_diffs = torch.exp(chunk_max - global_max) | ||
chunk_values *= torch.unsqueeze(max_diffs, -1) | ||
chunk_weights *= max_diffs | ||
|
||
all_values = chunk_values.sum(dim=0) | ||
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) | ||
return all_values / all_weights | ||
|
||
# TODO: refactor CrossAttention#get_attention_scores to share code with this | ||
def _get_attention_scores_no_kv_chunking( | ||
query: Tensor, | ||
key_t: Tensor, | ||
value: Tensor, | ||
scale: float, | ||
) -> Tensor: | ||
attn_scores = torch.baddbmm( | ||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), | ||
query, | ||
key_t, | ||
alpha=scale, | ||
beta=0, | ||
) | ||
attn_probs = attn_scores.softmax(dim=-1) | ||
del attn_scores | ||
hidden_states_slice = torch.bmm(attn_probs, value) | ||
return hidden_states_slice | ||
|
||
class ScannedChunk(NamedTuple): | ||
chunk_idx: int | ||
attn_chunk: AttnChunk | ||
|
||
def efficient_dot_product_attention( | ||
query: Tensor, | ||
key_t: Tensor, | ||
value: Tensor, | ||
query_chunk_size=1024, | ||
kv_chunk_size: Optional[int] = None, | ||
kv_chunk_size_min: Optional[int] = None, | ||
use_checkpoint=True, | ||
): | ||
"""Computes efficient dot-product attention given query, transposed key, and value. | ||
This is efficient version of attention presented in | ||
https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements. | ||
Args: | ||
query: queries for calculating attention with shape of | ||
`[batch * num_heads, tokens, channels_per_head]`. | ||
key_t: keys for calculating attention with shape of | ||
`[batch * num_heads, channels_per_head, tokens]`. | ||
value: values to be used in attention with shape of | ||
`[batch * num_heads, tokens, channels_per_head]`. | ||
query_chunk_size: int: query chunks size | ||
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens) | ||
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done). | ||
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference) | ||
Returns: | ||
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`. | ||
""" | ||
batch_x_heads, q_tokens, q_channels_per_head = query.shape | ||
_, _, k_tokens = key_t.shape | ||
scale = q_channels_per_head ** -0.5 | ||
|
||
kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens) | ||
if kv_chunk_size_min is not None: | ||
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) | ||
|
||
def get_query_chunk(chunk_idx: int) -> Tensor: | ||
return dynamic_slice( | ||
query, | ||
(0, chunk_idx, 0), | ||
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) | ||
) | ||
|
||
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) | ||
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk | ||
compute_query_chunk_attn: ComputeQueryChunkAttn = partial( | ||
_get_attention_scores_no_kv_chunking, | ||
scale=scale | ||
) if k_tokens <= kv_chunk_size else ( | ||
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw) | ||
partial( | ||
_query_chunk_attention, | ||
kv_chunk_size=kv_chunk_size, | ||
summarize_chunk=summarize_chunk, | ||
) | ||
) | ||
|
||
if q_tokens <= query_chunk_size: | ||
# fast-path for when there's just 1 query chunk | ||
return compute_query_chunk_attn( | ||
query=query, | ||
key_t=key_t, | ||
value=value, | ||
) | ||
|
||
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, | ||
# and pass slices to be mutated, instead of torch.cat()ing the returned slices | ||
res = torch.cat([ | ||
compute_query_chunk_attn( | ||
query=get_query_chunk(i * query_chunk_size), | ||
key_t=key_t, | ||
value=value, | ||
) for i in range(math.ceil(q_tokens / query_chunk_size)) | ||
], dim=1) | ||
return res |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from torch import Tensor | ||
from typing import List | ||
|
||
def dynamic_slice( | ||
x: Tensor, | ||
starts: List[int], | ||
sizes: List[int], | ||
) -> Tensor: | ||
slicing = [slice(start, start + size) for start, size in zip(starts, sizes)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this attempts to implement There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that works also: No notable performance difference that I observed, but it's probably slightly more efficient nonetheless. |
||
return x[slicing] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't
torch.zeros()
be used here instead oftorch.empty()
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nope; it's actually an unused tensor (because
beta=0
), so we want whatever's the cheapest thing that passes the parameter validation. unfortunately PyTorch complains if you passNone
. bad API design.