-
Notifications
You must be signed in to change notification settings - Fork 633
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
nvFuser integration for operation fusion (#357)
* added nvfuser implementation, benchmark for biasReluDropout * reformatted fuse pattern * revised benchamrking, nvfused patterns * adds BiasDropoutRes and BiasDropoutResLayernorm patterns, minor edits * unit testing for all fused patterns, minor edits * benchmarking for all nvfused patterns * mypy wip * benchmarking nvfuser patterns, adding plots, minor testing changes * fixing mypy errors * fixed benchmarking bug, minor test change * final benchmark plots, benchmmark edits * nvfuser documentation, minor edits * fixing functorch version error, documentation revisions * fixing circleci functorch errors, mypy errors * circleci config wip * circleci test wip * wip2 * testing revisions, circleci fixes, minor changes * changelog changes, fixes functorch flag bug * circle-ci fix * circle-ci spacing fix * build error wip * revised documentation, reverted circleci config * Fix functorch errors, circleci issue, testing changes * updating changelog Co-authored-by: Chris Yuan <christopheryuan@learnfair1488.h2.fair> Co-authored-by: Chris Yuan <christopheryuan@learnfair1481.h2.fair> Co-authored-by: Chris Yuan <christopheryuan@learnfair1483.h2.fair> Co-authored-by: Chris Yuan <christopheryuan@learnfair1492.h2.fair> Co-authored-by: Chris Yuan <christopheryuan@learnfair1478.h2.fair> Co-authored-by: Chris Yuan <christopheryuan@learnfair1479.h2.fair> Co-authored-by: Chris Yuan <christopheryuan@learnfair1484.h2.fair> Co-authored-by: Chris Yuan <christopheryuan@learnfair1477.h2.fair>
- Loading branch information
1 parent
3a7b713
commit 089f826
Showing
114 changed files
with
864 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file added
BIN
+75.3 KB
...iasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_gelu_torch.float16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+75.3 KB
...iasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_gelu_torch.float32.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+83.4 KB
...iasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_relu_torch.float16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+76.7 KB
...iasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_relu_torch.float32.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+88.5 KB
...ationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_squared_relu_torch.float16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+81.6 KB
...ationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_squared_relu_torch.float32.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+78.4 KB
...BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_gelu_torch.float16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+75.6 KB
...BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_gelu_torch.float32.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+74.6 KB
...BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_relu_torch.float16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+76.7 KB
...BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_relu_torch.float32.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+89.4 KB
...vationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_squared_relu_torch.float16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+81.5 KB
...vationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_squared_relu_torch.float32.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+82.8 KB
...r/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_gelu_torch.float16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+77.1 KB
...r/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_gelu_torch.float32.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+82.5 KB
...r/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_relu_torch.float16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+76.9 KB
...r/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_relu_torch.float32.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+85.8 KB
...tivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_squared_relu_torch.float16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+80.4 KB
...tivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_squared_relu_torch.float32.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+70.1 KB
...er/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_gelu_torch.float16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+72.2 KB
...er/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_gelu_torch.float32.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+73.7 KB
...er/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_relu_torch.float16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+77.5 KB
...er/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_relu_torch.float32.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+76.2 KB
...ctivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_squared_relu_torch.float16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+80.5 KB
...ctivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_squared_relu_torch.float32.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+74.8 KB
...asActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_gelu_torch.float16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+72.5 KB
...asActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_gelu_torch.float32.png
Oops, something went wrong.
Binary file added
BIN
+75.6 KB
...asActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_relu_torch.float16.png
Oops, something went wrong.
Binary file added
BIN
+72.3 KB
...asActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_relu_torch.float32.png
Oops, something went wrong.
Binary file added
BIN
+72.5 KB
...tionDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_squared_relu_torch.float16.png
Oops, something went wrong.
Binary file added
BIN
+70.2 KB
...tionDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_squared_relu_torch.float32.png
Oops, something went wrong.
Binary file added
BIN
+75.4 KB
...iasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_gelu_torch.float16.png
Oops, something went wrong.
Binary file added
BIN
+71 KB
...iasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_gelu_torch.float32.png
Oops, something went wrong.
Binary file added
BIN
+73.2 KB
...iasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_relu_torch.float16.png
Oops, something went wrong.
Binary file added
BIN
+71.2 KB
...iasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_relu_torch.float32.png
Oops, something went wrong.
Binary file added
BIN
+73.6 KB
...ationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_squared_relu_torch.float16.png
Oops, something went wrong.
Binary file added
BIN
+70 KB
...ationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_squared_relu_torch.float32.png
Oops, something went wrong.
Binary file added
BIN
+72.5 KB
.../BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_gelu_torch.float16.png
Oops, something went wrong.
Binary file added
BIN
+70.2 KB
.../BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_gelu_torch.float32.png
Oops, something went wrong.
Binary file added
BIN
+73 KB
.../BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_relu_torch.float16.png
Oops, something went wrong.
Binary file added
BIN
+70 KB
.../BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_relu_torch.float32.png
Oops, something went wrong.
Binary file added
BIN
+70.6 KB
...ivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_squared_relu_torch.float16.png
Oops, something went wrong.
Binary file added
BIN
+68.9 KB
...ivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_squared_relu_torch.float32.png
Oops, something went wrong.
Binary file added
BIN
+68.4 KB
...r/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_gelu_torch.float16.png
Oops, something went wrong.
Binary file added
BIN
+66.4 KB
...r/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_gelu_torch.float32.png
Oops, something went wrong.
Binary file added
BIN
+66.8 KB
...r/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_relu_torch.float16.png
Oops, something went wrong.
Binary file added
BIN
+65.4 KB
...r/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_relu_torch.float32.png
Oops, something went wrong.
Binary file added
BIN
+71.2 KB
...tivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_squared_relu_torch.float16.png
Oops, something went wrong.
Binary file added
BIN
+68.9 KB
...tivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_squared_relu_torch.float32.png
Oops, something went wrong.
Binary file added
BIN
+59.6 KB
...fuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW+BW_False_torch.float16.png
Oops, something went wrong.
Binary file added
BIN
+59.8 KB
...fuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW+BW_False_torch.float32.png
Oops, something went wrong.
Binary file added
BIN
+66.1 KB
...vfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW+BW_True_torch.float16.png
Oops, something went wrong.
Binary file added
BIN
+63 KB
...vfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW+BW_True_torch.float32.png
Oops, something went wrong.
Binary file added
BIN
+62.9 KB
.../nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW_False_torch.float16.png
Oops, something went wrong.
Binary file added
BIN
+63 KB
.../nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW_False_torch.float32.png
Oops, something went wrong.
Binary file added
BIN
+60.3 KB
...s/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW_True_torch.float16.png
Oops, something went wrong.
Binary file added
BIN
+60.3 KB
...s/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW_True_torch.float32.png
Oops, something went wrong.
Binary file added
BIN
+63 KB
...user/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW+BW_False_torch.float16.png
Oops, something went wrong.
Binary file added
BIN
+61.3 KB
...user/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW+BW_False_torch.float32.png
Oops, something went wrong.
Binary file added
BIN
+63.4 KB
...fuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW+BW_True_torch.float16.png
Oops, something went wrong.
Binary file added
BIN
+61.2 KB
...fuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW+BW_True_torch.float32.png
Oops, something went wrong.
Binary file added
BIN
+56 KB
...nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW_False_torch.float16.png
Oops, something went wrong.
Binary file added
BIN
+56 KB
...nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW_False_torch.float32.png
Oops, something went wrong.
Binary file added
BIN
+57.2 KB
.../nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW_True_torch.float16.png
Oops, something went wrong.
Binary file added
BIN
+56.1 KB
.../nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW_True_torch.float32.png
Oops, something went wrong.
Binary file added
BIN
+67.6 KB
...lLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float16_post.png
Oops, something went wrong.
Binary file added
BIN
+65 KB
...alLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float16_pre.png
Oops, something went wrong.
Binary file added
BIN
+67.4 KB
...lLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float32_post.png
Oops, something went wrong.
Binary file added
BIN
+65.1 KB
...alLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float32_pre.png
Oops, something went wrong.
Binary file added
BIN
+67.3 KB
...alLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float16_post.png
Oops, something went wrong.
Binary file added
BIN
+66.3 KB
...ualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float16_pre.png
Oops, something went wrong.
Binary file added
BIN
+66.2 KB
...alLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float32_post.png
Oops, something went wrong.
Binary file added
BIN
+68.2 KB
...ualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float32_pre.png
Oops, something went wrong.
Binary file added
BIN
+65.5 KB
...dualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_False_torch.float16_post.png
Oops, something went wrong.
Binary file added
BIN
+64.8 KB
...idualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_False_torch.float16_pre.png
Oops, something went wrong.
Binary file added
BIN
+67.4 KB
...dualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_False_torch.float32_post.png
Oops, something went wrong.
Binary file added
BIN
+66.6 KB
...idualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_False_torch.float32_pre.png
Oops, something went wrong.
Binary file added
BIN
+64.7 KB
...idualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_True_torch.float16_post.png
Oops, something went wrong.
Binary file added
BIN
+64.1 KB
...sidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_True_torch.float16_pre.png
Oops, something went wrong.
Binary file added
BIN
+66.6 KB
...idualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_True_torch.float32_post.png
Oops, something went wrong.
Binary file added
BIN
+66.3 KB
...sidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_True_torch.float32_pre.png
Oops, something went wrong.
Binary file added
BIN
+65 KB
...LayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float16_post.png
Oops, something went wrong.
Binary file added
BIN
+63.8 KB
...lLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float16_pre.png
Oops, something went wrong.
Binary file added
BIN
+62.7 KB
...LayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float32_post.png
Oops, something went wrong.
Binary file added
BIN
+61.9 KB
...lLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float32_pre.png
Oops, something went wrong.
Binary file added
BIN
+65.2 KB
...lLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float16_post.png
Oops, something went wrong.
Binary file added
BIN
+64.8 KB
...alLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float16_pre.png
Oops, something went wrong.
Binary file added
BIN
+62.7 KB
...lLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float32_post.png
Oops, something went wrong.
Binary file added
BIN
+61.7 KB
...alLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float32_pre.png
Oops, something went wrong.
Binary file added
BIN
+62.6 KB
...ualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_False_torch.float16_post.png
Oops, something went wrong.
Binary file added
BIN
+62.3 KB
...dualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_False_torch.float16_pre.png
Oops, something went wrong.
Binary file added
BIN
+62.6 KB
...ualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_False_torch.float32_post.png
Oops, something went wrong.
Binary file added
BIN
+62.2 KB
...dualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_False_torch.float32_pre.png
Oops, something went wrong.
Binary file added
BIN
+60.7 KB
...dualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_True_torch.float16_post.png
Oops, something went wrong.
Binary file added
BIN
+60.7 KB
...idualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_True_torch.float16_pre.png
Oops, something went wrong.
Binary file added
BIN
+63.4 KB
...dualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_True_torch.float32_post.png
Oops, something went wrong.
Binary file added
BIN
+63.6 KB
...idualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_True_torch.float32_pre.png
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
How to Enable Fused Operations Using AOTAutograd and NVFuser | ||
=================================================================== | ||
|
||
AOT Autograd is a toolkit from FuncTorch_ which can be used to accelerate model training in xFormers. | ||
Broadly, it extracts a computational graph of the forward and backward passes of a model ahead of time. | ||
This allows for some joint graph optimizations enables deep learning compilers such as NVFuser_ to perform operator fusion. | ||
The `memory_efficient_fusion`_ wrapper function provides a convenient way to leverage AOTAutograd and NVFuser on GPU. | ||
|
||
.. _FuncTorch: https://pytorch.org/functorch/stable/ | ||
.. _NVFuser: https://github.com/pytorch/pytorch/blob/release/1.12/torch/csrc/jit/codegen/cuda/README.md | ||
.. _memory_efficient_fusion: https://pytorch.org/functorch/stable/generated/functorch.compile.memory_efficient_fusion.html#functorch.compile.memory_efficient_fusion | ||
|
||
XFormers uses `memory_efficient_fusion` to combine sequences of fusable operations together into single fused function layers. | ||
These parts can be found inside `xformers/components/nvfuser`. A notable example is `NVFusedBiasActivationDropout`, which is readily used inside the `MLP`_ feedforward component. | ||
|
||
.. _MLP: https://github.com/facebookresearch/xformers/blob/main/xformers/components/feedforward/mlp.py | ||
|
||
A benchmark of these fused patterns across some representative shapes shows significant speed increases compared to the unfused, | ||
Pytorch eager approach―up to 3.5x speedup for the forward pass and 2.2x for the forward and backward passes together. On average, peak memory usage of fused patterns is also lower, | ||
although we see some infrequent cases of up to 1.6x Pytorch peak memory usage on larger shapes. We also see better overall performance against our implementation of fused Bias, | ||
Activation, and Dropout using Triton (see_) as well. Full benchmark plots can be found here_. | ||
|
||
.. _see: https://github.com/facebookresearch/xformers/blob/main/xformers/triton/dropout.py | ||
.. _here: https://github.com/facebookresearch/xformers/tree/main/docs/plots/nvfuser | ||
|
||
Please note from README that the `_is_functorch_available` flag must be enabled for xFormers to use these optimizations. | ||
This allows the fused layers to be used and changes the behavior of the `MLP` feedforward component, | ||
causing it to default to using the fused `NVFusedBiasActivationDropout` layer. | ||
|
||
AOT Autograd offers a great deal a flexibility to the user, as `memory_efficient_fusion` can accept either a Python function or an entire `nn.Module` as input for fusion. | ||
Currently in xFormers, however, it is only used with Python function inputs because initial attempts with fusing xFormers layers and blocks have yielded memory issues and other CUDA errors. | ||
We are currently exploring further testing and benchmarking. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ Tutorials | |
|
||
sparse_vit | ||
blocksparse | ||
aotautograd_nvfuser | ||
extend_attentions | ||
use_attention | ||
pytorch_encoder | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import logging | ||
from collections import OrderedDict | ||
from contextlib import nullcontext | ||
|
||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from torch.cuda.amp.autocast_mode import autocast | ||
|
||
import xformers | ||
from xformers.components import Activation, ResidualNormStyle | ||
|
||
# Store original and possible flag setting | ||
flag_orig = xformers._is_functorch_available | ||
flag_new = True | ||
xformers._is_functorch_available = True | ||
|
||
|
||
_gpu_available = torch.cuda.is_available() | ||
|
||
try: | ||
import xformers.components.feedforward as ff | ||
from xformers.components.nvfuser import ( | ||
NVFusedBiasActivationDropout, | ||
NVFusedBiasDropoutRes, | ||
NVFusedBiasDropoutResLayerNorm, | ||
) | ||
from xformers.components.nvfuser.utils import build_nvfused | ||
except ImportError as e: | ||
logging.warning(f"Functorch is not available to run test_nvfuser.py. \nError {e}") | ||
flag_new = False | ||
|
||
xformers._is_functorch_available = flag_orig | ||
|
||
FUSED_PATTERNS = ( | ||
[ | ||
NVFusedBiasActivationDropout, | ||
NVFusedBiasDropoutRes, | ||
NVFusedBiasDropoutResLayerNorm, | ||
] | ||
if flag_new | ||
else [] | ||
) | ||
|
||
# Testing odd (non-power-of-two for instance) shapes on purpose | ||
SHAPES = [ | ||
(384, 512), | ||
(8, 384, 128), | ||
(8, 784, 512), | ||
(4, 16, 384), | ||
(4, 16, 1024), | ||
(2, 16, 2048), | ||
(2, 16, 4096), | ||
(1, 16, 12288), | ||
] | ||
|
||
BATCH = 4 | ||
SEQ = 256 | ||
EMBD = 16 | ||
LATENT = 128 | ||
DEVICES = [torch.device("cuda")] | ||
|
||
ACTIVATIONS = [ | ||
Activation.ReLU, | ||
Activation.GeLU, | ||
Activation.LeakyReLU, | ||
Activation.SquaredReLU, | ||
Activation.SmeLU, | ||
] | ||
|
||
|
||
@pytest.mark.skipif(not flag_new, reason="Functorch is not available") | ||
@pytest.mark.skipif(not _gpu_available, reason="GPU is not available") | ||
@pytest.mark.parametrize("fused_pattern", FUSED_PATTERNS) | ||
@pytest.mark.parametrize("shape", SHAPES) | ||
@pytest.mark.parametrize("amp", [False, True]) | ||
@pytest.mark.parametrize("bias", [False, True]) | ||
@pytest.mark.parametrize("activation", ACTIVATIONS) | ||
@pytest.mark.parametrize("p", [0, 0.1, 0.5]) | ||
@pytest.mark.parametrize( | ||
"layer_norm_style", [None, ResidualNormStyle.Pre, ResidualNormStyle.Post] | ||
) | ||
def test_nvfused_pattern_parity( | ||
fused_pattern: nn.Module, | ||
shape: tuple, | ||
amp: bool, | ||
bias: bool, | ||
activation: Activation, | ||
p: float, | ||
layer_norm_style: ResidualNormStyle, | ||
): | ||
# Enable global flag | ||
xformers._is_functorch_available = flag_new | ||
|
||
if ( | ||
fused_pattern != NVFusedBiasDropoutResLayerNorm | ||
and layer_norm_style != ResidualNormStyle.Pre | ||
): | ||
pytest.skip( | ||
"Layer norm style doesn't apply, the same relevant params already tested once." | ||
) | ||
|
||
torch.cuda.manual_seed_all(0) | ||
torch.random.manual_seed(0) | ||
x = torch.normal(0, 1, size=shape, device="cuda", requires_grad=True) | ||
x_cpu = x.clone().cpu() | ||
|
||
with autocast(enabled=amp), pytest.raises( | ||
ValueError | ||
) if layer_norm_style is None else nullcontext(): | ||
fused = build_nvfused( | ||
fused_pattern, shape, bias, activation, p, layer_norm_style | ||
) | ||
fused.train().cuda() | ||
nvfused_res = fused(x, x) if fused.requires_residual else fused(x) | ||
fused.cpu() | ||
torch_res = ( | ||
fused(x_cpu, x_cpu).cuda() | ||
if fused.requires_residual | ||
else fused(x_cpu).cuda() | ||
) | ||
|
||
# Check if operation was actually fused | ||
assert isinstance( | ||
nvfused_res.grad_fn, torch.autograd.function.BackwardCFunction | ||
) | ||
|
||
if p == 0.0: | ||
# Check fused and unfused paths are the same | ||
assert torch.allclose(torch_res, nvfused_res, atol=1e-6, rtol=1e-2) | ||
|
||
# Restore original flag configuration | ||
xformers._is_functorch_available = flag_orig | ||
|
||
|
||
@pytest.mark.skipif(not flag_new, reason="Functorch is not available") | ||
@pytest.mark.skipif(not _gpu_available, reason="GPU is not available") | ||
@pytest.mark.parametrize("activation", ACTIVATIONS) | ||
@pytest.mark.parametrize("device", DEVICES) | ||
@pytest.mark.parametrize("p", [0, 0.1, 0.5]) | ||
def test_nvfused_mlp(activation: Activation, device: torch.device, p: float): | ||
test_config = { | ||
"name": "MLP", | ||
"dim_model": LATENT, | ||
"dropout": p, | ||
"activation": activation, | ||
"hidden_layer_multiplier": 4, | ||
"bias": False, | ||
} | ||
# Enable global flag | ||
xformers._is_functorch_available = flag_new | ||
|
||
torch.random.manual_seed(0) | ||
torch.cuda.manual_seed_all(0) | ||
|
||
mlp = ff.build_feedforward(test_config) | ||
# Creates non-fused default MLP | ||
xformers._is_functorch_available = False | ||
mlp_default = ff.build_feedforward(test_config) | ||
xformers._is_functorch_available = flag_new | ||
|
||
inputs = torch.rand(BATCH, SEQ, LATENT, device=device) | ||
mlp.train() | ||
|
||
# Check fused pattern w/ unfused default (switch happens within NVFusedBiasActivationDropout) | ||
mlp.cuda() | ||
fused_res = mlp(inputs) | ||
|
||
mlp.cpu() | ||
unfused_res = mlp(inputs.cpu()) | ||
|
||
if p == 0.0: | ||
assert torch.allclose(unfused_res.cuda(), fused_res, atol=1e-6, rtol=1e-2) | ||
|
||
# Check fused pattern w/ unfused default (switch happens within MLP) | ||
mlp.cuda() | ||
mlp_default.cuda() | ||
|
||
# Load same weight parameters into both models | ||
default_param_dict = OrderedDict( | ||
[ | ||
("mlp.2.weight", v) if k == "mlp.3.weight" else (k, v) | ||
for k, v in mlp_default.state_dict().items() | ||
] | ||
) | ||
mlp.load_state_dict(default_param_dict) | ||
fused_res = mlp(inputs) | ||
unfused_res = mlp_default(inputs) | ||
|
||
if p == 0.0: | ||
assert torch.allclose(unfused_res, fused_res, atol=1e-6, rtol=1e-2) | ||
|
||
# Restore original flag configuration | ||
xformers._is_functorch_available = flag_orig |
Oops, something went wrong.