Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how to reproduce memory snapshot in doc? #112

Open
LucQueen opened this issue Feb 19, 2024 · 5 comments
Open

how to reproduce memory snapshot in doc? #112

LucQueen opened this issue Feb 19, 2024 · 5 comments

Comments

@LucQueen
Copy link

LucQueen commented Feb 19, 2024

hi,how to reproduce memory snapshot in doc?
image

what i get is
image

I‘m very confused the reason that can not get ‘add_decomposed_rel_pos’ stack informations in memory snapshot, and how to get full stack backtrace.
The torch version is 2.2, following up instructions in https://github.com/pytorch-labs/segment-anything-fast/tree/main/experiments#installation-instructions
Looking forward to a reply.

@cpuhrsch
Copy link
Contributor

Hi @LucQueen - which GPU type are you running on? Thank you.

@LucQueen
Copy link
Author

@cpuhrsch thanks for your reply. The GPU type is A100 80G SXM.

@cpuhrsch
Copy link
Contributor

@LucQueen - oh ok! That's probably because you're using the fused kernel. So add_decomposed_rel_pos has been shortened and now doesn't materialize the full attention mask instead. Instead we're using flash_4 to fuse the construction, which is a lot more memory efficient. The reason you're seeing a bigger memory footprint is probably because you're using a larger batch size than from that snapshot. That snapshot is from the unmodified segment-anything.

See

if self.use_rel_pos:
rel_h, rel_w = add_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
q = q.view(B, self.num_heads, H * W, -1)
k = k.view(B, self.num_heads, H * W, -1)
v = v.view(B, self.num_heads, H * W, -1)
if self.use_rel_pos:
rel_h = rel_h.view(B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
rel_w = rel_w.view(B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
# attn_bias = (rel_h + rel_w).view(B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4))
# x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
x = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w)
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)

@LucQueen
Copy link
Author

@cpuhrsch thanks for your reply! I am using batch-size 16, whereas doc is using batch-size 8, so I am seeing a bigger memory footprint. But I'm still very confused, I see memory snapshot in doc is also using fused kernel, you can see it by marked in red box from the picture, why I can not get ‘add_decomposed_el_pos’ stack informations in memory snapshot.
image

@cpuhrsch
Copy link
Contributor

@LucQueen - Ah! Hm, I'm not sure. Is your picture from the latest version of segment-anything-fast?

The picture you reference is from a section within the blog and not based on the most recent version of segment-anything-fast. It was recorded from an earlier version without the fused kernels.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants