-
Notifications
You must be signed in to change notification settings - Fork 94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/layernorm #36
Feat/layernorm #36
Conversation
… fix/refactoring_benchmarks
# Conflicts: # README.md # implementations/activation_func.py # implementations/attention_masked_original.py # implementations/layer_norm.py # implementations/linear_layer.py # optimizer/dynamo_backend.py # optimizer/layer_norm.py # optimizer/linear.py # test/models/bert.py # test/test_attention.py # test/test_batched_matmul.py # test/test_layer_norm.py # test/test_linear_layer.py # test/test_torchdynamo_bert.py
# Conflicts: # implementations/layer_norm.py # test/test_layer_norm.py # test/test_linear_layer.py # test/test_torchdynamo_bert.py
|
||
# CREDITS: Initially inspired by the Triton tutorial | ||
|
||
|
||
@triton.jit | ||
def _layer_norm_fwd_fused( | ||
def _layer_norm_fwd_fused_single_pass( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it misses documentaiton and naming: What is A ? What size ? stride of what and what dimension ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
implementations/layer_norm.py
Outdated
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance | ||
""" | ||
# position of elements processed by this program | ||
row = tl.program_id(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have naming convention for this (_id or _idx)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from the original implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
implementations/layer_norm.py
Outdated
""" | ||
# position of elements processed by this program | ||
row = tl.program_id(0) | ||
Out += row * stride |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a very bad practice IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from original implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
implementations/layer_norm.py
Outdated
# compute mean | ||
mean = 0.0 | ||
var = 0.0 | ||
for start in range(0, N, BLOCK_SIZE): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO start and end should have more explicit name (like in attention kernel)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
implementations/layer_norm.py
Outdated
for start in range(0, N, BLOCK_SIZE): | ||
end = min((start + BLOCK_SIZE), N) | ||
nb_block_col = end - start | ||
cols = start + tl.arange(0, BLOCK_SIZE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we call this offsets in other kernels
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
implementations/layer_norm.py
Outdated
nb_block_col = end - start | ||
cols = start + tl.arange(0, BLOCK_SIZE) | ||
mask = cols < N | ||
a = tl.load(A + cols, mask=mask, other=0., eviction_policy="evict_last").to(tl.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you document why eviction_policy="evict_last"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
implementations/layer_norm.py
Outdated
rstd = 1 / tl.sqrt(var + eps) | ||
|
||
# write-back mean/rstd | ||
tl.store(Mean + row, mean) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add why we do this (futur backward)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
var += block_delta + delta_mean_sqr * (start * nb_block_col) / end | ||
|
||
var = var / N | ||
rstd = 1 / tl.sqrt(var + eps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what rstd means ? root std ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
a single pass layernorm implementation based on welford formula
fix #40
https://jonisalonen.com/2013/deriving-welfords-method-for-computing-variance/