diff --git a/.circleci/config.yml b/.circleci/config.yml
index 79c3110084..ac2e00fcdf 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -85,7 +85,7 @@ install_dep: &install_dep
# start installing
source activate /home/circleci/venv
- conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch-nightly -q
+ conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch -q
$CONDA_PYTHON -m pip install -r requirements-benchmark.txt --progress-bar off
# Mark install as complete
@@ -102,7 +102,7 @@ install_dep_exp: &install_dep_exp
# start installing
source activate /home/circleci/venv
- conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch-nightly -q
+ conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch -q
$CONDA_PYTHON -m pip install -r experimental/requirements.txt --progress-bar off
install_repo: &install_repo
@@ -374,7 +374,7 @@ jobs:
- ~/miniconda
- ~/venv
- key: cache-key-gpu-exp-114-{{ checksum "experimental/requirements.txt"}}-{{ checksum ".circleci/config.yml"}}
+ key: cache-key-gpu-exp-114-{{ checksum "experimental/requirements.txt" }}-{{ checksum ".circleci/config.yml" }}
- <<: *install_experimental_repo
- <<: *run_experimental_unittests
diff --git a/CHANGELOG.md b/CHANGELOG.md
index b3a9981788..7a6be360a0 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,10 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## TBD
### Fixed
-- Removed dupliacated biases in the FusedMLP layers [#317]
+- Removed duplicated biases in the FusedMLP layers [#317]
- Rotary embeddings respecting input types [#326]
- Poolformer style instantiating useless projection layers [#349]
-- Fix layer position not being properly tracked, causing extra layernorms for programatic xformers [#348]
+- Fix layer position not being properly tracked, causing extra layernorms for programmatic xformers [#348]
### Added
- Four blocksparsity layouts from DeepSpeed [#320]
@@ -18,7 +18,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Conv2DFeedforward feedforward part [#321]
- VisualAttention [#329]
- Automatic blocksparse for causal attention [#334]
-- Better hierarchical transformer generation [#345]
+- Better hierarchical transformer generation [#345]
+- Fused operations with AOTAutograd/NVFuser, integration into MLP [#357]
## [0.0.11] - 2022-05-30
### Fixed
@@ -40,7 +41,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.0.10] - 2022-03-14
### Fixed
- Expose bias flag for feedforwards, same default as Timm [#220]
-- Update eps value for layernormm, same default as torch [#221]
+- Update eps value for layernorm, same default as torch [#221]
- PreNorm bugfix, only one input was normalized [#233]
- Fix bug where embedding dimensions that did not match model dim would lead to a crash [#244]
@@ -53,12 +54,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Experimental Ragged attention [#189]
- Mixture of Experts [#181]
- BlockSparseTensor [#202]
-- nd-tensor support for triton softmax [#210]
+- Nd-tensor support for triton softmax [#210]
### Fixed
-- bugfix Favor, single feature map [#183]
-- sanity check blocksparse settings [#207]
-- fixed some pickability [#204]
+- Bugfix Favor, single feature map [#183]
+- Sanity check blocksparse settings [#207]
+- Fixed some picklability [#204]
## [0.0.8] - 2022-01-07
### Fixed
diff --git a/HOWTO.md b/HOWTO.md
index 3c9aafeaeb..142680dccb 100644
--- a/HOWTO.md
+++ b/HOWTO.md
@@ -12,6 +12,7 @@ Let's present here a couple of code snippets on how to solve a couple of questio
- [Replace all attentions from an existing ViT model with a sparse equivalent ?](#replace-all-attentions-from-an-existing-vit-model-with-a-sparse-equivalent-)
- [Some more examples](#some-more-examples)
- [BlockSparseAttention](#blocksparseattention)
+ - [How to Enable Fused Operations Using AOTAutograd and NVFuser](#how-to-enable-fused-operations-using-aotautograd-and-nvfuser)
- [From cherry picking attentions to building whole models](#from-cherry-picking-attentions-to-building-whole-models)
- [Testing out an attention mechanism](#testing-out-an-attention-mechanism)
- [Building an encoder, comparing to PyTorch](#building-an-encoder-comparing-to-pytorch)
@@ -295,6 +296,18 @@ On a V100, with PyTorch 1.9, Triton 1.1 and xFormers 0.0.2 this reports somethin
Note that the pattern here is not that sparse (half of the matrix is empty), the more sparse it gets the more biased the result will get towards BlockSparseAttention.
+## How to Enable Fused Operations Using AOTAutograd and NVFuser
+
+AOT Autograd is a toolkit from [FuncTorch](https://pytorch.org/functorch/stable/) which can be used to accelerate model training in xFormers. Broadly, it extracts a computational graph of the forward and backward passes of a model ahead of time. This allows for some joint graph optimizations and enables deep learning compilers such as [NVFuser](https://github.com/pytorch/pytorch/blob/release/1.12/torch/csrc/jit/codegen/cuda/README.md) to perform operator fusion. The [`memory_efficient_fusion`](https://pytorch.org/functorch/stable/generated/functorch.compile.memory_efficient_fusion.html#functorch.compile.memory_efficient_fusion) wrapper function provides a convenient way to leverage AOTAutograd and NVFuser on GPU.
+
+XFormers uses `memory_efficient_fusion` to combine sequences of fusable operations together into single fused function layers. These parts can be found [here](xformers/components/nvfuser). A notable example is [`NVFusedBiasActivationDropout`](xformers/components/nvfuser/bias_act_dropout.py), which is readily used inside the [`MLP`](xformers/components/feedforward/mlp.py) feedforward component.
+
+A benchmark of these fused patterns across some representative shapes shows significant speed increases compared to the unfused, Pytorch eager approach—up to 3.5x speedup for the forward pass and 2.2x for the forward and backward passes together. On average, peak memory usage of fused patterns is also lower, although we see some infrequent cases of up to 1.6x Pytorch peak memory usage on larger shapes. We also see better overall performance against our implementation of fused Bias, Activation, and Dropout using Triton ([see](xformers/triton/dropout.py)) as well. Full benchmark plots can be found [here](docs/plots/nvfuser/).
+
+Please note from README that the `_is_functorch_available` flag must be enabled for xFormers to use these optimizations. This allows the fused layers to be used and changes the behavior of the `MLP` feedforward component, causing it to default to using the fused `NVFusedBiasActivationDropout` layer.
+
+AOT Autograd offers a great deal a flexibility to the user, as `memory_efficient_fusion` can accept either a Python function or an entire `nn.Module` as input for fusion. Currently in xFormers, however, it is only used with Python function inputs because initial attempts with fusing xFormers layers and blocks have yielded memory issues and other CUDA errors. We are currently exploring further testing and benchmarking.
+
## From cherry picking attentions to building whole models
### Testing out an attention mechanism
diff --git a/README.md b/README.md
index 5adf6804a8..a3950398a3 100644
--- a/README.md
+++ b/README.md
@@ -84,6 +84,25 @@ Triton will cache the compiled kernels to `/tmp/triton` by default. If this beco
+ AOTAutograd/NVFuser
+
+Some parts of xFormers use AOT Autograd from the [FuncTorch](https://pytorch.org/functorch/stable/) library, and will only expose themselves if FuncTorch is installed, and a compatible GPU is present. If functorch was not installed as part of the testing procedure, you can install it directly through pip.
+
+ ```bash
+ pip install functorch
+ ```
+
+ Once installed, set the flag `_is_functorch_available = True` in `xformers/__init__.py`. You can optionally test that the installation is successful by running one of the functorch-related benchmarks `python3 xformers/benchmarks/benchmark_nvfuser.py`
+
+ If you are importing the xFormers library in a script, you can modify the flag as such:
+
+ ```python
+ import xformers
+ xformers._is_functorch_available = True
+ ```
+
+
+
### Testing the installation
This will run a benchmark of the attention mechanisms exposed by xFormers, and generate a runtime and memory plot.
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_gelu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_gelu_torch.float16.png
new file mode 100644
index 0000000000..b9bcd284f0
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_gelu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_gelu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_gelu_torch.float32.png
new file mode 100644
index 0000000000..98fc778a20
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_gelu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_relu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_relu_torch.float16.png
new file mode 100644
index 0000000000..4342553be2
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_relu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_relu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_relu_torch.float32.png
new file mode 100644
index 0000000000..e51e83d060
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_relu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_squared_relu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_squared_relu_torch.float16.png
new file mode 100644
index 0000000000..162441e8c1
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_squared_relu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_squared_relu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_squared_relu_torch.float32.png
new file mode 100644
index 0000000000..6a0177bf98
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_False_squared_relu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_gelu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_gelu_torch.float16.png
new file mode 100644
index 0000000000..24da5e9473
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_gelu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_gelu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_gelu_torch.float32.png
new file mode 100644
index 0000000000..2575da0db1
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_gelu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_relu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_relu_torch.float16.png
new file mode 100644
index 0000000000..940e0f5bda
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_relu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_relu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_relu_torch.float32.png
new file mode 100644
index 0000000000..04b29152c6
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_relu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_squared_relu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_squared_relu_torch.float16.png
new file mode 100644
index 0000000000..d89ca06492
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_squared_relu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_squared_relu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_squared_relu_torch.float32.png
new file mode 100644
index 0000000000..dc0ea064f0
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW+BW_True_squared_relu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_gelu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_gelu_torch.float16.png
new file mode 100644
index 0000000000..bbaec7f1e2
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_gelu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_gelu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_gelu_torch.float32.png
new file mode 100644
index 0000000000..a881020b3d
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_gelu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_relu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_relu_torch.float16.png
new file mode 100644
index 0000000000..991c0576a9
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_relu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_relu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_relu_torch.float32.png
new file mode 100644
index 0000000000..6cc3ed3f46
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_relu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_squared_relu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_squared_relu_torch.float16.png
new file mode 100644
index 0000000000..8a56d5fb9a
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_squared_relu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_squared_relu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_squared_relu_torch.float32.png
new file mode 100644
index 0000000000..631dc8b4c4
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_False_squared_relu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_gelu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_gelu_torch.float16.png
new file mode 100644
index 0000000000..371be96445
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_gelu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_gelu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_gelu_torch.float32.png
new file mode 100644
index 0000000000..421a27ad4a
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_gelu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_relu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_relu_torch.float16.png
new file mode 100644
index 0000000000..392dbd4665
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_relu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_relu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_relu_torch.float32.png
new file mode 100644
index 0000000000..c328dd75e4
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_relu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_squared_relu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_squared_relu_torch.float16.png
new file mode 100644
index 0000000000..21ec934098
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_squared_relu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_squared_relu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_squared_relu_torch.float32.png
new file mode 100644
index 0000000000..cc5f140789
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/MAXMEM_Bias_Act_Dropout_FW_True_squared_relu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_gelu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_gelu_torch.float16.png
new file mode 100644
index 0000000000..d332b443a5
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_gelu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_gelu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_gelu_torch.float32.png
new file mode 100644
index 0000000000..a8c5daaa5c
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_gelu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_relu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_relu_torch.float16.png
new file mode 100644
index 0000000000..e915dee88d
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_relu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_relu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_relu_torch.float32.png
new file mode 100644
index 0000000000..9d99daecbc
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_relu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_squared_relu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_squared_relu_torch.float16.png
new file mode 100644
index 0000000000..5aed6851fa
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_squared_relu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_squared_relu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_squared_relu_torch.float32.png
new file mode 100644
index 0000000000..2c5e9e6861
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_False_squared_relu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_gelu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_gelu_torch.float16.png
new file mode 100644
index 0000000000..0fa9947035
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_gelu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_gelu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_gelu_torch.float32.png
new file mode 100644
index 0000000000..baf2561bc7
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_gelu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_relu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_relu_torch.float16.png
new file mode 100644
index 0000000000..0dac783299
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_relu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_relu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_relu_torch.float32.png
new file mode 100644
index 0000000000..e6f2fd7b13
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_relu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_squared_relu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_squared_relu_torch.float16.png
new file mode 100644
index 0000000000..df686aa9d7
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_squared_relu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_squared_relu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_squared_relu_torch.float32.png
new file mode 100644
index 0000000000..344594e8d3
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW+BW_True_squared_relu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_gelu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_gelu_torch.float16.png
new file mode 100644
index 0000000000..89ebb0c9ed
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_gelu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_gelu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_gelu_torch.float32.png
new file mode 100644
index 0000000000..3c6e8e4016
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_gelu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_relu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_relu_torch.float16.png
new file mode 100644
index 0000000000..23cdbc3b10
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_relu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_relu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_relu_torch.float32.png
new file mode 100644
index 0000000000..23309977d2
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_relu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_squared_relu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_squared_relu_torch.float16.png
new file mode 100644
index 0000000000..296a3eaa55
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_squared_relu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_squared_relu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_squared_relu_torch.float32.png
new file mode 100644
index 0000000000..90f30357e4
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_False_squared_relu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_gelu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_gelu_torch.float16.png
new file mode 100644
index 0000000000..f4d85bcfbd
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_gelu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_gelu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_gelu_torch.float32.png
new file mode 100644
index 0000000000..84290d6a68
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_gelu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_relu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_relu_torch.float16.png
new file mode 100644
index 0000000000..31e1825b65
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_relu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_relu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_relu_torch.float32.png
new file mode 100644
index 0000000000..b4a2c40730
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_relu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_squared_relu_torch.float16.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_squared_relu_torch.float16.png
new file mode 100644
index 0000000000..97dcf4c1b9
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_squared_relu_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_squared_relu_torch.float32.png b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_squared_relu_torch.float32.png
new file mode 100644
index 0000000000..3b4f1975b0
Binary files /dev/null and b/docs/plots/nvfuser/BiasActivationDropout/RUNTIME_Bias_Act_Dropout_FW_True_squared_relu_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW+BW_False_torch.float16.png b/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW+BW_False_torch.float16.png
new file mode 100644
index 0000000000..61f56e7d96
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW+BW_False_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW+BW_False_torch.float32.png b/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW+BW_False_torch.float32.png
new file mode 100644
index 0000000000..f3bd453826
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW+BW_False_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW+BW_True_torch.float16.png b/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW+BW_True_torch.float16.png
new file mode 100644
index 0000000000..a449886449
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW+BW_True_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW+BW_True_torch.float32.png b/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW+BW_True_torch.float32.png
new file mode 100644
index 0000000000..99738ee1de
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW+BW_True_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW_False_torch.float16.png b/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW_False_torch.float16.png
new file mode 100644
index 0000000000..090aaef3bf
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW_False_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW_False_torch.float32.png b/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW_False_torch.float32.png
new file mode 100644
index 0000000000..cf467b366c
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW_False_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW_True_torch.float16.png b/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW_True_torch.float16.png
new file mode 100644
index 0000000000..3f24fe440c
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW_True_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW_True_torch.float32.png b/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW_True_torch.float32.png
new file mode 100644
index 0000000000..2b115e1c04
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidual/MAXMEM_Bias_Dropout_Res_FW_True_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW+BW_False_torch.float16.png b/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW+BW_False_torch.float16.png
new file mode 100644
index 0000000000..3185f40fb7
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW+BW_False_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW+BW_False_torch.float32.png b/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW+BW_False_torch.float32.png
new file mode 100644
index 0000000000..cc1253a10c
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW+BW_False_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW+BW_True_torch.float16.png b/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW+BW_True_torch.float16.png
new file mode 100644
index 0000000000..9c0a733c60
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW+BW_True_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW+BW_True_torch.float32.png b/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW+BW_True_torch.float32.png
new file mode 100644
index 0000000000..39bd930845
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW+BW_True_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW_False_torch.float16.png b/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW_False_torch.float16.png
new file mode 100644
index 0000000000..0e18ae8cdd
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW_False_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW_False_torch.float32.png b/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW_False_torch.float32.png
new file mode 100644
index 0000000000..76e50fd6d3
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW_False_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW_True_torch.float16.png b/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW_True_torch.float16.png
new file mode 100644
index 0000000000..2de1143ee1
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW_True_torch.float16.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW_True_torch.float32.png b/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW_True_torch.float32.png
new file mode 100644
index 0000000000..bc72551c8d
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidual/RUNTIME_Bias_Dropout_Res_FW_True_torch.float32.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float16_post.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float16_post.png
new file mode 100644
index 0000000000..b26fea0daa
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float16_post.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float16_pre.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float16_pre.png
new file mode 100644
index 0000000000..d67f05b7f6
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float16_pre.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float32_post.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float32_post.png
new file mode 100644
index 0000000000..1698343620
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float32_post.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float32_pre.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float32_pre.png
new file mode 100644
index 0000000000..08818a5e08
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float32_pre.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float16_post.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float16_post.png
new file mode 100644
index 0000000000..47d2c9c349
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float16_post.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float16_pre.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float16_pre.png
new file mode 100644
index 0000000000..fa79645ca6
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float16_pre.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float32_post.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float32_post.png
new file mode 100644
index 0000000000..6591dfe244
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float32_post.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float32_pre.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float32_pre.png
new file mode 100644
index 0000000000..d436f270a7
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float32_pre.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_False_torch.float16_post.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_False_torch.float16_post.png
new file mode 100644
index 0000000000..5dd4ef4182
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_False_torch.float16_post.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_False_torch.float16_pre.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_False_torch.float16_pre.png
new file mode 100644
index 0000000000..965ef015e6
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_False_torch.float16_pre.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_False_torch.float32_post.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_False_torch.float32_post.png
new file mode 100644
index 0000000000..a2a6581a96
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_False_torch.float32_post.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_False_torch.float32_pre.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_False_torch.float32_pre.png
new file mode 100644
index 0000000000..e8e309e934
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_False_torch.float32_pre.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_True_torch.float16_post.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_True_torch.float16_post.png
new file mode 100644
index 0000000000..bb0dce8d7d
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_True_torch.float16_post.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_True_torch.float16_pre.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_True_torch.float16_pre.png
new file mode 100644
index 0000000000..2a7119aa48
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_True_torch.float16_pre.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_True_torch.float32_post.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_True_torch.float32_post.png
new file mode 100644
index 0000000000..1f5083996a
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_True_torch.float32_post.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_True_torch.float32_pre.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_True_torch.float32_pre.png
new file mode 100644
index 0000000000..71e22d740b
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/MAXMEM_Bias_Dropout_Res_LayerNorm_FW_True_torch.float32_pre.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float16_post.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float16_post.png
new file mode 100644
index 0000000000..70002421f6
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float16_post.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float16_pre.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float16_pre.png
new file mode 100644
index 0000000000..e057faac1e
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float16_pre.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float32_post.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float32_post.png
new file mode 100644
index 0000000000..6a5aa26ce8
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float32_post.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float32_pre.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float32_pre.png
new file mode 100644
index 0000000000..6c11ddac44
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_False_torch.float32_pre.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float16_post.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float16_post.png
new file mode 100644
index 0000000000..b83df2c74f
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float16_post.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float16_pre.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float16_pre.png
new file mode 100644
index 0000000000..270eb2e07a
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float16_pre.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float32_post.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float32_post.png
new file mode 100644
index 0000000000..368580e409
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float32_post.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float32_pre.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float32_pre.png
new file mode 100644
index 0000000000..b5785477e4
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW+BW_True_torch.float32_pre.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_False_torch.float16_post.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_False_torch.float16_post.png
new file mode 100644
index 0000000000..316f82042c
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_False_torch.float16_post.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_False_torch.float16_pre.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_False_torch.float16_pre.png
new file mode 100644
index 0000000000..483b025371
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_False_torch.float16_pre.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_False_torch.float32_post.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_False_torch.float32_post.png
new file mode 100644
index 0000000000..0c3927e132
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_False_torch.float32_post.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_False_torch.float32_pre.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_False_torch.float32_pre.png
new file mode 100644
index 0000000000..fbe2ff7957
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_False_torch.float32_pre.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_True_torch.float16_post.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_True_torch.float16_post.png
new file mode 100644
index 0000000000..913d0ea3e6
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_True_torch.float16_post.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_True_torch.float16_pre.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_True_torch.float16_pre.png
new file mode 100644
index 0000000000..ce347f3d37
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_True_torch.float16_pre.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_True_torch.float32_post.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_True_torch.float32_post.png
new file mode 100644
index 0000000000..1840967dfd
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_True_torch.float32_post.png differ
diff --git a/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_True_torch.float32_pre.png b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_True_torch.float32_pre.png
new file mode 100644
index 0000000000..1a172d91a1
Binary files /dev/null and b/docs/plots/nvfuser/BiasDropoutResidualLayerNorm/RUNTIME_Bias_Dropout_Res_LayerNorm_FW_True_torch.float32_pre.png differ
diff --git a/docs/source/tutorials/aotautograd_nvfuser.rst b/docs/source/tutorials/aotautograd_nvfuser.rst
new file mode 100644
index 0000000000..3f4931a02b
--- /dev/null
+++ b/docs/source/tutorials/aotautograd_nvfuser.rst
@@ -0,0 +1,32 @@
+How to Enable Fused Operations Using AOTAutograd and NVFuser
+===================================================================
+
+AOT Autograd is a toolkit from FuncTorch_ which can be used to accelerate model training in xFormers.
+Broadly, it extracts a computational graph of the forward and backward passes of a model ahead of time.
+This allows for some joint graph optimizations enables deep learning compilers such as NVFuser_ to perform operator fusion.
+The `memory_efficient_fusion`_ wrapper function provides a convenient way to leverage AOTAutograd and NVFuser on GPU.
+
+.. _FuncTorch: https://pytorch.org/functorch/stable/
+.. _NVFuser: https://github.com/pytorch/pytorch/blob/release/1.12/torch/csrc/jit/codegen/cuda/README.md
+.. _memory_efficient_fusion: https://pytorch.org/functorch/stable/generated/functorch.compile.memory_efficient_fusion.html#functorch.compile.memory_efficient_fusion
+
+XFormers uses `memory_efficient_fusion` to combine sequences of fusable operations together into single fused function layers.
+These parts can be found inside `xformers/components/nvfuser`. A notable example is `NVFusedBiasActivationDropout`, which is readily used inside the `MLP`_ feedforward component.
+
+.. _MLP: https://github.com/facebookresearch/xformers/blob/main/xformers/components/feedforward/mlp.py
+
+A benchmark of these fused patterns across some representative shapes shows significant speed increases compared to the unfused,
+Pytorch eager approach―up to 3.5x speedup for the forward pass and 2.2x for the forward and backward passes together. On average, peak memory usage of fused patterns is also lower,
+although we see some infrequent cases of up to 1.6x Pytorch peak memory usage on larger shapes. We also see better overall performance against our implementation of fused Bias,
+Activation, and Dropout using Triton (see_) as well. Full benchmark plots can be found here_.
+
+.. _see: https://github.com/facebookresearch/xformers/blob/main/xformers/triton/dropout.py
+.. _here: https://github.com/facebookresearch/xformers/tree/main/docs/plots/nvfuser
+
+Please note from README that the `_is_functorch_available` flag must be enabled for xFormers to use these optimizations.
+This allows the fused layers to be used and changes the behavior of the `MLP` feedforward component,
+causing it to default to using the fused `NVFusedBiasActivationDropout` layer.
+
+AOT Autograd offers a great deal a flexibility to the user, as `memory_efficient_fusion` can accept either a Python function or an entire `nn.Module` as input for fusion.
+Currently in xFormers, however, it is only used with Python function inputs because initial attempts with fusing xFormers layers and blocks have yielded memory issues and other CUDA errors.
+We are currently exploring further testing and benchmarking.
diff --git a/docs/source/tutorials/index.rst b/docs/source/tutorials/index.rst
index 37af922a98..fd49bdd6fd 100644
--- a/docs/source/tutorials/index.rst
+++ b/docs/source/tutorials/index.rst
@@ -6,6 +6,7 @@ Tutorials
sparse_vit
blocksparse
+ aotautograd_nvfuser
extend_attentions
use_attention
pytorch_encoder
diff --git a/requirements-test.txt b/requirements-test.txt
index 41c8e0fd7a..f6d968768f 100644
--- a/requirements-test.txt
+++ b/requirements-test.txt
@@ -28,3 +28,7 @@ fairscale >= 0.4.5
# Dependency for fused layers, optional
triton == 2.0.0.dev20220701
+
+# Dependencies for fused layers using FuncTorch, optional
+git+https://github.com/pytorch/functorch@v0.2.0
+networkx == 2.8.4
diff --git a/tests/test_nvfuser.py b/tests/test_nvfuser.py
new file mode 100644
index 0000000000..2595f14821
--- /dev/null
+++ b/tests/test_nvfuser.py
@@ -0,0 +1,200 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import logging
+from collections import OrderedDict
+from contextlib import nullcontext
+
+import pytest
+import torch
+import torch.nn as nn
+from torch.cuda.amp.autocast_mode import autocast
+
+import xformers
+from xformers.components import Activation, ResidualNormStyle
+
+# Store original and possible flag setting
+flag_orig = xformers._is_functorch_available
+flag_new = True
+xformers._is_functorch_available = True
+
+
+_gpu_available = torch.cuda.is_available()
+
+try:
+ import xformers.components.feedforward as ff
+ from xformers.components.nvfuser import (
+ NVFusedBiasActivationDropout,
+ NVFusedBiasDropoutRes,
+ NVFusedBiasDropoutResLayerNorm,
+ )
+ from xformers.components.nvfuser.utils import build_nvfused
+except ImportError as e:
+ logging.warning(f"Functorch is not available to run test_nvfuser.py. \nError {e}")
+ flag_new = False
+
+xformers._is_functorch_available = flag_orig
+
+FUSED_PATTERNS = (
+ [
+ NVFusedBiasActivationDropout,
+ NVFusedBiasDropoutRes,
+ NVFusedBiasDropoutResLayerNorm,
+ ]
+ if flag_new
+ else []
+)
+
+# Testing odd (non-power-of-two for instance) shapes on purpose
+SHAPES = [
+ (384, 512),
+ (8, 384, 128),
+ (8, 784, 512),
+ (4, 16, 384),
+ (4, 16, 1024),
+ (2, 16, 2048),
+ (2, 16, 4096),
+ (1, 16, 12288),
+]
+
+BATCH = 4
+SEQ = 256
+EMBD = 16
+LATENT = 128
+DEVICES = [torch.device("cuda")]
+
+ACTIVATIONS = [
+ Activation.ReLU,
+ Activation.GeLU,
+ Activation.LeakyReLU,
+ Activation.SquaredReLU,
+ Activation.SmeLU,
+]
+
+
+@pytest.mark.skipif(not flag_new, reason="Functorch is not available")
+@pytest.mark.skipif(not _gpu_available, reason="GPU is not available")
+@pytest.mark.parametrize("fused_pattern", FUSED_PATTERNS)
+@pytest.mark.parametrize("shape", SHAPES)
+@pytest.mark.parametrize("amp", [False, True])
+@pytest.mark.parametrize("bias", [False, True])
+@pytest.mark.parametrize("activation", ACTIVATIONS)
+@pytest.mark.parametrize("p", [0, 0.1, 0.5])
+@pytest.mark.parametrize(
+ "layer_norm_style", [None, ResidualNormStyle.Pre, ResidualNormStyle.Post]
+)
+def test_nvfused_pattern_parity(
+ fused_pattern: nn.Module,
+ shape: tuple,
+ amp: bool,
+ bias: bool,
+ activation: Activation,
+ p: float,
+ layer_norm_style: ResidualNormStyle,
+):
+ # Enable global flag
+ xformers._is_functorch_available = flag_new
+
+ if (
+ fused_pattern != NVFusedBiasDropoutResLayerNorm
+ and layer_norm_style != ResidualNormStyle.Pre
+ ):
+ pytest.skip(
+ "Layer norm style doesn't apply, the same relevant params already tested once."
+ )
+
+ torch.cuda.manual_seed_all(0)
+ torch.random.manual_seed(0)
+ x = torch.normal(0, 1, size=shape, device="cuda", requires_grad=True)
+ x_cpu = x.clone().cpu()
+
+ with autocast(enabled=amp), pytest.raises(
+ ValueError
+ ) if layer_norm_style is None else nullcontext():
+ fused = build_nvfused(
+ fused_pattern, shape, bias, activation, p, layer_norm_style
+ )
+ fused.train().cuda()
+ nvfused_res = fused(x, x) if fused.requires_residual else fused(x)
+ fused.cpu()
+ torch_res = (
+ fused(x_cpu, x_cpu).cuda()
+ if fused.requires_residual
+ else fused(x_cpu).cuda()
+ )
+
+ # Check if operation was actually fused
+ assert isinstance(
+ nvfused_res.grad_fn, torch.autograd.function.BackwardCFunction
+ )
+
+ if p == 0.0:
+ # Check fused and unfused paths are the same
+ assert torch.allclose(torch_res, nvfused_res, atol=1e-6, rtol=1e-2)
+
+ # Restore original flag configuration
+ xformers._is_functorch_available = flag_orig
+
+
+@pytest.mark.skipif(not flag_new, reason="Functorch is not available")
+@pytest.mark.skipif(not _gpu_available, reason="GPU is not available")
+@pytest.mark.parametrize("activation", ACTIVATIONS)
+@pytest.mark.parametrize("device", DEVICES)
+@pytest.mark.parametrize("p", [0, 0.1, 0.5])
+def test_nvfused_mlp(activation: Activation, device: torch.device, p: float):
+ test_config = {
+ "name": "MLP",
+ "dim_model": LATENT,
+ "dropout": p,
+ "activation": activation,
+ "hidden_layer_multiplier": 4,
+ "bias": False,
+ }
+ # Enable global flag
+ xformers._is_functorch_available = flag_new
+
+ torch.random.manual_seed(0)
+ torch.cuda.manual_seed_all(0)
+
+ mlp = ff.build_feedforward(test_config)
+ # Creates non-fused default MLP
+ xformers._is_functorch_available = False
+ mlp_default = ff.build_feedforward(test_config)
+ xformers._is_functorch_available = flag_new
+
+ inputs = torch.rand(BATCH, SEQ, LATENT, device=device)
+ mlp.train()
+
+ # Check fused pattern w/ unfused default (switch happens within NVFusedBiasActivationDropout)
+ mlp.cuda()
+ fused_res = mlp(inputs)
+
+ mlp.cpu()
+ unfused_res = mlp(inputs.cpu())
+
+ if p == 0.0:
+ assert torch.allclose(unfused_res.cuda(), fused_res, atol=1e-6, rtol=1e-2)
+
+ # Check fused pattern w/ unfused default (switch happens within MLP)
+ mlp.cuda()
+ mlp_default.cuda()
+
+ # Load same weight parameters into both models
+ default_param_dict = OrderedDict(
+ [
+ ("mlp.2.weight", v) if k == "mlp.3.weight" else (k, v)
+ for k, v in mlp_default.state_dict().items()
+ ]
+ )
+ mlp.load_state_dict(default_param_dict)
+ fused_res = mlp(inputs)
+ unfused_res = mlp_default(inputs)
+
+ if p == 0.0:
+ assert torch.allclose(unfused_res, fused_res, atol=1e-6, rtol=1e-2)
+
+ # Restore original flag configuration
+ xformers._is_functorch_available = flag_orig
diff --git a/xformers/__init__.py b/xformers/__init__.py
index 63c7f823dc..d44d51fbe0 100644
--- a/xformers/__init__.py
+++ b/xformers/__init__.py
@@ -10,8 +10,11 @@
# Please update the doc version in docs/source/conf.py as well.
__version__ = "0.0.12.dev"
-_is_sparse_available = True
-_is_triton_available = torch.cuda.is_available()
+_is_sparse_available: bool = True
+_is_triton_available: bool = torch.cuda.is_available()
+
+# Set to true to utilize functorch
+_is_functorch_available: bool = False
def _register_extensions():
@@ -77,3 +80,13 @@ def _register_extensions():
f"Triton is not available, some optimizations will not be enabled.\nError {e}"
)
_is_triton_available = False
+
+
+if _is_functorch_available:
+ try:
+ from xformers.components.nvfuser import NVFusedBiasActivationDropout # noqa
+ except ImportError as e:
+ logging.warning(
+ f"Functorch is not available, some optimizations will not be enabled.\nError {e}"
+ )
+ _is_functorch_available = False
diff --git a/xformers/benchmarks/benchmark_causal_blocksparse.py b/xformers/benchmarks/benchmark_causal_blocksparse.py
index 70a5e499b9..4bd206349c 100644
--- a/xformers/benchmarks/benchmark_causal_blocksparse.py
+++ b/xformers/benchmarks/benchmark_causal_blocksparse.py
@@ -113,7 +113,7 @@ def sdp_attention():
)
pretty_print(
results_mem,
- title=f"\n --- Type: {datatype}Block Size: {BS} --- ",
+ title=f"\n --- Type: {datatype} Block Size: {BS} --- ",
units="peak memory usage in MB",
)
diff --git a/xformers/benchmarks/benchmark_nvfuser.py b/xformers/benchmarks/benchmark_nvfuser.py
new file mode 100644
index 0000000000..d43dfd4e45
--- /dev/null
+++ b/xformers/benchmarks/benchmark_nvfuser.py
@@ -0,0 +1,251 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+
+from functools import partial
+from typing import Any, Dict, List, Optional
+
+import torch
+import torch.nn as nn
+import triton
+
+from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
+from xformers.components import Activation, ResidualNormStyle, build_activation
+from xformers.components.nvfuser import (
+ NVFusedBiasActivationDropout,
+ NVFusedBiasDropoutRes,
+ NVFusedBiasDropoutResLayerNorm,
+)
+from xformers.components.nvfuser.bias_act_dropout import _fn as bias_act_dropout
+from xformers.components.nvfuser.bias_dropout_res import _fn as bias_dropout_res
+from xformers.components.nvfuser.bias_dropout_res_layernorm import (
+ _fn as bias_dropout_res_layernorm,
+)
+from xformers.components.nvfuser.utils import build_nvfused
+from xformers.triton import FusedDropoutBias
+
+SHAPES = [
+ (8, 256, 512),
+ (8, 512, 1024),
+ (4, 1024, 1024),
+ (2, 2048, 2048),
+ (1, 2048, 12288),
+ (2, 4096, 4096),
+]
+
+P = 0.1
+
+
+def build_torch_fn(
+ pattern: nn.Module,
+ shape: tuple,
+ bias: Optional[torch.Tensor],
+ activation: Optional[Activation],
+ p: float,
+ layer_norm_style: Optional[ResidualNormStyle],
+ dtype: torch.dtype,
+):
+ torch_act = build_activation(activation)
+ if pattern == NVFusedBiasActivationDropout:
+ return partial(bias_act_dropout, bias=bias, activation=torch_act, prob=p)
+ elif pattern == NVFusedBiasDropoutRes:
+ return partial(bias_dropout_res, bias=bias, prob=p)
+ elif pattern == NVFusedBiasDropoutResLayerNorm:
+ norm = nn.LayerNorm(shape[-1]).to(device=torch.device("cuda"), dtype=dtype)
+ return partial(
+ bias_dropout_res_layernorm,
+ bias=bias,
+ prob=p,
+ layer_norm_style=layer_norm_style,
+ norm=norm,
+ )
+ else:
+ raise ValueError
+
+
+def bench_nvfused(
+ fused_pattern: nn.Module,
+ bias: bool,
+ backward: bool,
+ activation: Optional[Activation],
+ layer_norm_style: Optional[ResidualNormStyle],
+):
+ device = torch.device("cuda")
+
+ pattern_str = {
+ NVFusedBiasActivationDropout: "Bias_Act_Dropout",
+ NVFusedBiasDropoutRes: "Bias_Dropout_Res",
+ NVFusedBiasDropoutResLayerNorm: "Bias_Dropout_Res_LayerNorm",
+ }[
+ fused_pattern # type: ignore
+ ]
+
+ for dtype in [
+ torch.float16,
+ torch.float32,
+ ]:
+ results: Dict[str, Any] = {}
+ results_mem: Dict[str, Any] = {}
+
+ for B, M, K in SHAPES:
+ a = torch.rand(
+ (B, M, K), device=device, dtype=dtype, requires_grad=backward
+ )
+ b = torch.rand(K, device=device, dtype=dtype, requires_grad=backward)
+
+ torch_fn = build_torch_fn(
+ fused_pattern,
+ (B, M, K),
+ b if bias else None,
+ activation,
+ P,
+ layer_norm_style,
+ dtype,
+ )
+
+ nvfuser_fn = build_nvfused(
+ fused_pattern, (B, M, K), bias, activation, P, layer_norm_style
+ )
+ nvfuser_fn.cuda()
+ nvfuser_fn.to(device=device, dtype=dtype)
+ residual = nvfuser_fn.requires_residual
+
+ triton_fn = (
+ FusedDropoutBias(
+ P, bias_shape=K if bias else None, activation=activation
+ )
+ if fused_pattern == NVFusedBiasActivationDropout
+ else None
+ )
+
+ def step(fn, residual, x):
+ y = fn(x=x, residual=x) if residual else fn(x)
+ if backward:
+ y.grad = None
+ torch.norm(y).backward()
+ return y
+
+ testcases = [
+ TestCase(
+ partial(step, fn=torch_fn, residual=residual),
+ "pytorch- bias: {} - fw{}{}{}".format(
+ bias,
+ "+bw" if backward else "",
+ f" - Act: {activation}" if activation is not None else "",
+ f" - Style: {layer_norm_style}"
+ if layer_norm_style is not None
+ else "",
+ ),
+ ),
+ TestCase(
+ partial(step, fn=nvfuser_fn, residual=residual),
+ "nvFuser- bias: {} - fw{}{}{}".format(
+ bias,
+ "+bw" if backward else "",
+ f" - Act: {activation}" if activation is not None else "",
+ f" - Style: {layer_norm_style}"
+ if layer_norm_style is not None
+ else "",
+ ),
+ ),
+ ]
+ if triton_fn is not None:
+ triton_test = TestCase(
+ partial(step, fn=triton_fn, residual=residual),
+ "triton- bias: {} - fw{}{}{}".format(
+ bias,
+ "+bw" if backward else "",
+ f" - Act: {activation}" if activation is not None else "",
+ f" - Style: {layer_norm_style}"
+ if layer_norm_style is not None
+ else "",
+ ),
+ )
+ testcases.append(triton_test)
+
+ for testcase in testcases:
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+ torch.cuda.synchronize()
+
+ time = triton.testing.do_bench(
+ lambda: testcase.function(x=a), grad_to_none=[a, b]
+ )[0]
+
+ torch.cuda.synchronize()
+ max_memory = torch.cuda.max_memory_allocated() // 2**20
+
+ key = f"B={B}, M={M}, K={K}"
+ if key not in results:
+ results[key] = {}
+
+ results[key][testcase.name] = f"{time:.3f}"
+
+ # Record peak mem usage
+ if key not in results_mem:
+ results_mem[key] = {}
+ results_mem[key][testcase.name] = f"{max_memory:.1f}"
+
+ pretty_print(
+ results,
+ title="\n --- RUNTIME Type: {} {} --- ".format(pattern_str, dtype),
+ units="ms",
+ )
+ pretty_print(
+ results_mem,
+ title="\n --- PEAK MEMORY Type: {} {} --- ".format(pattern_str, dtype),
+ units="MB",
+ )
+ pretty_plot(
+ results,
+ title="RUNTIME-{}-FW{}-{}{}-{}{}".format(
+ pattern_str,
+ "+BW" if backward else "",
+ bias,
+ f"-{activation}" if activation is not None else "",
+ dtype,
+ f"-{layer_norm_style}" if layer_norm_style is not None else "",
+ ),
+ units="ms",
+ dash_key="pytorch",
+ legend_loc="upper left",
+ )
+ pretty_plot(
+ results_mem,
+ title="MAXMEM-{}-FW{}-{}{}-{}{}".format(
+ pattern_str,
+ "+BW" if backward else "",
+ bias,
+ f"-{activation}" if activation is not None else "",
+ dtype,
+ f"-{layer_norm_style}" if layer_norm_style is not None else "",
+ ),
+ units="MB",
+ dash_key="pytorch",
+ legend_loc="upper left",
+ )
+
+
+PATTERNS = [
+ NVFusedBiasActivationDropout,
+ NVFusedBiasDropoutRes,
+ NVFusedBiasDropoutResLayerNorm,
+]
+
+for pattern in PATTERNS:
+ activations: List[Optional[Activation]] = (
+ [Activation.ReLU, Activation.GeLU, Activation.SquaredReLU]
+ if pattern == NVFusedBiasActivationDropout
+ else [None]
+ )
+ for activation in activations:
+ for bw in [True, False]:
+ for bias in [True, False]:
+ styles: List[Optional[ResidualNormStyle]] = (
+ [ResidualNormStyle.Pre, ResidualNormStyle.Post]
+ if pattern == NVFusedBiasDropoutResLayerNorm
+ else [None]
+ )
+ for style in styles:
+ bench_nvfused(pattern, bias, bw, activation, style) # type: ignore
diff --git a/xformers/components/feedforward/mlp.py b/xformers/components/feedforward/mlp.py
index 45ae600a63..adc869868e 100644
--- a/xformers/components/feedforward/mlp.py
+++ b/xformers/components/feedforward/mlp.py
@@ -9,15 +9,22 @@
import torch
import torch.nn as nn
+import xformers
from xformers.components import Activation, build_activation
from xformers.components.feedforward import Feedforward, FeedforwardConfig
+if xformers._is_functorch_available:
+ from xformers.components.nvfuser import ( # noqa
+ NVFusedBiasActivationDropout,
+ )
+
from . import register_feedforward
@dataclass
class MlpConfig(FeedforwardConfig):
hidden_layer_multiplier: int
+ bias: bool
@register_feedforward("MLP", MlpConfig)
@@ -33,14 +40,42 @@ def __init__(
**kwargs,
):
super().__init__()
+ dim_mlp = hidden_layer_multiplier * dim_model
+ # check if fused Bias Activation Dropout is applicable
+ if xformers._is_functorch_available:
+
+ # Catch unimported fused layer
+ from xformers.components.nvfuser.bias_act_dropout import ( # noqa
+ NVFusedBiasActivationDropout,
+ )
- self.mlp = nn.Sequential(
- nn.Linear(dim_model, hidden_layer_multiplier * dim_model, bias=bias),
- build_activation(activation),
- nn.Dropout(dropout),
- nn.Linear(hidden_layer_multiplier * dim_model, dim_model, bias=bias),
- nn.Dropout(dropout),
- )
+ self.requires_cuda = True
+ self.mlp = nn.Sequential(
+ nn.Linear(
+ in_features=dim_model, out_features=dim_mlp, bias=False
+ ), # bias is handled in the next layer
+ NVFusedBiasActivationDropout(
+ p=dropout,
+ bias_shape=dim_mlp if bias else None,
+ activation=activation,
+ ),
+ nn.Linear(
+ in_features=dim_mlp, out_features=dim_model, bias=False
+ ), # bias is handled in the next layer
+ NVFusedBiasActivationDropout(
+ p=dropout,
+ bias_shape=dim_model if bias else None,
+ activation=None,
+ ),
+ )
+ else:
+ self.mlp = nn.Sequential(
+ nn.Linear(in_features=dim_model, out_features=dim_mlp, bias=bias),
+ build_activation(activation),
+ nn.Dropout(dropout),
+ nn.Linear(in_features=dim_mlp, out_features=dim_model, bias=bias),
+ nn.Dropout(dropout),
+ )
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return self.mlp(inputs)
diff --git a/xformers/components/nvfuser/__init__.py b/xformers/components/nvfuser/__init__.py
new file mode 100644
index 0000000000..1a46d6ff50
--- /dev/null
+++ b/xformers/components/nvfuser/__init__.py
@@ -0,0 +1,21 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+from xformers import _is_functorch_available
+
+if _is_functorch_available: # noqa
+ try:
+ from .bias_act_dropout import NVFusedBiasActivationDropout # noqa
+ from .bias_dropout_res import NVFusedBiasDropoutRes # noqa
+ from .bias_dropout_res_layernorm import NVFusedBiasDropoutResLayerNorm # noqa
+
+ __all__ = [
+ "NVFusedBiasActivationDropout",
+ "NVFusedBiasDropoutResLayerNorm",
+ "NVFusedBiasDropoutRes",
+ ]
+ except ImportError:
+ __all__ = []
diff --git a/xformers/components/nvfuser/bias_act_dropout.py b/xformers/components/nvfuser/bias_act_dropout.py
new file mode 100644
index 0000000000..a6f052b1d5
--- /dev/null
+++ b/xformers/components/nvfuser/bias_act_dropout.py
@@ -0,0 +1,70 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from functorch.compile import memory_efficient_fusion
+
+from xformers.components import Activation, build_activation
+
+
+def _fn(
+ x: torch.Tensor,
+ bias: Optional[torch.nn.parameter.Parameter],
+ activation: nn.Module,
+ prob: float,
+) -> torch.Tensor:
+ if bias is not None:
+ x = torch.add(x, bias)
+ y = activation(x)
+ return torch.nn.functional.dropout(y, prob) if prob > 0.0 else y
+
+
+class NVFusedBiasActivationDropout(torch.nn.Module):
+ """
+ A layer which fuses the computation of Dropout(Activation(x + Bias))
+ with AOTAutograd and nvFuser
+ """
+
+ def __init__(
+ self,
+ p: float,
+ activation: Optional[Activation] = None,
+ bias_shape: Optional[int] = None,
+ ) -> None:
+ super().__init__()
+
+ self.p = float(p)
+ self.requires_residual = False
+ self.activation = activation
+ self.pytorch_activation = build_activation(self.activation)
+
+ self.bias = (
+ nn.Parameter(torch.zeros(bias_shape)) if bias_shape is not None else None
+ )
+
+ assert (
+ self.p < 1.0
+ ), f"We don't want to drop all the values, most probably p={self.p} is not properly set"
+
+ def init_weights(self, *args, **kwargs):
+ with torch.no_grad():
+ if self.bias is not None:
+ self.bias.fill_(0.0)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # Train/inference
+ p = self.p if self.training else 0.0
+
+ # Catch a non-cuda setup, fallback to pytorch
+ if not x.is_cuda:
+ return _fn(x, self.bias, self.pytorch_activation, p)
+
+ # AOTAutograd, NVFuser backed path
+ aot_fn = memory_efficient_fusion(_fn, static_argnums=(2, 3))
+ return aot_fn(x, self.bias, self.pytorch_activation, p)
diff --git a/xformers/components/nvfuser/bias_dropout_res.py b/xformers/components/nvfuser/bias_dropout_res.py
new file mode 100644
index 0000000000..244d02b698
--- /dev/null
+++ b/xformers/components/nvfuser/bias_dropout_res.py
@@ -0,0 +1,64 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from functorch.compile import memory_efficient_fusion
+
+
+def _fn(
+ x: torch.Tensor,
+ bias: Optional[torch.nn.parameter.Parameter],
+ prob: float,
+ residual: torch.Tensor,
+) -> torch.Tensor:
+ a = torch.add(x, bias) if bias is not None else x
+ b = torch.nn.functional.dropout(a, prob) if prob > 0.0 else a
+ return torch.add(b, residual)
+
+
+class NVFusedBiasDropoutRes(torch.nn.Module):
+ """
+ A layer which fuses the computation of Dropout(x + Bias) + Residual
+ with AOTAutograd and nvFuser
+ """
+
+ def __init__(
+ self,
+ p: float,
+ bias_shape: Optional[int] = None,
+ ) -> None:
+ super().__init__()
+
+ self.p = float(p)
+ self.requires_residual = True
+
+ self.bias = (
+ nn.Parameter(torch.zeros(bias_shape)) if bias_shape is not None else None
+ )
+
+ assert (
+ self.p < 1.0
+ ), f"We don't want to drop all the values, most probably p={self.p} is not properly set"
+
+ def init_weights(self, *args, **kwargs):
+ with torch.no_grad():
+ if self.bias is not None:
+ self.bias.fill_(0.0)
+
+ def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
+ # Train/inference
+ p = self.p if self.training else 0.0
+
+ # Catch a non-cuda setup, fallback to pytorch
+ if not x.is_cuda:
+ return _fn(x, self.bias, p, residual)
+
+ # AOTAutograd, NVFuser backed path
+ aot_fn = memory_efficient_fusion(fn=_fn, static_argnums=(2))
+ return aot_fn(x, self.bias, p, residual)
diff --git a/xformers/components/nvfuser/bias_dropout_res_layernorm.py b/xformers/components/nvfuser/bias_dropout_res_layernorm.py
new file mode 100644
index 0000000000..2fb76e5f10
--- /dev/null
+++ b/xformers/components/nvfuser/bias_dropout_res_layernorm.py
@@ -0,0 +1,80 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from functorch.compile import memory_efficient_fusion
+
+from xformers.components import ResidualNormStyle
+
+
+def _fn(
+ x: torch.Tensor,
+ bias: Optional[torch.nn.parameter.Parameter],
+ prob: float,
+ layer_norm_style: Optional[ResidualNormStyle],
+ norm: nn.Module,
+ residual: torch.Tensor,
+) -> torch.Tensor:
+ a = torch.add(x, bias) if bias is not None else x
+ b = torch.nn.functional.dropout(a, prob) if prob > 0.0 else a
+ if layer_norm_style == ResidualNormStyle.Pre:
+ c = norm(b)
+ return torch.add(c, residual)
+ elif layer_norm_style == ResidualNormStyle.Post:
+ c = torch.add(b, residual)
+ return norm(c)
+ else:
+ raise ValueError
+
+
+class NVFusedBiasDropoutResLayerNorm(torch.nn.Module):
+
+ """
+ A layer which fuses the computation of LayerNorm, Residual, and Dropout(x + Bias)
+ operations with AOTAutograd and nvFuser based on specified layer norm style
+ """
+
+ def __init__(
+ self,
+ p: float,
+ d_model: int,
+ bias_shape: Optional[int] = None,
+ layer_norm_style: ResidualNormStyle = ResidualNormStyle.Post,
+ ) -> None:
+ super().__init__()
+
+ self.p = float(p)
+ self.requires_residual = True
+ self.layer_norm_style = layer_norm_style
+
+ self.bias = (
+ nn.Parameter(torch.zeros(bias_shape)) if bias_shape is not None else None
+ )
+ self.norm = nn.LayerNorm(d_model)
+
+ assert (
+ self.p < 1.0
+ ), f"We don't want to drop all the values, most probably p={self.p} is not properly set"
+
+ def init_weights(self, *args, **kwargs):
+ with torch.no_grad():
+ if self.bias is not None:
+ self.bias.fill_(0.0)
+
+ def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
+ # Train/inference
+ p = self.p if self.training else 0.0
+
+ # Catch a non-cuda setup, fallback to pytorch
+ if not x.is_cuda:
+ return _fn(x, self.bias, p, self.layer_norm_style, self.norm, residual)
+
+ # AOTAutograd, NVFuser backed path
+ aot_fn = memory_efficient_fusion(fn=_fn, static_argnums=(2, 3, 4))
+ return aot_fn(x, self.bias, p, self.layer_norm_style, self.norm, residual)
diff --git a/xformers/components/nvfuser/utils.py b/xformers/components/nvfuser/utils.py
new file mode 100644
index 0000000000..d6fef065b1
--- /dev/null
+++ b/xformers/components/nvfuser/utils.py
@@ -0,0 +1,38 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Any, Dict, List, Optional
+
+import torch.nn as nn
+
+from xformers.components import Activation, ResidualNormStyle
+from xformers.components.nvfuser import (
+ NVFusedBiasActivationDropout,
+ NVFusedBiasDropoutRes,
+ NVFusedBiasDropoutResLayerNorm,
+)
+
+
+def build_nvfused(
+ fused_pattern: nn.Module,
+ shape: tuple,
+ bias: bool,
+ activation: Optional[Activation],
+ p: float,
+ layer_norm_style: Optional[ResidualNormStyle],
+):
+ bias_shape = shape[-1] if bias else None
+ d_model = shape[-1]
+ init_args: Dict[nn.Module, List[Any]] = {
+ NVFusedBiasActivationDropout: [p, activation, bias_shape], # type: ignore
+ NVFusedBiasDropoutRes: [p, bias_shape], # type: ignore
+ NVFusedBiasDropoutResLayerNorm: [ # type: ignore
+ p,
+ d_model,
+ bias_shape,
+ layer_norm_style,
+ ],
+ }
+ return fused_pattern(*init_args[fused_pattern])
diff --git a/xformers/triton/layer_norm.py b/xformers/triton/layer_norm.py
index 0ab872c9b9..b437daef9c 100644
--- a/xformers/triton/layer_norm.py
+++ b/xformers/triton/layer_norm.py
@@ -224,7 +224,7 @@ def layer_norm(
_triton_registered_warnings = True
logging.warning(
"Triton layernorm kernel register spillover or invalid image caught. "
- "Deactivating this kernel, please file an issue int the xFormers repository"
+ "Deactivating this kernel, please file an issue in the xFormers repository"
)
logging.warning(e)