-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flash attention #7977
Flash attention #7977
Conversation
…to issue Project-MONAI#7946. Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk>
…h >= 2.0) to cross attention and self attention blocks, and addition of parameters to diffusion model unet and to transformer block. Modification of tests to check this functionality. Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk>
A few minor comments but looks good otherwise. I would like others to test this change and see if memory performance improves. Thanks! |
…h >= 2.0) to cross attention and self attention blocks, and addition of parameters to diffusion model unet and to transformer block. Modification of tests to check this functionality. >>> Implementation of proposed corrections Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk>
…h >= 2.0) to cross attention and self attention blocks, and addition of parameters to diffusion model unet and to transformer block. Modification of tests to check this functionality. >>> Implementation of proposed corrections: - Addition of causal, dropout and scale to the call to scaled_dot_product_attention - For this, addition of self.dropout_rate as an attribute - Raising error when rel_pos_embedding is not None and use_flash_attention is True - Fix of docstrings that had gone wrong (in cross and self attention and transformer block) - Addition of two tests to self and cross attention blocks tests to account for the rel_pos_embedding error and to make sure that the causal = True call works. Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk>
@yiheng-wang-nv Could you please help also review this PR? Thanks in advance. |
…h >= 2.0) to cross attention and self attention blocks, and addition of parameters to diffusion model unet and to transformer block. Modification of tests to check this functionality. >>>> It was necessary to transpose query, value and key passed to the PyTorch flash attention module to get a behavior that is consistent with the xformers and no flash one, and then to transpose back the result. Behavior this way is consistent with xformers. Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk>
Perhaps we need also include the test @ericspod shared which compare the result between without flash attention and attention that Pytorch says is the same as |
In particular: - modified this line <att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf"))> when use_causal is True in self_attention to <att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))> - added pertinent transpose calls to cross attention to ensure that the behaviour matches that of xops and that the code works, as well, for flash_attention=False. - added SkipIfPytorch[...] clause before the test_shape in test_cross_attention to make sure it does not error out for cases in the case block that use flash_attention = True - fix one rogue space on docstrings that had been added I ran autofix and mypy. cross_attention was reformatted. mypy did not suggest changes. Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk>
In particular: - modified this line <att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf"))> when use_causal is True in self_attention to <att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))> - added pertinent transpose calls to cross attention to ensure that the behaviour matches that of xops and that the code works, as well, for flash_attention=False. - added SkipIfPytorch[...] clause before the test_shape in test_cross_attention to make sure it does not error out for cases in the case block that use flash_attention = True - fix one rogue space on docstrings that had been added I ran autofix and mypy. cross_attention was reformatted. mypy did not suggest changes. >>>> FIX: I forgot to sign! Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk>
In particular: - modified this line <att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf"))> when use_causal is True in self_attention to <att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))> - added pertinent transpose calls to cross attention to ensure that the behaviour matches that of xops and that the code works, as well, for flash_attention=False. - added SkipIfPytorch[...] clause before the test_shape in test_cross_attention to make sure it does not error out for cases in the case block that use flash_attention = True - fix one rogue space on docstrings that had been added I ran autofix and mypy. cross_attention was reformatted. mypy did not suggest changes. >>>> FIX: I forgot to sign! Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk>
Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk>
@KumoLiu @mingxin-zheng @yiheng-wang-nv we have the flash attention added here. There are other changes we want to do in further PRs such as #7996 but we should merge this one first. I know there's issues of consistency between this layer and the one in GenerativeModels, I'm working on narrowing down where that's coming from. For now I would say we resolve whatever conversations are outstanding, run blossom, and hopefully merge soon. We can come back to these issues if there are any. Also @virginiafdez will be away from tomorrow and less able to work on things. |
/build |
run more test and will address the remaining issues in the next PR. |
Fixes #7944.
Description
In response to Issue #7944, I added the new functionality scaled_dot_product_attention from PyTorch to re-enable flash attention, present in the original MONAI Generative Models repository. This is allowed for torch >= 2.0 and when argument save_attn = False. Errors are raised otherwise. I ran quick tests and added some checks on test_selfattention and test_crossattention scripts to make sure the outputs are the same as not using flash attention.
Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.