Skip to content

Commit

Permalink
Merge branch 'develop' into ppdiffuesers_ldm_to_original
Browse files Browse the repository at this point in the history
  • Loading branch information
JunnYu authored Nov 18, 2022
2 parents fc0b22b + 8842fbf commit a781506
Showing 1 changed file with 22 additions and 27 deletions.
49 changes: 22 additions & 27 deletions ppdiffusers/ppdiffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ def __init__(self,
context_dim = context_dim if context_dim is not None else query_dim

self.scale = dim_head**-0.5
self.heads = heads
self.num_heads = heads
self.head_dim = inner_dim // heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
Expand All @@ -281,46 +282,32 @@ def __init__(self,
nn.Dropout(dropout))

def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(
[batch_size, seq_len, head_size, dim // head_size])
tensor = tensor.transpose([0, 2, 1, 3]).reshape(
[batch_size * head_size, seq_len, dim // head_size])
tensor = tensor.reshape([0, 0, self.num_heads, self.head_dim])
tensor = tensor.transpose([0, 2, 1, 3])
return tensor

def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(
[batch_size // head_size, head_size, seq_len, dim])
tensor = tensor.transpose([0, 2, 1, 3]).reshape(
[batch_size // head_size, seq_len, dim * head_size])
tensor = tensor.transpose([0, 2, 1, 3])
tensor = tensor.reshape([0, 0, tensor.shape[2] * tensor.shape[3]])
return tensor

def forward(self, hidden_states, context=None, mask=None):
batch_size, sequence_length, _ = hidden_states.shape

query = self.to_q(hidden_states)
context = context if context is not None else hidden_states
key = self.to_k(context)
value = self.to_v(context)

dim = query.shape[-1]

query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)

# TODO(PVP) - mask is currently never used. Remember to re-implement when used

# attention, what we cannot get enough of

if self._slice_size is None or query.shape[0] // self._slice_size == 1:
if self._slice_size is None:
hidden_states = self._attention(query, key, value)
else:
hidden_states = self._sliced_attention(query, key, value,
sequence_length, dim)
hidden_states = self._sliced_attention(query, key, value)

return self.to_out(hidden_states)

Expand All @@ -335,14 +322,19 @@ def _attention(self, query, key, value):
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states

def _sliced_attention(self, query, key, value, sequence_length, dim):
batch_size_attention = query.shape[0]
def _sliced_attention(self, query, key, value):
# query, key, value flatten [bs*num_heads, seqlen, head_dim]
query = query.flatten(0, 1)
key = key.flatten(0, 1)
value = value.flatten(0, 1)

batch_size_attention, sequence_length = query.shape[0], query.shape[1]
hidden_states = paddle.zeros(
(batch_size_attention, sequence_length, dim // self.heads),
(batch_size_attention, sequence_length, self.head_dim),
dtype=query.dtype)
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[
0]
for i in range(hidden_states.shape[0] // slice_size):
slice_size = self._slice_size if self._slice_size is not None else batch_size_attention

for i in range(batch_size_attention // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
attn_slice = (paddle.matmul(query[start_idx:end_idx],
Expand All @@ -354,6 +346,9 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):

hidden_states[start_idx:end_idx] = attn_slice

# reshape back to [bs, num_heads, seqlen, head_dim]
hidden_states = hidden_states.reshape(
[-1, self.num_heads, sequence_length, self.head_dim])
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
Expand Down

0 comments on commit a781506

Please sign in to comment.