Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Up to 2x speedup on GPUs using memory efficient attention #532

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions docs/source/optimization/fp16.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ We present some techniques and ideas to optimize 🤗 Diffusers _inference_ for
| fp16 | 3.61s | x2.63 |
| channels last | 3.30s | x2.88 |
| traced UNet | 3.21s | x2.96 |
| memory efficient attention | 2.63s | x3.61 |

<em>
obtained on NVIDIA TITAN RTX by generating a single image of size 512x512 from
Expand Down Expand Up @@ -290,3 +291,41 @@ pipe.unet = TracedUNet()
with torch.inference_mode():
image = pipe([prompt] * 1, num_inference_steps=50).images[0]
```


## Memory Efficient Attention
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice addition btw 🔥

Recent work on optimizing the bandwitdh in the attention block have generated huge speed ups and gains in GPU memory usage. The most recent being Flash Attention (from @tridao, [code](https://github.com/HazyResearch/flash-attention), [paper](https://arxiv.org/pdf/2205.14135.pdf)) .
Here are the speedups we obtain on a few Nvidia GPUs when running the inference at 512x512 with a batch size of 1 (one prompt):

| GPU | Base Attention FP16 | Memory Efficient Attention FP16 |
|------------------ |--------------------- |--------------------------------- |
| NVIDIA Tesla T4 | 3.5it/s | 5.5it/s |
Comment on lines +297 to +302
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding this, great doc!

| NVIDIA 3060 RTX | 4.6it/s | 7.8it/s |
| NVIDIA A10G | 8.88it/s | 15.6it/s |
| NVIDIA RTX A6000 | 11.7it/s | 21.09it/s |
| NVIDIA TITAN RTX | 12.51it/s | 18.22it/s |
| A100-SXM4-40GB | 18.6it/s | 29.it/s |
| A100-SXM-80GB | 18.7it/s | 29.5it/s |

To leverage it just make sure you have:
- PyTorch > 1.12
- Cuda available
- Installed the [xformers](https://github.com/facebookresearch/xformers) library
Copy link
Member

@NouamaneTazi NouamaneTazi Oct 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also indicate that users should build from source if they want to correctly compile the binaries and get the same benefits
python setup.py build develop

```python
from diffusers import StableDiffusionPipeline
import torch

pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16,
).to("cuda")

pipe.enable_xformers_memory_efficient_attention()

with torch.inference_mode():
sample = pipe("a small cat")

# optional: You can disable it via
# pipe.disable_xformers_memory_efficient_attention()
```
55 changes: 51 additions & 4 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@
import torch.nn.functional as F
from torch import nn

from diffusers.utils.import_utils import is_xformers_available


if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None


class AttentionBlock(nn.Module):
"""
Expand Down Expand Up @@ -150,6 +159,10 @@ def _set_attention_slice(self, slice_size):
for block in self.transformer_blocks:
block._set_attention_slice(slice_size)

def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for block in self.transformer_blocks:
block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

def forward(self, hidden_states, context=None):
# note: if no context is given, cross-attention defaults to self-attention
batch, channel, height, weight = hidden_states.shape
Expand Down Expand Up @@ -206,6 +219,32 @@ def _set_attention_slice(self, slice_size):
self.attn1._slice_size = slice_size
self.attn2._slice_size = slice_size

def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
if not is_xformers_available():
print("Here is how to install it")
raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers",
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
" available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) Could we maybe do some more checks here? E.g. what if xformers is available but the kernels are not compiled or available? What if Pytorch is < 1.12 - does xformers throw an error in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my experience it does at runtime but this might a bit late indeed. Happy to update it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the pytorch version check should probably done inside the function is_xformers_available() as it is a package dependency. What do you think ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the pytorch version check should probably done inside the function is_xformers_available() as it is a package dependency. What do you think ?
Makes sense to me!


def forward(self, hidden_states, context=None):
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
Expand Down Expand Up @@ -239,6 +278,7 @@ def __init__(
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self._slice_size = None
self._use_memory_efficient_attention_xformers = False

self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
Expand Down Expand Up @@ -279,11 +319,13 @@ def forward(self, hidden_states, context=None, mask=None):
# TODO(PVP) - mask is currently never used. Remember to re-implement when used

# attention, what we cannot get enough of

if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value)
if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)

# linear proj
hidden_states = self.to_out[0](hidden_states)
Expand Down Expand Up @@ -341,6 +383,11 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states

def _memory_efficient_attention_xformers(self, query, key, value):
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states


class FeedForward(nn.Module):
r"""
Expand Down
12 changes: 12 additions & 0 deletions src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,10 @@ def set_attention_slice(self, slice_size):
for attn in self.attentions:
attn._set_attention_slice(slice_size)

def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
Expand Down Expand Up @@ -542,6 +546,10 @@ def set_attention_slice(self, slice_size):
for attn in self.attentions:
attn._set_attention_slice(slice_size)

def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = ()

Expand Down Expand Up @@ -1117,6 +1125,10 @@ def set_attention_slice(self, slice_size):

self.gradient_checkpointing = False

def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

def forward(
self,
hidden_states,
Expand Down
11 changes: 11 additions & 0 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,17 @@ def set_attention_slice(self, slice_size):
if hasattr(block, "attentions") and block.attentions is not None:
block.set_attention_slice(slice_size)

def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for block in self.down_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

for block in self.up_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
module.gradient_checkpointing = value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,24 @@ def __init__(
feature_extractor=feature_extractor,
)

def enable_xformers_memory_efficient_attention(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice!

r"""
Enable memory efficient attention as implemented in xformers.

When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.

Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)

def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,24 @@ def disable_attention_slicing(self):
# set slice_size = `None` to disable `set_attention_slice`
self.enable_attention_slicing(None)

def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.

When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.

Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)

def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)

@torch.no_grad()
def __call__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,24 @@ def disable_attention_slicing(self):
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)

def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.

When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.

Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)

def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)

@torch.no_grad()
def __call__(
self,
Expand Down
16 changes: 16 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@
except importlib_metadata.PackageNotFoundError:
_accelerate_available = False

_xformers_available = importlib.util.find_spec("xformers") is not None
try:
_xformers_version = importlib_metadata.version("xformers")
if _torch_available:
import torch

if torch.__version__ < version.Version("1.12"):
raise ValueError("PyTorch should be >= 1.12")
logger.debug(f"Successfully imported xformers version {_xformers_version}")
except importlib_metadata.PackageNotFoundError:
_xformers_available = False


def is_torch_available():
return _torch_available
Expand Down Expand Up @@ -205,6 +217,10 @@ def is_scipy_available():
return _scipy_available


def is_xformers_available():
return _xformers_available


def is_accelerate_available():
return _accelerate_available

Expand Down