-
Notifications
You must be signed in to change notification settings - Fork 94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/layernorm #36
Feat/layernorm #36
Changes from 68 commits
a45997c
0745dd2
d06fb8a
2c4f4bf
0644ae6
e6e359b
2ba239c
189ca47
0e13b3c
a0f880d
a6dcc70
b142c1b
6aceec7
a4f920f
5157a8b
90f989c
f7a0ffa
bf5ecfa
ce73e78
9590e78
3534cb5
26e5720
1576ef5
f5729df
be04b88
7e2cb62
672ea04
f5c1ae2
c8bb059
3998bc9
dbf7c93
95fac33
368f3e5
804335b
e93632b
a392150
efdabc9
d464f6b
e9d6b04
a76b710
b91456c
d20619c
87cf360
b112b4b
479254f
afc8d41
106cb24
ed6d1a1
df3daa8
cf8e7d0
7d91384
f2130a3
119e776
38ee01b
94dda21
4bffeef
8f91f79
61cbc1f
f39219c
10703a9
dabf536
bd83caa
85ae2d8
f3564b0
7b06621
c2b8ade
cca52e8
1100ef6
2cedbfa
7f3aefa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,12 +2,73 @@ | |
|
||
import triton | ||
import triton.language as tl | ||
from triton import JITFunction | ||
|
||
|
||
# CREDITS: Initially inspired by the Triton tutorial | ||
|
||
|
||
@triton.jit | ||
def _layer_norm_fwd_fused( | ||
def _layer_norm_fwd_fused_single_pass( | ||
Out, | ||
A, | ||
Weight, | ||
Bias, | ||
Mean, Rstd, | ||
stride, N, eps, | ||
BLOCK_SIZE: tl.constexpr, | ||
): | ||
""" | ||
Based on Welford's variance computation algorithm. | ||
https://changyaochen.github.io/welford/ | ||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance | ||
""" | ||
# position of elements processed by this program | ||
row = tl.program_id(0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we have naming convention for this (_id or _idx) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. from the original implementation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed |
||
Out += row * stride | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a very bad practice IMO There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. from original implementation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated |
||
A += row * stride | ||
# compute mean | ||
mean = 0.0 | ||
var = 0.0 | ||
for start in range(0, N, BLOCK_SIZE): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO start and end should have more explicit name (like in attention kernel) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed |
||
end = min((start + BLOCK_SIZE), N) | ||
nb_block_col = end - start | ||
cols = start + tl.arange(0, BLOCK_SIZE) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we call this offsets in other kernels There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed |
||
mask = cols < N | ||
a = tl.load(A + cols, mask=mask, other=0., eviction_policy="evict_last").to(tl.float32) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you document why eviction_policy="evict_last" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added |
||
|
||
block_mean = tl.sum(a, axis=0) / nb_block_col | ||
# mean is 0 or has a mask applied to it, no need to mask delta_mean! | ||
delta_mean = block_mean - mean | ||
delta_mean_sqr = delta_mean * delta_mean | ||
|
||
block_delta = tl.sum((a - block_mean) * a, axis=0) | ||
# mean has a mask! | ||
mean += tl.sum((a - mean) * mask, axis=0) / end | ||
var += block_delta + delta_mean_sqr * (start * nb_block_col) / end | ||
|
||
var = var / N | ||
rstd = 1 / tl.sqrt(var + eps) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what rstd means ? root std ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed |
||
|
||
# write-back mean/rstd | ||
tl.store(Mean + row, mean) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you add why we do this (futur backward) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added |
||
tl.store(Rstd + row, rstd) | ||
|
||
# multiply by weight and add bias | ||
for off in range(0, N, BLOCK_SIZE): | ||
cols = off + tl.arange(0, BLOCK_SIZE) | ||
mask = cols < N | ||
weight = tl.load(Weight + cols, mask=mask) | ||
bias = tl.load(Bias + cols, mask=mask) | ||
a = tl.load(A + cols, mask=mask, other=0., eviction_policy="evict_first").to(tl.float32) | ||
a_hat = (a - mean) * rstd | ||
out = a_hat * weight + bias | ||
# write-back | ||
tl.store(Out + cols, out, mask=mask) | ||
|
||
|
||
@triton.jit | ||
def _layer_norm_fwd_fused_multi_pass( | ||
Out, | ||
A, | ||
Weight, | ||
|
@@ -49,11 +110,11 @@ def _layer_norm_fwd_fused( | |
a = tl.load(A + cols, mask=mask, other=0., eviction_policy="evict_first").to(tl.float32) | ||
a_hat = (a - mean) * rstd | ||
out = a_hat * weight + bias | ||
# # write-back | ||
# write-back | ||
tl.store(Out + cols, out, mask=mask) | ||
|
||
|
||
def layer_norm_forward(a: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): | ||
def layer_norm_forward(a: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float, implementation: JITFunction = _layer_norm_fwd_fused_single_pass): | ||
# allocate output | ||
out = torch.empty_like(a) | ||
# reshape input data into 2D tensor | ||
|
@@ -69,7 +130,7 @@ def layer_norm_forward(a: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | |
# heuristics for number of warps | ||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8) | ||
eps = min(eps, 1e-6) # >= 1e-5 may decrease Bert accuracy | ||
_layer_norm_fwd_fused[(M,)]( | ||
implementation[(M,)]( | ||
out, | ||
a_arg, | ||
weight, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it misses documentaiton and naming: What is A ? What size ? stride of what and what dimension ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added