From 2a59e0c7632ffdd8441a0cb87ac6eb2d659adec7 Mon Sep 17 00:00:00 2001 From: MatthieuTPHR Date: Fri, 16 Sep 2022 16:12:49 +0200 Subject: [PATCH 01/15] 2x speedup using memory efficient attention --- src/diffusers/models/attention.py | 96 ++++++++++++++++++++++++++++++- 1 file changed, 93 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index af441ef86181..85c378a952a6 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -12,12 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Optional +import os +from inspect import isfunction +from typing import Any, Optional import torch import torch.nn.functional as F from torch import nn +import xformers +import xformers.ops +from einops import rearrange + + +_USE_MEMORY_EFFICIENT_ATTENTION = int(os.environ.get("USE_MEMORY_EFFICIENT_ATTENTION", 0)) == 1 + + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + class AttentionBlock(nn.Module): """ @@ -190,11 +209,12 @@ def __init__( checkpoint: bool = True, ): super().__init__() - self.attn1 = CrossAttention( + AttentionBuilder = MemoryEfficientCrossAttention if _USE_MEMORY_EFFICIENT_ATTENTION else CrossAttention + self.attn1 = AttentionBuilder( query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout ) # is a self-attention self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = CrossAttention( + self.attn2 = AttentionBuilder( query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout ) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) @@ -213,6 +233,76 @@ def forward(self, hidden_states, context=None): return hidden_states +class MemoryEfficientCrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def _maybe_init(self, x): + """ + Initialize the attention operator, if required We expect the head dimension to be exposed here, meaning that x + : B, Head, Length + """ + if self.attention_op is not None: + return + + _, K, M = x.shape + try: + self.attention_op = xformers.ops.AttentionOpDispatch( + dtype=x.dtype, + device=x.device, + k=K, + attn_bias_type=type(None), + has_dropout=False, + kv_len=M, + q_len=M, + ).op + + except NotImplementedError as err: + raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}") + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map( + lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h).contiguous(), + (q, k, v), + ) + + # init the attention op, if required, using the proper dimensions + self._maybe_init(q) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + # TODO: Use this directly in the attention operation, as a bias + if exists(mask): + raise NotImplementedError + # mask = rearrange(mask, "b ... -> b (...)") + # max_neg_value = -torch.finfo(sim.dtype).max + # mask = repeat(mask, "b j -> (b h) () j", h=h) + # sim.masked_fill_(~mask, max_neg_value) + + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + return self.to_out(out) + + class CrossAttention(nn.Module): r""" A cross attention layer. From 68e1ef5448ffab2a853d93275f5249585a9aced2 Mon Sep 17 00:00:00 2001 From: MatthieuTPHR Date: Fri, 16 Sep 2022 18:57:57 +0200 Subject: [PATCH 02/15] remove einops dependency --- src/diffusers/models/attention.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 85c378a952a6..5843553d3406 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -22,7 +22,6 @@ import xformers import xformers.ops -from einops import rearrange _USE_MEMORY_EFFICIENT_ATTENTION = int(os.environ.get("USE_MEMORY_EFFICIENT_ATTENTION", 0)) == 1 @@ -241,6 +240,7 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. self.scale = dim_head**-0.5 self.heads = heads + self.dim_head = dim_head self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) @@ -273,15 +273,18 @@ def _maybe_init(self, x): raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}") def forward(self, x, context=None, mask=None): - h = self.heads - q = self.to_q(x) context = default(context, x) k = self.to_k(context) v = self.to_v(context) + b, _, _ = q.shape q, k, v = map( - lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h).contiguous(), + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), (q, k, v), ) @@ -294,12 +297,12 @@ def forward(self, x, context=None, mask=None): # TODO: Use this directly in the attention operation, as a bias if exists(mask): raise NotImplementedError - # mask = rearrange(mask, "b ... -> b (...)") - # max_neg_value = -torch.finfo(sim.dtype).max - # mask = repeat(mask, "b j -> (b h) () j", h=h) - # sim.masked_fill_(~mask, max_neg_value) - - out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) return self.to_out(out) From db557e8ce0555bdfd74084417216a7835d28d293 Mon Sep 17 00:00:00 2001 From: MatthieuTPHR Date: Thu, 22 Sep 2022 09:48:26 +0200 Subject: [PATCH 03/15] Swap K, M in op instantiation --- src/diffusers/models/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 5843553d3406..5557c2629417 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -257,7 +257,7 @@ def _maybe_init(self, x): if self.attention_op is not None: return - _, K, M = x.shape + _, M, K = x.shape try: self.attention_op = xformers.ops.AttentionOpDispatch( dtype=x.dtype, From 9d9aea05a77613d71da3669cdfb268689ab8b753 Mon Sep 17 00:00:00 2001 From: MatthieuTPHR Date: Thu, 22 Sep 2022 10:52:50 +0200 Subject: [PATCH 04/15] Simplify code, remove unnecessary maybe_init call and function, remove unused self.scale parameter --- src/diffusers/models/attention.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 5557c2629417..eaf2ea57b4c5 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -238,7 +238,6 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) - self.scale = dim_head**-0.5 self.heads = heads self.dim_head = dim_head @@ -249,29 +248,6 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) self.attention_op: Optional[Any] = None - def _maybe_init(self, x): - """ - Initialize the attention operator, if required We expect the head dimension to be exposed here, meaning that x - : B, Head, Length - """ - if self.attention_op is not None: - return - - _, M, K = x.shape - try: - self.attention_op = xformers.ops.AttentionOpDispatch( - dtype=x.dtype, - device=x.device, - k=K, - attn_bias_type=type(None), - has_dropout=False, - kv_len=M, - q_len=M, - ).op - - except NotImplementedError as err: - raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}") - def forward(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) @@ -288,9 +264,6 @@ def forward(self, x, context=None, mask=None): (q, k, v), ) - # init the attention op, if required, using the proper dimensions - self._maybe_init(q) - # actually compute the attention, what we cannot get enough of out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) From 54c9c1557b69e84e02b45eeb3621d942b479a314 Mon Sep 17 00:00:00 2001 From: MatthieuTPHR Date: Tue, 25 Oct 2022 14:20:16 +0200 Subject: [PATCH 05/15] make xformers a soft dependency --- src/diffusers/models/attention.py | 12 +++++++++--- src/diffusers/utils/import_utils.py | 11 +++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index eaf2ea57b4c5..81cb98b9a7d5 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -20,11 +20,17 @@ import torch.nn.functional as F from torch import nn -import xformers -import xformers.ops +from diffusers.utils.import_utils import is_xformers_available -_USE_MEMORY_EFFICIENT_ATTENTION = int(os.environ.get("USE_MEMORY_EFFICIENT_ATTENTION", 0)) == 1 +if is_xformers_available(): + import xformers + import xformers.ops + + _USE_MEMORY_EFFICIENT_ATTENTION = int(os.environ.get("USE_MEMORY_EFFICIENT_ATTENTION", 0)) == 1 +else: + xformers = None + _USE_MEMORY_EFFICIENT_ATTENTION = False def exists(val): diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 2a5f7f64dd07..a8d477b5ccbf 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -168,6 +168,13 @@ except importlib_metadata.PackageNotFoundError: _accelerate_available = False +_xformers_available = importlib.util.find_spec("xformers") is not None +try: + _xformers_version = importlib_metadata.version("xformers") + logger.debug(f"Successfully imported xformers version {_xformers_version}") +except importlib_metadata.PackageNotFoundError: + _xformers_available = False + def is_torch_available(): return _torch_available @@ -205,6 +212,10 @@ def is_scipy_available(): return _scipy_available +def is_xformers_available(): + return _xformers_available + + def is_accelerate_available(): return _accelerate_available From e59ea367a7622c253d5204e5d2680fce54e4dacb Mon Sep 17 00:00:00 2001 From: MatthieuTPHR Date: Mon, 31 Oct 2022 10:06:36 +0100 Subject: [PATCH 06/15] remove one-liner functions --- src/diffusers/models/attention.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 81cb98b9a7d5..62ad8a1648f6 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -33,16 +33,6 @@ _USE_MEMORY_EFFICIENT_ATTENTION = False -def exists(val): - return val is not None - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted @@ -242,7 +232,7 @@ class MemoryEfficientCrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) + context_dim = query_dim if context_dim is None else context_dim self.heads = heads self.dim_head = dim_head @@ -256,7 +246,7 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. def forward(self, x, context=None, mask=None): q = self.to_q(x) - context = default(context, x) + context = x if context is None else context k = self.to_k(context) v = self.to_v(context) @@ -274,7 +264,7 @@ def forward(self, x, context=None, mask=None): out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) # TODO: Use this directly in the attention operation, as a bias - if exists(mask): + if mask is not None: raise NotImplementedError out = ( out.unsqueeze(0) From 321c39028c25daf8e840f87d031273d923a9b4a7 Mon Sep 17 00:00:00 2001 From: MatthieuTPHR Date: Mon, 31 Oct 2022 11:12:50 +0100 Subject: [PATCH 07/15] change one letter variable to appropriate names --- src/diffusers/models/attention.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 62ad8a1648f6..def5e87cf804 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -13,7 +13,6 @@ # limitations under the License. import math import os -from inspect import isfunction from typing import Any, Optional import torch @@ -245,32 +244,32 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. self.attention_op: Optional[Any] = None def forward(self, x, context=None, mask=None): - q = self.to_q(x) + queries = self.to_q(x) context = x if context is None else context - k = self.to_k(context) - v = self.to_v(context) + keys = self.to_k(context) + values = self.to_v(context) - b, _, _ = q.shape - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, t.shape[1], self.heads, self.dim_head) + batch_size, _, _ = queries.shape + queries, keys, values = map( + lambda tensor: tensor.unsqueeze(3) + .reshape(batch_size, tensor.shape[1], self.heads, self.dim_head) .permute(0, 2, 1, 3) - .reshape(b * self.heads, t.shape[1], self.dim_head) + .reshape(batch_size * self.heads, tensor.shape[1], self.dim_head) .contiguous(), - (q, k, v), + (queries, keys, values), ) # actually compute the attention, what we cannot get enough of - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + out = xformers.ops.memory_efficient_attention(queries, keys, values, attn_bias=None, op=self.attention_op) # TODO: Use this directly in the attention operation, as a bias if mask is not None: raise NotImplementedError out = ( out.unsqueeze(0) - .reshape(b, self.heads, out.shape[1], self.dim_head) + .reshape(batch_size, self.heads, out.shape[1], self.dim_head) .permute(0, 2, 1, 3) - .reshape(b, out.shape[1], self.heads * self.dim_head) + .reshape(batch_size, out.shape[1], self.heads * self.dim_head) ) return self.to_out(out) From 7079dab4de3995b8e513ab7c96b3a3100bec6a55 Mon Sep 17 00:00:00 2001 From: MatthieuTPHR Date: Mon, 31 Oct 2022 12:06:57 +0100 Subject: [PATCH 08/15] Remove Env variable dependency, remove MemoryEfficientCrossAttention class and use enable_xformers_memory_efficient_attention method --- src/diffusers/models/attention.py | 85 ++++++------------- src/diffusers/models/unet_2d_blocks.py | 12 +++ src/diffusers/models/unet_2d_condition.py | 11 +++ .../pipeline_stable_diffusion.py | 18 ++++ 4 files changed, 67 insertions(+), 59 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index def5e87cf804..6afc2764a4d6 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -import os -from typing import Any, Optional +from typing import Optional import torch import torch.nn.functional as F @@ -25,11 +24,8 @@ if is_xformers_available(): import xformers import xformers.ops - - _USE_MEMORY_EFFICIENT_ATTENTION = int(os.environ.get("USE_MEMORY_EFFICIENT_ATTENTION", 0)) == 1 else: xformers = None - _USE_MEMORY_EFFICIENT_ATTENTION = False class AttentionBlock(nn.Module): @@ -163,6 +159,10 @@ def _set_attention_slice(self, slice_size): for block in self.transformer_blocks: block._set_attention_slice(slice_size) + def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for block in self.transformer_blocks: + block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + def forward(self, hidden_states, context=None): # note: if no context is given, cross-attention defaults to self-attention batch, channel, height, weight = hidden_states.shape @@ -203,12 +203,11 @@ def __init__( checkpoint: bool = True, ): super().__init__() - AttentionBuilder = MemoryEfficientCrossAttention if _USE_MEMORY_EFFICIENT_ATTENTION else CrossAttention - self.attn1 = AttentionBuilder( + self.attn1 = CrossAttention( query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout ) # is a self-attention self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = AttentionBuilder( + self.attn2 = CrossAttention( query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout ) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) @@ -220,6 +219,13 @@ def _set_attention_slice(self, slice_size): self.attn1._slice_size = slice_size self.attn2._slice_size = slice_size + def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + if is_xformers_available(): + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + else: + raise ModuleNotFoundError(name="xformers") + def forward(self, hidden_states, context=None): hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states @@ -227,53 +233,6 @@ def forward(self, hidden_states, context=None): return hidden_states -class MemoryEfficientCrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): - super().__init__() - inner_dim = dim_head * heads - context_dim = query_dim if context_dim is None else context_dim - - self.heads = heads - self.dim_head = dim_head - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) - self.attention_op: Optional[Any] = None - - def forward(self, x, context=None, mask=None): - queries = self.to_q(x) - context = x if context is None else context - keys = self.to_k(context) - values = self.to_v(context) - - batch_size, _, _ = queries.shape - queries, keys, values = map( - lambda tensor: tensor.unsqueeze(3) - .reshape(batch_size, tensor.shape[1], self.heads, self.dim_head) - .permute(0, 2, 1, 3) - .reshape(batch_size * self.heads, tensor.shape[1], self.dim_head) - .contiguous(), - (queries, keys, values), - ) - - # actually compute the attention, what we cannot get enough of - out = xformers.ops.memory_efficient_attention(queries, keys, values, attn_bias=None, op=self.attention_op) - - # TODO: Use this directly in the attention operation, as a bias - if mask is not None: - raise NotImplementedError - out = ( - out.unsqueeze(0) - .reshape(batch_size, self.heads, out.shape[1], self.dim_head) - .permute(0, 2, 1, 3) - .reshape(batch_size, out.shape[1], self.heads * self.dim_head) - ) - return self.to_out(out) - - class CrossAttention(nn.Module): r""" A cross attention layer. @@ -300,6 +259,7 @@ def __init__( # is split across the batch axis to save memory # You can set slice_size with `set_attention_slice` self._slice_size = None + self._use_memory_efficient_attention_xformers = False self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) @@ -340,11 +300,13 @@ def forward(self, hidden_states, context=None, mask=None): # TODO(PVP) - mask is currently never used. Remember to re-implement when used # attention, what we cannot get enough of - - if self._slice_size is None or query.shape[0] // self._slice_size == 1: - hidden_states = self._attention(query, key, value) + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value) else: - hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) # linear proj hidden_states = self.to_out[0](hidden_states) @@ -402,6 +364,11 @@ def _sliced_attention(self, query, key, value, sequence_length, dim): hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states + def _memory_efficient_attention_xformers(self, query, key, value): + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + class FeedForward(nn.Module): r""" diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index f4081c5c1cac..ae4fe2d8bba7 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -367,6 +367,10 @@ def set_attention_slice(self, slice_size): for attn in self.attentions: attn._set_attention_slice(slice_size) + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for attn in self.attentions: + attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): @@ -542,6 +546,10 @@ def set_attention_slice(self, slice_size): for attn in self.attentions: attn._set_attention_slice(slice_size) + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for attn in self.attentions: + attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): output_states = () @@ -1117,6 +1125,10 @@ def set_attention_slice(self, slice_size): self.gradient_checkpointing = False + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for attn in self.attentions: + attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + def forward( self, hidden_states, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index d271b78a6525..7f7f3ecd4435 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -225,6 +225,17 @@ def set_attention_slice(self, slice_size): if hasattr(block, "attentions") and block.attentions is not None: block.set_attention_slice(slice_size) + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for block in self.down_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + for block in self.up_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): module.gradient_checkpointing = value diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 5927f36b12a1..3c1eb734a49d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -113,6 +113,24 @@ def __init__( feature_extractor=feature_extractor, ) + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" Enable sliced attention computation. From eff0c4275ea3a2ec061c463f35b686ccfe21a109 Mon Sep 17 00:00:00 2001 From: MatthieuTPHR Date: Mon, 31 Oct 2022 14:02:50 +0100 Subject: [PATCH 09/15] Add memory efficient attention toggle to img2img and inpaint pipelines --- .../pipeline_stable_diffusion_img2img.py | 18 ++++++++++++++++++ .../pipeline_stable_diffusion_inpaint.py | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 00c364f8e5e3..e61fb27acc1e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -151,6 +151,24 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `set_attention_slice` self.enable_attention_slicing(None) + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + @torch.no_grad() def __call__( self, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 57f9b65716ee..bbe6ee60832c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -151,6 +151,24 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + @torch.no_grad() def __call__( self, From 3f109ca488d2acd6c0dcb4454ac1b06de946522f Mon Sep 17 00:00:00 2001 From: MatthieuTPHR Date: Mon, 31 Oct 2022 17:13:33 +0100 Subject: [PATCH 10/15] Clearer management of xformers' availability --- src/diffusers/models/attention.py | 25 ++++++++++++++++++++++--- src/diffusers/utils/import_utils.py | 4 ++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 6afc2764a4d6..1f9cf641c32d 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -220,11 +220,30 @@ def _set_attention_slice(self, slice_size): self.attn2._slice_size = slice_size def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): - if is_xformers_available(): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - else: - raise ModuleNotFoundError(name="xformers") def forward(self, hidden_states, context=None): hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index a8d477b5ccbf..cc1315e0bec5 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -19,6 +19,8 @@ import sys from collections import OrderedDict +import torch + from packaging import version from . import logging @@ -171,6 +173,8 @@ _xformers_available = importlib.util.find_spec("xformers") is not None try: _xformers_version = importlib_metadata.version("xformers") + if torch.__version__ < version.Version("1.12"): + raise ValueError("PyTorch should be >= 1.12") logger.debug(f"Successfully imported xformers version {_xformers_version}") except importlib_metadata.PackageNotFoundError: _xformers_available = False From 13b187eca8431ac755d22347ee6c254fc256bb5e Mon Sep 17 00:00:00 2001 From: MatthieuTPHR Date: Mon, 31 Oct 2022 17:40:04 +0100 Subject: [PATCH 11/15] update optimizations markdown to add info about memory efficient attention --- docs/source/optimization/fp16.mdx | 36 +++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/docs/source/optimization/fp16.mdx b/docs/source/optimization/fp16.mdx index f12c067ba5ee..e57442eed398 100644 --- a/docs/source/optimization/fp16.mdx +++ b/docs/source/optimization/fp16.mdx @@ -290,3 +290,39 @@ pipe.unet = TracedUNet() with torch.inference_mode(): image = pipe([prompt] * 1, num_inference_steps=50).images[0] ``` + + +## Memory Efficient Attention +Recent work on optimizing the bandwitdh in the attention block have generated huge speed ups and gains in GPU memory usage. The most recent being Flash Attention (from @tridao, [code](https://github.com/HazyResearch/flash-attention), [paper](https://arxiv.org/pdf/2205.14135.pdf)) . +Here are the speedups we obtain on a few Nvidia GPUs: +| GPU | Base Attention FP16 | Memory Efficient Attention FP16 | +|------------------ |--------------------- |--------------------------------- | +| NVIDIA Tesla T4 | 3.5it/s | 5.5it/s | +| NVIDIA 3060 RTX | 4.6it/s | 7.8it/s | +| NVIDIA A10G | 8.88it/s | 15.6it/s | +| NVIDIA RTX A6000 | 11.7it/s | 21.09it/s | +| A100-SXM4-40GB | 18.6it/s | 29.it/s | +| A100-SXM-80GB | 18.7it/s | 29.5it/s | + +To leverage it just make sure you have: + - PyTorch > 1.12 + - Cuda available + - Installed the [xformers](https://github.com/facebookresearch/xformers) library +```python +from diffusers import StableDiffusionPipeline +import torch + +pipe = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + revision="fp16", + torch_dtype=torch.float16, +).to("cuda") + +pipe.enable_xformers_memory_efficient_attention() + +with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): + sample = pipe("a small cat") + +# You can disable it via +pipe.disable_xformers_memory_efficient_attention() +``` \ No newline at end of file From 24e71bece9266123f9976ea9d6fdeac2971ccb37 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Mon, 31 Oct 2022 18:54:14 +0000 Subject: [PATCH 12/15] add benchmarks for TITAN RTX --- docs/source/optimization/fp16.mdx | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/source/optimization/fp16.mdx b/docs/source/optimization/fp16.mdx index e57442eed398..eb4fa3588029 100644 --- a/docs/source/optimization/fp16.mdx +++ b/docs/source/optimization/fp16.mdx @@ -22,6 +22,7 @@ We present some techniques and ideas to optimize 🤗 Diffusers _inference_ for | fp16 | 3.61s | x2.63 | | channels last | 3.30s | x2.88 | | traced UNet | 3.21s | x2.96 | +| memory efficient attention | 2.63s | x3.61 | obtained on NVIDIA TITAN RTX by generating a single image of size 512x512 from @@ -301,6 +302,7 @@ Here are the speedups we obtain on a few Nvidia GPUs: | NVIDIA 3060 RTX | 4.6it/s | 7.8it/s | | NVIDIA A10G | 8.88it/s | 15.6it/s | | NVIDIA RTX A6000 | 11.7it/s | 21.09it/s | +| NVIDIA TITAN RTX | 12.51it/s | 18.22it/s | | A100-SXM4-40GB | 18.6it/s | 29.it/s | | A100-SXM-80GB | 18.7it/s | 29.5it/s | @@ -323,6 +325,6 @@ pipe.enable_xformers_memory_efficient_attention() with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): sample = pipe("a small cat") -# You can disable it via -pipe.disable_xformers_memory_efficient_attention() +# optional: You can disable it via +# pipe.disable_xformers_memory_efficient_attention() ``` \ No newline at end of file From 1fba7eb188cc6abbf41f8fcd2db5d847cfd7d28e Mon Sep 17 00:00:00 2001 From: MatthieuTPHR Date: Mon, 31 Oct 2022 20:04:06 +0100 Subject: [PATCH 13/15] More detailed explanation of how the mem eff benchmark were ran --- docs/source/optimization/fp16.mdx | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/optimization/fp16.mdx b/docs/source/optimization/fp16.mdx index eb4fa3588029..b804ee8301d0 100644 --- a/docs/source/optimization/fp16.mdx +++ b/docs/source/optimization/fp16.mdx @@ -295,7 +295,8 @@ with torch.inference_mode(): ## Memory Efficient Attention Recent work on optimizing the bandwitdh in the attention block have generated huge speed ups and gains in GPU memory usage. The most recent being Flash Attention (from @tridao, [code](https://github.com/HazyResearch/flash-attention), [paper](https://arxiv.org/pdf/2205.14135.pdf)) . -Here are the speedups we obtain on a few Nvidia GPUs: +Here are the speedups we obtain on a few Nvidia GPUs when running the inference at 512x512 with a batch size of 1 (one prompt): + | GPU | Base Attention FP16 | Memory Efficient Attention FP16 | |------------------ |--------------------- |--------------------------------- | | NVIDIA Tesla T4 | 3.5it/s | 5.5it/s | @@ -325,6 +326,6 @@ pipe.enable_xformers_memory_efficient_attention() with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): sample = pipe("a small cat") -# optional: You can disable it via +# optional: You can disable it via # pipe.disable_xformers_memory_efficient_attention() ``` \ No newline at end of file From 0f75d572eb103a5b824b6732e69f028519f0f259 Mon Sep 17 00:00:00 2001 From: MatthieuTPHR Date: Mon, 31 Oct 2022 20:15:23 +0100 Subject: [PATCH 14/15] Removing autocast from optimization markdown --- docs/source/optimization/fp16.mdx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/optimization/fp16.mdx b/docs/source/optimization/fp16.mdx index b804ee8301d0..4371daacc903 100644 --- a/docs/source/optimization/fp16.mdx +++ b/docs/source/optimization/fp16.mdx @@ -296,7 +296,7 @@ with torch.inference_mode(): ## Memory Efficient Attention Recent work on optimizing the bandwitdh in the attention block have generated huge speed ups and gains in GPU memory usage. The most recent being Flash Attention (from @tridao, [code](https://github.com/HazyResearch/flash-attention), [paper](https://arxiv.org/pdf/2205.14135.pdf)) . Here are the speedups we obtain on a few Nvidia GPUs when running the inference at 512x512 with a batch size of 1 (one prompt): - + | GPU | Base Attention FP16 | Memory Efficient Attention FP16 | |------------------ |--------------------- |--------------------------------- | | NVIDIA Tesla T4 | 3.5it/s | 5.5it/s | @@ -323,7 +323,7 @@ pipe = StableDiffusionPipeline.from_pretrained( pipe.enable_xformers_memory_efficient_attention() -with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): +with torch.inference_mode(): sample = pipe("a small cat") # optional: You can disable it via From 3db0a2fbd8acc6617a59693c92d18ea71f303aef Mon Sep 17 00:00:00 2001 From: MatthieuTPHR Date: Tue, 1 Nov 2022 21:19:12 +0100 Subject: [PATCH 15/15] import_utils: import torch only if is available --- src/diffusers/utils/import_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index cc1315e0bec5..4ea02dcc94da 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -19,8 +19,6 @@ import sys from collections import OrderedDict -import torch - from packaging import version from . import logging @@ -173,8 +171,11 @@ _xformers_available = importlib.util.find_spec("xformers") is not None try: _xformers_version = importlib_metadata.version("xformers") - if torch.__version__ < version.Version("1.12"): - raise ValueError("PyTorch should be >= 1.12") + if _torch_available: + import torch + + if torch.__version__ < version.Version("1.12"): + raise ValueError("PyTorch should be >= 1.12") logger.debug(f"Successfully imported xformers version {_xformers_version}") except importlib_metadata.PackageNotFoundError: _xformers_available = False