Skip to content

Commit

Permalink
[docs][minor] Adding more fused layers data to the HTML docs (#174)
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux authored Jan 4, 2022
1 parent f66f9d1 commit 3683752
Show file tree
Hide file tree
Showing 18 changed files with 104 additions and 63 deletions.
6 changes: 2 additions & 4 deletions BENCHMARKS.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ You can reproduce these numbers locally by running `python3 xformers/benchmarks/

### Fused linear layer

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_fused_linear_layer.py`. The units are TFlops/s. These results are for a nVidia V100, Triton 1.1 and PyTorch 1.9.
You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_fused_linear_layer.py`. The units are TFlops/s. These results are for a laptop nVidia 3080, Triton 1.1 and PyTorch 1.10.

**As of October 2021, these Triton kernels are only competitive with Pytorch in float16, this is a work in progress**.

![Fused linear layers throughput in fp16 - inference](docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png)

Expand All @@ -78,8 +77,7 @@ You can reproduce these numbers locally by running `python3 xformers/benchmarks/

### Fused layer norm

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_layernorm.py`. The units are GB/s. These results are for a nVidia V100, Triton 1.1 and PyTorch 1.9.
Note that in the Triton case the slowdowns at extreme sizes are because of register spilling, A100s get much better performance.
You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_layernorm.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 1.1 and PyTorch 1.10.

![Fused layer norm throughput in fp16 - inference](docs/plots/layer_norm/LayerNorm_FW_torch.float16.png)

Expand Down
Binary file removed docs/plots/fused_linear/FusedLinear_fp16_FW.png
Binary file not shown.
Binary file removed docs/plots/fused_linear/FusedLinear_fp16_FW_BW.png
Binary file not shown.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_gelu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_leaky_relu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_none.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_relu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_leaky_relu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_none.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_relu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_squared_relu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/layer_norm/LayerNorm_FW+BW_torch.float16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/layer_norm/LayerNorm_FW+BW_torch.float32.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/layer_norm/LayerNorm_FW_torch.float16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/layer_norm/LayerNorm_FW_torch.float32.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
161 changes: 102 additions & 59 deletions docs/source/tutorials/triton.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,52 +26,24 @@ Log-softmax is also available. The actual Triton kernel is very similar to `this
The expected throughput, when compared to PyTorch and on a nVidia V100, is along these lines

+-----------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+---------------------+--------------------------+
| torch.float16 | Unit: GB/s | | | | | | |
+===================================+======================+====================+====================+======================+====================+=====================+==========================+
| | B=8, M=384, K=128 | B=8, M=784, K=512 | B=4, M=2048, K=384| B=4, M=3136, K=1024 | B=2, M=1024, K=2048| B=2, M=2048, K=4096 | B=2, M=4096, K=4096 |
+-----------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+---------------------+--------------------------+
|pytorch - fw | 170.7 | 501.8 | 512.0 | 597.3 | 399.6 | 524.3 | 553.0 |
+-----------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+---------------------+--------------------------+
|triton - fw | 153.6 | 522.7 | 512.0 | 716.8 | 606.8 | 736.4 | 775.6 |
+-----------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+---------------------+--------------------------+
|pytorch - log - fw | 192.0 | 545.4 | 534.3 | 669.0 | 496.5 | 601.2 | 615.4 |
+-----------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+---------------------+--------------------------+
|triton - log - fw | 153.6 | 570.2 | 558.5 | 748.9 | 682.7 | 780.2 | 799.2 |
+-----------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+---------------------+--------------------------+
|pytorch - fw+bw | 71.4 | 170.7 | 168.3 | 205.6 | 164.7 | 196.5 | 203.5 |
+-----------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+---------------------+--------------------------+
|triton - fw+bw | 69.8 | 218.2 | 211.9 | 264.8 | 224.4 | 271.4 | 284.3 |
+-----------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+---------------------+--------------------------+
|pytorch - log - fw+bw | 78.8 | 207.3 | 204.8 | 255.3 | 206.1 | 247.3 | 255.5 |
+-----------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+---------------------+--------------------------+
|triton - log - fw+bw | 71.4 | 220.1 | 213.7 | 266.9 | 229.1 | 273.6 | 285.6 |
+-----------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+---------------------+--------------------------+


+-----------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+---------------------+--------------------------+
| torch.float32 | Unit: GB/s | | | | | | |
+===================================+======================+====================+====================+======================+====================+=====================+==========================+
| | B=8, M=384, K=128 | B=8, M=784, K=512 | B=4, M=2048, K=384 | B=4, M=3136, K=1024 | B=2, M=1024, K=2048| B=2, M=2048, K=4096 | B=2, M=4096, K=4096 |
+-----------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+---------------------+--------------------------+
|pytorch - fw | 341.3 | 660.2 | 682.7 | 760.2 | 555.4 | 636.3 | 650.5 |
+-----------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+---------------------+--------------------------+
|triton - fw | 307.2 | 678.1 | 682.7 | 784.0 | 712.3 | 789.6 | 809.1 |
+-----------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+---------------------+--------------------------+
|pytorch - log - fw | 384.0 | 696.9 | 702.2 | 777.9 | 537.2 | 541.6 | 543.9 |
+-----------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+---------------------+--------------------------+
|triton - log - fw | 307.2 | 696.9 | 702.2 | 796.4 | 744.7 | 799.2 | 814.1 |
+-----------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+---------------------+--------------------------+
|pytorch - fw+bw | 133.6 | 203.1 | 204.0 | 229.9 | 193.9 | 211.1 | 215.3 |
+-----------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+---------------------+--------------------------+
|triton - fw+bw | 136.5 | 254.7 | 257.3 | 290.9 | 263.2 | 294.5 | 301.0 |
+-----------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+---------------------+--------------------------+
|pytorch - log - fw+bw | 149.9 | 252.1 | 252.1 | 289.6 | 234.1 | 251.6 | 254.5 |
+-----------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+---------------------+--------------------------+
|triton - log - fw+bw | 136.5 | 257.3 | 258.7 | 291.7 | 265.3 | 295.2 | 301.3 |
+-----------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+---------------------+--------------------------+
.. image:: ../../plots/fused_softmax/Softmax_Bandwidth_FW_fp16.png
:width: 600
:alt: Softmax throughput in fp16 - inference


.. image:: ../../plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp16.png
:width: 600
:alt: Softmax throughput in fp16 - training

.. image:: ../../plots/fused_softmax/Softmax_Bandwidth_FW_fp32.png
:width: 600
:alt: Softmax throughput in fp32 - inference


.. image:: ../../plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp32.png
:width: 600
:alt: Softmax throughput in fp32 - training

Fused linear layer
-------------------
This is a drop-in replacement to two PyTorch operands: a `torch.nn.Linear`, and an activation, like `torch.nn.ReLU`. It is Torch AMP and autograd aware, and can be used very simply:
Expand All @@ -89,22 +61,93 @@ This is a drop-in replacement to two PyTorch operands: a `torch.nn.Linear`, and
It is possible to skip either the bias or the activation (just use `None` in that case). As of September 2021, this layer is **faster than PyTorch for non-sigmoid activations and fp16**.
In all other usecases, you will be better served using PyTorch.

The following is an example of the measured performance on a nVidia V100.

+-----------------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+
| torch.float16 | Unit: TFlops | | | | |
+=========================================+======================+====================+====================+======================+====================+
| | B=8, M=256, K=512 | B=8, M=512, K=1024 | B=4, M=1024, K=1024| B=2, M=2048, K=2048 | B=2, M=4096, K=4096|
+-----------------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+
| pytorch - squared_relu - bias - fw | 6.3 | 12.4 | 12.3 | 17.1 | 19.0 |
+-----------------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+
| triton - squared_relu - bias - fw | 13.8 | 18.9 | 18.9 | 21.9 | 21.7 |
+-----------------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+
| pytorch - squared_relu - bias - fw+bw | 4.0 | 7.6 | 7.7 | 10.7 | 12.6 |
+-----------------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+
| triton - squared_relu - bias - fw+bw | 8.4 | 13.5 | 13.3 | 15.9 | 16.8 |
+-----------------------------------------+----------------------+--------------------+--------------------+----------------------+--------------------+
The following is an example of the measured performance on a laptop nVidia 3080, using Triton 1.1 and PyTorch 1.10.

.. image:: ../../plots/fused_linear/FusedLinear_fp16_FW_gelu.png
:width: 600
:alt: Fused linear layers throughput in fp16 - inference - GeLU

.. image:: ../../plots/fused_linear/FusedLinear_fp16_FW_BW_gelu.png
:width: 600
:alt: Fused linear layers throughput in fp16 - training - GeLU

--

.. image:: ../../plots/fused_linear/FusedLinear_fp16_FW_leaky_relu.png
:width: 600
:alt: Fused linear layers throughput in fp16 - inference - LeakyReLU

.. image:: ../../plots/fused_linear/FusedLinear_fp16_FW_BW_leaky_relu.png
:width: 600
:alt: Fused linear layers throughput in fp16 - training - LeakyReLU

--

.. image:: ../../plots/fused_linear/FusedLinear_fp16_FW_squared_relu.png
:width: 600
:alt: Fused linear layers throughput in fp16 - inference - Squared ReLU

.. image:: ../../plots/fused_linear/FusedLinear_fp16_FW_BW_squared_relu.png
:width: 600
:alt: Fused linear layers throughput in fp16 - training - Squared ReLU

--

.. image:: ../../plots/fused_linear/FusedLinear_fp16_FW_relu.png
:width: 600
:alt: Fused linear layers throughput in fp16 - inference - ReLU

.. image:: ../../plots/fused_linear/FusedLinear_fp16_FW_BW_relu.png
:width: 600
:alt: Fused linear layers throughput in fp16 - training - ReLU




Fused layer norm
-----------------

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_layernorm.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 1.1 and PyTorch 1.10.

.. image:: ../../plots/layer_norm/LayerNorm_FW_torch.float16.png
:width: 600
:alt: Fused layer norm throughput in fp16 - inference

.. image:: ../../plots/layer_norm/LayerNorm_FW+BW_torch.float16.png
:width: 600
:alt: Fused layer norm throughput in fp16 - training

.. image:: ../../plots/layer_norm/LayerNorm_FW_torch.float32.png
:width: 600
:alt: Fused layer norm throughput in fp32 - inference

.. image:: ../../plots/layer_norm/LayerNorm_FW+BW_torch.float32.png
:width: 600
:alt: Fused layer norm throughput in fp32 - training


Fused dropout + bias + activation
---------------------------------

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_dropout.py`. The units are GB/s.
These results are for a laptop nVidia 3080, Triton 1.1 and PyTorch 1.10.


.. image:: ../../plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_gelu.png
:width: 600
:alt: Fused dropout+ bias throughput in fp16 - inference - GeLU

.. image:: ../../plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_gelu.png
:width: 600
:alt: Fused dropout+ bias throughput in fp16 - training - GeLU

.. image:: ../../plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_squared_relu.png
:width: 600
:alt: Fused dropout+ bias throughput in fp16 - inference - Squared ReLU

.. image:: ../../plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_squared_relu.png
:width: 600
:alt: Fused dropout+ bias throughput in fp16 - training - Squared ReLU


.. _Triton: https://triton-lang.org/
Expand Down

0 comments on commit 3683752

Please sign in to comment.