Skip to content

Commit

Permalink
Enable even larger images with one simple torch.nn.functional.silu im…
Browse files Browse the repository at this point in the history
…port (invoke-ai#653)

Fixes:
File "stable-diffusion/ldm/modules/diffusionmodules/model.py", line 37, in nonlinearity
    return x*torch.sigmoid(x)
RuntimeError: CUDA out of memory. Tried to allocate 1.56 GiB [..]

Now up to 1536x1280 is possible on 8GB VRAM.
Also remove unused SiLU class.
  • Loading branch information
mh-dm authored and afiaka87 committed Sep 19, 2022
1 parent 16eeef3 commit b37680c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 21 deletions.
26 changes: 11 additions & 15 deletions ldm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math
import torch
import torch.nn as nn
from torch.nn.functional import silu
import numpy as np
from einops import rearrange

Expand Down Expand Up @@ -32,11 +33,6 @@ def get_timestep_embedding(timesteps, embedding_dim):
return emb


def nonlinearity(x):
# swish
return x*torch.sigmoid(x)


def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)

Expand Down Expand Up @@ -122,14 +118,14 @@ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,

def forward(self, x, temb):
h = self.norm1(x)
h = nonlinearity(h)
h = silu(h)
h = self.conv1(h)

if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
h = h + self.temb_proj(silu(temb))[:,:,None,None]

h = self.norm2(h)
h = nonlinearity(h)
h = silu(h)
h = self.dropout(h)
h = self.conv2(h)

Expand Down Expand Up @@ -368,7 +364,7 @@ def forward(self, x, t=None, context=None):
assert t is not None
temb = get_timestep_embedding(t, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = silu(temb)
temb = self.temb.dense[1](temb)
else:
temb = None
Expand Down Expand Up @@ -402,7 +398,7 @@ def forward(self, x, t=None, context=None):

# end
h = self.norm_out(h)
h = nonlinearity(h)
h = silu(h)
h = self.conv_out(h)
return h

Expand Down Expand Up @@ -499,7 +495,7 @@ def forward(self, x):

# end
h = self.norm_out(h)
h = nonlinearity(h)
h = silu(h)
h = self.conv_out(h)
return h

Expand Down Expand Up @@ -611,7 +607,7 @@ def forward(self, z):
return h

h = self.norm_out(h)
h = nonlinearity(h)
h = silu(h)
h = self.conv_out(h)
if self.tanh_out:
h = torch.tanh(h)
Expand Down Expand Up @@ -649,7 +645,7 @@ def forward(self, x):
x = layer(x)

h = self.norm_out(x)
h = nonlinearity(h)
h = silu(h)
x = self.conv_out(h)
return x

Expand Down Expand Up @@ -697,7 +693,7 @@ def forward(self, x):
if i_level != self.num_resolutions - 1:
h = self.upsample_blocks[k](h)
h = self.norm_out(h)
h = nonlinearity(h)
h = silu(h)
h = self.conv_out(h)
return h

Expand Down Expand Up @@ -873,7 +869,7 @@ def forward(self,x):
z_fs = self.encode_with_pretrained(x)
z = self.proj_norm(z_fs)
z = self.proj(z)
z = nonlinearity(z)
z = silu(z)

for submodel, downmodel in zip(self.model,self.downsampler):
z = submodel(z,temb=None)
Expand Down
6 changes: 0 additions & 6 deletions ldm/modules/diffusionmodules/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,12 +252,6 @@ def normalization(channels):
return GroupNorm32(32, channels)


# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)


class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
Expand Down

0 comments on commit b37680c

Please sign in to comment.