-
Notifications
You must be signed in to change notification settings - Fork 94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/layernorm #36
Merged
Merged
Feat/layernorm #36
Changes from 69 commits
Commits
Show all changes
70 commits
Select commit
Hold shift + click to select a range
a45997c
feat: add attention
gaetansnl 0745dd2
fix: use tuple in triton
gaetansnl d06fb8a
docs: attention
gaetansnl 2c4f4bf
feat: add torchdynamo end to end fusion
gaetansnl 0644ae6
feat: causal masked attention
gaetansnl e6e359b
feat: benchmark dynamo backends
gaetansnl 2ba239c
Merge branch 'main' into feat/torchdynamo-fused
gaetansnl 189ca47
fix: renaming
gaetansnl 0e13b3c
feat: add support for arbitrary stride
gaetansnl a0f880d
fix: move output outside kernel
gaetansnl a6dcc70
feat: module replacement example
gaetansnl b142c1b
fix: missing benchmark for masked attention
gaetansnl 6aceec7
feat: add pattern and fix fx bug
gaetansnl a4f920f
fix: refactoring
gaetansnl 5157a8b
feat: add layer_norm
gaetansnl 90f989c
fix: show speedup in benchmark display
gaetansnl f7a0ffa
fix: update torchdyname and matcher
gaetansnl bf5ecfa
fix: update matcher
gaetansnl ce73e78
fix: cuda graph
gaetansnl 9590e78
fix: small seq_length
gaetansnl 3534cb5
feat: viz server
gaetansnl 26e5720
fix: bug in matcher and add complete graph report
gaetansnl 1576ef5
fix: compatibility with pytorch stable
gaetansnl f5729df
fix: add credit, rename variables, add doc
gaetansnl be04b88
fix: add test for shape change
gaetansnl 7e2cb62
fix: attention renaming
gaetansnl 672ea04
fix: add license
gaetansnl f5c1ae2
feat: add stride management on linear layer + replace cuda graph + ne…
pommedeterresautee c8bb059
feat: remove M, N masking
pommedeterresautee 3998bc9
feat: improve autotune
pommedeterresautee dbf7c93
fix: command line
pommedeterresautee 95fac33
fix: refactoring benchmarks (fix cuda graphs API, tests as dict, add …
pommedeterresautee 368f3e5
fix: linear layer is now working (remove trick on max contiguous), re…
pommedeterresautee 804335b
fix: refactoring of the linear benchmark with sizes similar to bert o…
pommedeterresautee e93632b
feat: add more instructions
pommedeterresautee a392150
fix: restore cuda graphs warmup (and remove todo)
pommedeterresautee efdabc9
feat: change linear implementation
pommedeterresautee d464f6b
feat: change linear implementation
pommedeterresautee e9d6b04
feat: replace GELU + layernorm by more precise version, fix all preci…
pommedeterresautee a76b710
Merge remote-tracking branch 'origin/main' into fix/refactoring_bench…
pommedeterresautee b91456c
fix: plural
pommedeterresautee d20619c
fix: some doc
pommedeterresautee 87cf360
fix: layer norm unit test
pommedeterresautee b112b4b
feat: add cuda graph layer norm unit test
pommedeterresautee 479254f
fix: remove TODO
pommedeterresautee afc8d41
feat: add split K support
pommedeterresautee 106cb24
Merge branch 'main' into feat/tools
pommedeterresautee ed6d1a1
Merge branch 'feat/tools' into fix/refactoring_benchmarks
pommedeterresautee df3daa8
fix: remove split k
pommedeterresautee cf8e7d0
Merge remote-tracking branch 'origin/fix/refactoring_benchmarks' into…
pommedeterresautee 7d91384
feat: refactoring layernorm test
pommedeterresautee f2130a3
fix: add back bias and activation tests + refactoring
pommedeterresautee 119e776
feat: make test display understandable
pommedeterresautee 38ee01b
fix: remove benchmark display
gaetansnl 94dda21
fix: remove benchmark display
gaetansnl 4bffeef
Merge branch 'feat/tools' into fix/refactoring_benchmarks
pommedeterresautee 8f91f79
fix: avoid OOM on reference implementation
pommedeterresautee 61cbc1f
fix: get input
pommedeterresautee f39219c
Merge remote-tracking branch 'origin/fix/refactoring_benchmarks' into…
pommedeterresautee 10703a9
fix: remove some OOM test for reference implementation
pommedeterresautee dabf536
fix: add tests
pommedeterresautee bd83caa
feat: new layernorm single pass variance computation implementation
pommedeterresautee 85ae2d8
Merge branch 'main' into fix/refactoring_benchmarks
pommedeterresautee f3564b0
fix: rename variables
pommedeterresautee 7b06621
Merge branch 'fix/refactoring_benchmarks' into feat/layernorm
pommedeterresautee c2b8ade
feat: add naive implem of layernorm
pommedeterresautee cca52e8
Merge branch 'main' into feat/layernorm
pommedeterresautee 1100ef6
fix: store mean/var in layernorm (for bw pass)
pommedeterresautee 2cedbfa
fix: following review comments
pommedeterresautee 7f3aefa
fix: add manual seed
pommedeterresautee File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,12 +2,87 @@ | |
|
||
import triton | ||
import triton.language as tl | ||
from triton import JITFunction | ||
|
||
|
||
# CREDITS: Initially inspired by the Triton tutorial | ||
|
||
|
||
@triton.jit | ||
def _layer_norm_fwd_fused( | ||
def _layer_norm_fwd_fused_single_pass( | ||
Out, | ||
A, | ||
Weight, | ||
Bias, | ||
Mean, std, | ||
stride, N, eps, | ||
BLOCK_SIZE: tl.constexpr, | ||
): | ||
""" | ||
Layernorm based on Welford's variance computation algorithm. | ||
https://changyaochen.github.io/welford/ | ||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance | ||
|
||
:param Out: output tensor | ||
:param A: input tensor | ||
:param Weight: weights applied to the normalized input | ||
:param Bias: bias added to the normalized input | ||
:param Mean: save mean tensor for backward | ||
:param std: save standard deviation tensor for backward | ||
:param stride: stride of the input tensor | ||
:param N: number of elements per row in the input tensor | ||
:param eps: epsilon value to avoid division by zero | ||
:param BLOCK_SIZE: number of threads per block | ||
:return: None | ||
""" | ||
# position of elements processed by this program | ||
_idx = tl.program_id(0) | ||
out_ptr = Out + _idx * stride | ||
a_ptr = A + _idx * stride | ||
# compute mean | ||
mean = 0.0 | ||
var = 0.0 | ||
for start_n_offset in range(0, N, BLOCK_SIZE): | ||
end_n_offset = min((start_n_offset + BLOCK_SIZE), N) | ||
nb_block_cols = end_n_offset - start_n_offset | ||
column_offset = start_n_offset + tl.arange(0, BLOCK_SIZE) | ||
mask = column_offset < N | ||
# eviction policy below have little impact now because of new implementation. Kept as is. | ||
a = tl.load(a_ptr + column_offset, mask=mask, other=0., eviction_policy="evict_last").to(tl.float32) | ||
|
||
block_mean = tl.sum(a, axis=0) / nb_block_cols | ||
# mean is 0 or has a mask applied to it, no need to mask delta_mean! | ||
delta_mean = block_mean - mean | ||
delta_mean_sqr = delta_mean * delta_mean | ||
|
||
block_delta = tl.sum((a - block_mean) * a, axis=0) | ||
# mean has a mask! | ||
mean += tl.sum((a - mean) * mask, axis=0) / end_n_offset | ||
var += block_delta + delta_mean_sqr * (start_n_offset * nb_block_cols) / end_n_offset | ||
|
||
var = var / N | ||
rstd = 1 / tl.sqrt(var + eps) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what rstd means ? root std ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed |
||
|
||
# write-back mean/rstd for backward pass | ||
tl.store(Mean + _idx, mean) | ||
tl.store(std + _idx, rstd) | ||
|
||
# multiply by weight and add bias | ||
for off in range(0, N, BLOCK_SIZE): | ||
column_offset = off + tl.arange(0, BLOCK_SIZE) | ||
mask = column_offset < N | ||
weight = tl.load(Weight + column_offset, mask=mask) | ||
bias = tl.load(Bias + column_offset, mask=mask) | ||
# eviction policy helps to keep weights in cache (reused by other threads) | ||
a = tl.load(a_ptr + column_offset, mask=mask, other=0., eviction_policy="evict_first").to(tl.float32) | ||
a_hat = (a - mean) * rstd | ||
out = a_hat * weight + bias | ||
# write-back | ||
tl.store(out_ptr + column_offset, out, mask=mask) | ||
|
||
|
||
@triton.jit | ||
def _layer_norm_fwd_fused_multi_pass( | ||
Out, | ||
A, | ||
Weight, | ||
|
@@ -49,18 +124,19 @@ def _layer_norm_fwd_fused( | |
a = tl.load(A + cols, mask=mask, other=0., eviction_policy="evict_first").to(tl.float32) | ||
a_hat = (a - mean) * rstd | ||
out = a_hat * weight + bias | ||
# # write-back | ||
# write-back | ||
tl.store(Out + cols, out, mask=mask) | ||
|
||
|
||
def layer_norm_forward(a: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): | ||
def layer_norm_forward(a: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float, implementation: JITFunction = _layer_norm_fwd_fused_single_pass): | ||
# allocate output | ||
out = torch.empty_like(a) | ||
# reshape input data into 2D tensor | ||
a_arg = a.reshape(-1, a.shape[-1]) | ||
M, N = a_arg.shape | ||
# tensors below for backward pass | ||
mean = torch.empty((M,), dtype=torch.float32, device="cuda") | ||
rstd = torch.empty((M,), dtype=torch.float32, device="cuda") | ||
std = torch.empty((M,), dtype=torch.float32, device="cuda") | ||
# Less than 64KB per feature: enqueue fused kernel | ||
MAX_FUSED_SIZE = 65536 // a.element_size() | ||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) | ||
|
@@ -69,12 +145,12 @@ def layer_norm_forward(a: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | |
# heuristics for number of warps | ||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8) | ||
eps = min(eps, 1e-6) # >= 1e-5 may decrease Bert accuracy | ||
_layer_norm_fwd_fused[(M,)]( | ||
implementation[(M,)]( | ||
out, | ||
a_arg, | ||
weight, | ||
bias, | ||
mean, rstd, | ||
mean, std, | ||
a_arg.stride(0), N, eps, | ||
BLOCK_SIZE=BLOCK_SIZE, | ||
num_warps=num_warps, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it misses documentaiton and naming: What is A ? What size ? stride of what and what dimension ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added