Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention #7977

Merged
merged 12 commits into from
Aug 6, 2024
Merged

Flash attention #7977

merged 12 commits into from
Aug 6, 2024

Conversation

virginiafdez
Copy link
Contributor

@virginiafdez virginiafdez commented Aug 1, 2024

Fixes #7944.

Description

In response to Issue #7944, I added the new functionality scaled_dot_product_attention from PyTorch to re-enable flash attention, present in the original MONAI Generative Models repository. This is allowed for torch >= 2.0 and when argument save_attn = False. Errors are raised otherwise. I ran quick tests and added some checks on test_selfattention and test_crossattention scripts to make sure the outputs are the same as not using flash attention.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Virginia Fernandez and others added 3 commits July 31, 2024 23:09
…to issue Project-MONAI#7946.

Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk>
…h >= 2.0) to cross attention and self attention blocks, and addition of parameters to diffusion model unet and to transformer block. Modification of tests to check this functionality.

Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk>
@ericspod
Copy link
Member

ericspod commented Aug 1, 2024

A few minor comments but looks good otherwise. I would like others to test this change and see if memory performance improves. Thanks!

…h >= 2.0) to cross attention and self attention blocks, and addition of parameters to diffusion model unet and to transformer block. Modification of tests to check this functionality.

>>> Implementation of proposed corrections

Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk>
Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk>
…h >= 2.0) to cross attention and self attention blocks, and addition of parameters to diffusion model unet and to transformer block. Modification of tests to check this functionality.

>>> Implementation of proposed corrections:
- Addition of causal, dropout and scale to the call to scaled_dot_product_attention
- For this, addition of self.dropout_rate as an attribute
- Raising error when rel_pos_embedding is not None and use_flash_attention is True
- Fix of docstrings that had gone wrong (in cross and self attention and transformer block)
- Addition of two tests to self and cross attention blocks tests to account for the rel_pos_embedding error and to make sure that the causal = True call works.

Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk>
@Nic-Ma Nic-Ma requested a review from yiheng-wang-nv August 3, 2024 10:18
@Nic-Ma
Copy link
Contributor

Nic-Ma commented Aug 3, 2024

@yiheng-wang-nv Could you please help also review this PR?

Thanks in advance.

…h >= 2.0) to cross attention and self attention blocks, and addition of parameters to diffusion model unet and to transformer block. Modification of tests to check this functionality.

>>>> It was necessary to transpose query, value and key passed to the PyTorch flash attention module to get a behavior that is consistent with the xformers and no flash one, and then to transpose back the result. Behavior this way is consistent with xformers.

Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk>
@KumoLiu
Copy link
Contributor

KumoLiu commented Aug 5, 2024

Perhaps we need also include the test @ericspod shared which compare the result between without flash attention and attention that Pytorch says is the same as F.scaled_dot_product_attention. The dimensions appear to be highly prone to error, and additional comprehensive testing is necessary to prevent them. Thanks!

@ericspod
Copy link
Member

ericspod commented Aug 5, 2024

Related issues: #7991 #7992

Virginia Fernandez added 3 commits August 6, 2024 13:33
In particular:
- modified this line <att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf"))> when use_causal is True in self_attention to <att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))>
- added pertinent transpose calls to cross attention to ensure that the behaviour matches that of xops and that the code works, as well, for flash_attention=False.
- added SkipIfPytorch[...] clause before the test_shape in test_cross_attention to make sure it does not error out for cases in the case block that use flash_attention = True
- fix one rogue space on docstrings that had been added

I ran autofix and mypy. cross_attention was reformatted. mypy did not suggest changes.

Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk>
In particular:
- modified this line <att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf"))> when use_causal is True in self_attention to <att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))>
- added pertinent transpose calls to cross attention to ensure that the behaviour matches that of xops and that the code works, as well, for flash_attention=False.
- added SkipIfPytorch[...] clause before the test_shape in test_cross_attention to make sure it does not error out for cases in the case block that use flash_attention = True
- fix one rogue space on docstrings that had been added

I ran autofix and mypy. cross_attention was reformatted. mypy did not suggest changes.

>>>> FIX: I forgot to sign!

Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk>
In particular:
- modified this line <att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf"))> when use_causal is True in self_attention to <att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))>
- added pertinent transpose calls to cross attention to ensure that the behaviour matches that of xops and that the code works, as well, for flash_attention=False.
- added SkipIfPytorch[...] clause before the test_shape in test_cross_attention to make sure it does not error out for cases in the case block that use flash_attention = True
- fix one rogue space on docstrings that had been added

I ran autofix and mypy. cross_attention was reformatted. mypy did not suggest changes.

>>>> FIX: I forgot to sign!

Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk>
Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk>
@ericspod
Copy link
Member

ericspod commented Aug 6, 2024

@KumoLiu @mingxin-zheng @yiheng-wang-nv we have the flash attention added here. There are other changes we want to do in further PRs such as #7996 but we should merge this one first. I know there's issues of consistency between this layer and the one in GenerativeModels, I'm working on narrowing down where that's coming from. For now I would say we resolve whatever conversations are outstanding, run blossom, and hopefully merge soon. We can come back to these issues if there are any. Also @virginiafdez will be away from tomorrow and less able to work on things.

@KumoLiu
Copy link
Contributor

KumoLiu commented Aug 6, 2024

/build

@KumoLiu
Copy link
Contributor

KumoLiu commented Aug 6, 2024

run more test and will address the remaining issues in the next PR.

@KumoLiu KumoLiu merged commit 6c23fd0 into Project-MONAI:dev Aug 6, 2024
28 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
No open projects
Status: Done
Development

Successfully merging this pull request may close these issues.

Consideration of Flash Attention in Generative Components
6 participants