Skip to content

Commit

Permalink
remove einops exts
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 6, 2024
1 parent b9a73db commit 4451ab7
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
9 changes: 7 additions & 2 deletions imagen_pytorch/imagen_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
from einops_exts.torch import EinopsToAndFrom

from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME

Expand Down Expand Up @@ -1501,7 +1500,7 @@ def __init__(
mid_dim = dims[-1]

self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_attn = EinopsToAndFrom('b c f h w', 'b (f h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_attn = Residual(Attention(mid_dim, **attn_kwargs)) if attend_at_middle else None
self.mid_temporal_peg = temporal_peg(mid_dim)
self.mid_temporal_attn = temporal_attn(mid_dim)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
Expand Down Expand Up @@ -1872,8 +1871,14 @@ def forward(
x = self.mid_block1(x, t, c, **conv_kwargs)

if exists(self.mid_attn):
x = rearrange(x, 'b c f h w -> b f h w c')
x, ps = pack([x], 'b * c')

x = self.mid_attn(x)

x, = unpack(x, ps, 'b * c')
x = rearrange(x, 'b f h w c -> b c f h w')

if not ignore_time:
x = self.mid_temporal_peg(x)
x = self.mid_temporal_attn(x)
Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.25.12'
__version__ = '1.26.0'
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,13 @@
'numpy',
'packaging',
'pillow',
'pydantic<=2',
'pydantic>=2',
'pytorch-warmup',
'sentencepiece',
'torch>=1.6',
'torchvision',
'transformers',
'tqdm',
'einops_exts'
'tqdm'
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down

0 comments on commit 4451ab7

Please sign in to comment.