Skip to content

Commit

Permalink
partial cherrypick of @neonsecret's optimized attention CompVis#177
Browse files Browse the repository at this point in the history
I didn't implement the most consequential part (splitting the softmax in two) because M1 Mac is not so VRAM-constrained.
but I implemented the reference-freeing, and also freed x earlier.
  • Loading branch information
Birch-san committed Sep 11, 2022
1 parent 37fdde1 commit dab78e9
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,18 +172,22 @@ def forward(self, x, context=None, mask=None):

q = self.to_q(x)
context = default(context, x)
del x
k = self.to_k(context)
v = self.to_v(context)
del context

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k

if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
del mask

# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
Expand Down

7 comments on commit dab78e9

@neonsecret
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its a lot more then that..

@Birch-san
Copy link
Owner Author

@Birch-san Birch-san commented on dab78e9 Sep 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for reviewing!

I think the only part I missed was this?
image

ah, certainly I neglected to free the reference to sim.
ah, and you're re-using sim's storage (to hold the softmax() result and then the einsum() result)?

okay, I'll certainly add those.

but the sim[4:] split… what does this achieve? it reduces concurrency. so if I have enough VRAM (I have 64GB), presumably it's faster to avoid doing this?

@neonsecret
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

take at look at my current fork https://github.com/neonsecret/stable-diffusion
and it's changes. it's a whole lot more then that..

@Birch-san
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! I'll have a dig. it's a 91-file diff though, and in most cases starts new files altogether, so hard to find the important bits…
CompVis/stable-diffusion@main...neonsecret:stable-diffusion:main

any key files you'd recommend looking at?

@neonsecret
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's why you probably shouldn't try to merge it, it differs too much now

@Birch-san
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, I've reviewed the attention.py in your branch. certainly there's more there than was in CompVis#177.

however to my understanding, the attention.py changes are only to reduce memory usage, and come at the expense of inference speed?
I have the opposite problem. M1 GPUs are slow at inference, but have loads of VRAM.
we also cannot use torch.cuda.memory_stats(device), torch.cuda.mem_get_info(torch.cuda.current_device()) or torch.cuda.empty_cache() (because we do not have CUDA) and cannot use FP16.

is there anything you'd recommend for improving inference speed?

@Birch-san
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, are you load-balancing between CUDA and CPU? is that for speed or for memory?

Please sign in to comment.