-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Up to 2x speedup on GPUs using memory efficient attention #532
Changes from all commits
2a59e0c
68e1ef5
db557e8
9d9aea0
54c9c15
e59ea36
321c390
7079dab
eff0c42
3f109ca
13b187e
24e71be
1fba7eb
0f75d57
3db0a2f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ We present some techniques and ideas to optimize 🤗 Diffusers _inference_ for | |
| fp16 | 3.61s | x2.63 | | ||
| channels last | 3.30s | x2.88 | | ||
| traced UNet | 3.21s | x2.96 | | ||
| memory efficient attention | 2.63s | x3.61 | | ||
|
||
<em> | ||
obtained on NVIDIA TITAN RTX by generating a single image of size 512x512 from | ||
|
@@ -290,3 +291,41 @@ pipe.unet = TracedUNet() | |
with torch.inference_mode(): | ||
image = pipe([prompt] * 1, num_inference_steps=50).images[0] | ||
``` | ||
|
||
|
||
## Memory Efficient Attention | ||
Recent work on optimizing the bandwitdh in the attention block have generated huge speed ups and gains in GPU memory usage. The most recent being Flash Attention (from @tridao, [code](https://github.com/HazyResearch/flash-attention), [paper](https://arxiv.org/pdf/2205.14135.pdf)) . | ||
Here are the speedups we obtain on a few Nvidia GPUs when running the inference at 512x512 with a batch size of 1 (one prompt): | ||
|
||
| GPU | Base Attention FP16 | Memory Efficient Attention FP16 | | ||
|------------------ |--------------------- |--------------------------------- | | ||
| NVIDIA Tesla T4 | 3.5it/s | 5.5it/s | | ||
Comment on lines
+297
to
+302
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks a lot for adding this, great doc! |
||
| NVIDIA 3060 RTX | 4.6it/s | 7.8it/s | | ||
| NVIDIA A10G | 8.88it/s | 15.6it/s | | ||
| NVIDIA RTX A6000 | 11.7it/s | 21.09it/s | | ||
| NVIDIA TITAN RTX | 12.51it/s | 18.22it/s | | ||
| A100-SXM4-40GB | 18.6it/s | 29.it/s | | ||
| A100-SXM-80GB | 18.7it/s | 29.5it/s | | ||
|
||
To leverage it just make sure you have: | ||
- PyTorch > 1.12 | ||
- Cuda available | ||
- Installed the [xformers](https://github.com/facebookresearch/xformers) library | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can also indicate that users should build from source if they want to correctly compile the binaries and get the same benefits |
||
```python | ||
from diffusers import StableDiffusionPipeline | ||
import torch | ||
|
||
pipe = StableDiffusionPipeline.from_pretrained( | ||
"runwayml/stable-diffusion-v1-5", | ||
revision="fp16", | ||
torch_dtype=torch.float16, | ||
).to("cuda") | ||
|
||
pipe.enable_xformers_memory_efficient_attention() | ||
|
||
with torch.inference_mode(): | ||
sample = pipe("a small cat") | ||
|
||
# optional: You can disable it via | ||
# pipe.disable_xformers_memory_efficient_attention() | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,15 @@ | |
import torch.nn.functional as F | ||
from torch import nn | ||
|
||
from diffusers.utils.import_utils import is_xformers_available | ||
|
||
|
||
if is_xformers_available(): | ||
import xformers | ||
import xformers.ops | ||
else: | ||
xformers = None | ||
|
||
|
||
class AttentionBlock(nn.Module): | ||
""" | ||
|
@@ -150,6 +159,10 @@ def _set_attention_slice(self, slice_size): | |
for block in self.transformer_blocks: | ||
block._set_attention_slice(slice_size) | ||
|
||
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): | ||
for block in self.transformer_blocks: | ||
block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) | ||
|
||
def forward(self, hidden_states, context=None): | ||
# note: if no context is given, cross-attention defaults to self-attention | ||
batch, channel, height, weight = hidden_states.shape | ||
|
@@ -206,6 +219,32 @@ def _set_attention_slice(self, slice_size): | |
self.attn1._slice_size = slice_size | ||
self.attn2._slice_size = slice_size | ||
|
||
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): | ||
if not is_xformers_available(): | ||
print("Here is how to install it") | ||
raise ModuleNotFoundError( | ||
"Refer to https://github.com/facebookresearch/xformers for more information on how to install" | ||
" xformers", | ||
name="xformers", | ||
) | ||
elif not torch.cuda.is_available(): | ||
raise ValueError( | ||
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" | ||
" available for GPU " | ||
) | ||
else: | ||
try: | ||
# Make sure we can run the memory efficient attention | ||
_ = xformers.ops.memory_efficient_attention( | ||
torch.randn((1, 2, 40), device="cuda"), | ||
torch.randn((1, 2, 40), device="cuda"), | ||
torch.randn((1, 2, 40), device="cuda"), | ||
) | ||
except Exception as e: | ||
raise e | ||
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers | ||
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit) Could we maybe do some more checks here? E.g. what if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my experience it does at runtime but this might a bit late indeed. Happy to update it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the pytorch version check should probably done inside the function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
def forward(self, hidden_states, context=None): | ||
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states | ||
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states | ||
|
@@ -239,6 +278,7 @@ def __init__( | |
# is split across the batch axis to save memory | ||
# You can set slice_size with `set_attention_slice` | ||
self._slice_size = None | ||
self._use_memory_efficient_attention_xformers = False | ||
|
||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) | ||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False) | ||
|
@@ -279,11 +319,13 @@ def forward(self, hidden_states, context=None, mask=None): | |
# TODO(PVP) - mask is currently never used. Remember to re-implement when used | ||
|
||
# attention, what we cannot get enough of | ||
|
||
if self._slice_size is None or query.shape[0] // self._slice_size == 1: | ||
hidden_states = self._attention(query, key, value) | ||
if self._use_memory_efficient_attention_xformers: | ||
hidden_states = self._memory_efficient_attention_xformers(query, key, value) | ||
else: | ||
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) | ||
if self._slice_size is None or query.shape[0] // self._slice_size == 1: | ||
hidden_states = self._attention(query, key, value) | ||
else: | ||
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) | ||
|
||
# linear proj | ||
hidden_states = self.to_out[0](hidden_states) | ||
|
@@ -341,6 +383,11 @@ def _sliced_attention(self, query, key, value, sequence_length, dim): | |
hidden_states = self.reshape_batch_dim_to_heads(hidden_states) | ||
return hidden_states | ||
|
||
def _memory_efficient_attention_xformers(self, query, key, value): | ||
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) | ||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states) | ||
return hidden_states | ||
|
||
|
||
class FeedForward(nn.Module): | ||
r""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -113,6 +113,24 @@ def __init__( | |
feature_extractor=feature_extractor, | ||
) | ||
|
||
def enable_xformers_memory_efficient_attention(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. very nice! |
||
r""" | ||
Enable memory efficient attention as implemented in xformers. | ||
|
||
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference | ||
time. Speed up at training time is not guaranteed. | ||
|
||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention | ||
is used. | ||
""" | ||
self.unet.set_use_memory_efficient_attention_xformers(True) | ||
|
||
def disable_xformers_memory_efficient_attention(self): | ||
r""" | ||
Disable memory efficient attention as implemented in xformers. | ||
""" | ||
self.unet.set_use_memory_efficient_attention_xformers(False) | ||
|
||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | ||
r""" | ||
Enable sliced attention computation. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice addition btw 🔥