Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Adding Visual Attention #329

Merged
merged 1 commit into from
Jun 9, 2022
Merged

[feat] Adding Visual Attention #329

merged 1 commit into from
Jun 9, 2022

Conversation

blefaudeux
Copy link
Contributor

@blefaudeux blefaudeux commented Jun 8, 2022

What does this PR do?

Fixes #319. Note that to reproduce the paper you need the Conv2DFeedforward introduced here #321, and a metaformer-like structure

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
  • Did you update the changelog? (if needed)

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 8, 2022
@blefaudeux
Copy link
Contributor Author

cc @mannatsingh if you're interested in these things

@blefaudeux blefaudeux force-pushed the visual_attention branch 2 times, most recently from c9282b7 to bf34467 Compare June 8, 2022 23:03
@blefaudeux blefaudeux changed the title [DRAFT][feat] Adding Visual Attention [feat] Adding Visual Attention Jun 8, 2022
@blefaudeux blefaudeux requested review from fmassa and dianaml0 and removed request for fmassa and dianaml0 June 8, 2022 23:34
@blefaudeux
Copy link
Contributor Author

blefaudeux commented Jun 8, 2022

training a 6.7M metaformer-derived model on cifar10 with the visual attention from this PR: (for the record resnet18 is 11M params and goes to 93%)
Screenshot from 2022-06-08 16-33-45

@blefaudeux blefaudeux requested a review from mannatsingh June 8, 2022 23:39
@blefaudeux blefaudeux force-pushed the visual_attention branch 2 times, most recently from 3adc3ac to 7fb47e6 Compare June 8, 2022 23:51
@@ -121,8 +121,8 @@ def forward(self, x):

# Adjust batch depending on the available memory on your machine.
# You can also use reversible layers to save memory
REF_BATCH = 512
BATCH = 512 # lower if not enough GPU memory
REF_BATCH = 768
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like a classic default for Cifar10

@@ -31,6 +31,7 @@ def __init__(
num_classes=10,
dim=384,
attention="scaled_dot_product",
feedforward="MLP",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about the defaults here, how to show that you can use these to repro "Visual Attention" for instance ? Should we show different presets ?

@@ -62,6 +62,10 @@ def __init__(
# This operator does not really handle q,k,v
self.requires_same_k_q_dimensions = True

# This attention requires the 2d structure out of the context,
# implictly assumed to be a squared length
self.requires_squared_context = True
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was already true before, but not formalized like this, I think it's cleaner ? "pooling" (PoolingFormer) and "visual" both recover the 2d structure of and assume a squared context length for that

H = int(math.sqrt(HW))
assert H * H == HW

x = q.transpose(-2, -1).reshape(B, C, H, H)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've not benchmarked that, but maybe that it's beneficial to .contiguous() here, depending on the Conv2D kernels

@blefaudeux blefaudeux merged commit efdd15a into main Jun 9, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support Visual Attention Network
3 participants