Skip to content

Commit

Permalink
allow for one to pass in a tuple to use_linear_attn kwargs in unet, t…
Browse files Browse the repository at this point in the history
…o finely specify which layers get sparse linear attention, and to turn off attention altogether
  • Loading branch information
lucidrains committed Aug 30, 2022
1 parent 59b5e55 commit d4c45ab
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
27 changes: 20 additions & 7 deletions imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,9 @@ def __init__(
layer_attns_depth = cast_tuple(layer_attns_depth, num_layers)
layer_cross_attns = cast_tuple(layer_cross_attns, num_layers)

use_linear_attn = cast_tuple(use_linear_attn, num_layers)
use_linear_cross_attn = cast_tuple(use_linear_cross_attn, num_layers)

assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))])

# downsample klass
Expand All @@ -1279,20 +1282,24 @@ def __init__(
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)

layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns]
layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, use_linear_attn, use_linear_cross_attn]
reversed_layer_params = list(map(reversed, layer_params))

# downsampling layers

skip_connect_dims = [] # keep track of skip connection dimensions

for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn) in enumerate(zip(in_out, *layer_params)):
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(in_out, *layer_params)):
is_last = ind >= (num_resolutions - 1)

layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn
layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None

transformer_block_klass = TransformerBlock if layer_attn else (LinearAttentionTransformerBlock if use_linear_attn else Identity)
if layer_attn:
transformer_block_klass = TransformerBlock
elif layer_use_linear_attn:
transformer_block_klass = LinearAttentionTransformerBlock
else:
transformer_block_klass = Identity

current_dim = dim_in

Expand Down Expand Up @@ -1336,11 +1343,17 @@ def __init__(

upsample_fmap_dims = []

for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)):
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)):
is_last = ind == (len(in_out) - 1)
layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn

layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
transformer_block_klass = TransformerBlock if layer_attn else (LinearAttentionTransformerBlock if use_linear_attn else Identity)

if layer_attn:
transformer_block_klass = TransformerBlock
elif layer_use_linear_attn:
transformer_block_klass = LinearAttentionTransformerBlock
else:
transformer_block_klass = Identity

skip_connect_dim = skip_connect_dims.pop()

Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.11.1'
__version__ = '1.11.2'

0 comments on commit d4c45ab

Please sign in to comment.