Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/layernorm #36

Merged
merged 70 commits into from
Sep 15, 2022
Merged

Feat/layernorm #36

merged 70 commits into from
Sep 15, 2022

Conversation

pommedeterresautee
Copy link
Member

@pommedeterresautee pommedeterresautee commented Sep 12, 2022

a single pass layernorm implementation based on welford formula

fix #40

https://jonisalonen.com/2013/deriving-welfords-method-for-computing-variance/

@pommedeterresautee pommedeterresautee added benchmark Measure, measure, measure optimization labels Sep 12, 2022
@pommedeterresautee pommedeterresautee self-assigned this Sep 12, 2022
@pommedeterresautee pommedeterresautee changed the base branch from main to fix/refactoring_benchmarks September 12, 2022 13:54
# Conflicts:
#	README.md
#	implementations/activation_func.py
#	implementations/attention_masked_original.py
#	implementations/layer_norm.py
#	implementations/linear_layer.py
#	optimizer/dynamo_backend.py
#	optimizer/layer_norm.py
#	optimizer/linear.py
#	test/models/bert.py
#	test/test_attention.py
#	test/test_batched_matmul.py
#	test/test_layer_norm.py
#	test/test_linear_layer.py
#	test/test_torchdynamo_bert.py
Base automatically changed from fix/refactoring_benchmarks to main September 13, 2022 15:44
# Conflicts:
#	implementations/layer_norm.py
#	test/test_layer_norm.py
#	test/test_linear_layer.py
#	test/test_torchdynamo_bert.py
@pommedeterresautee pommedeterresautee marked this pull request as ready for review September 13, 2022 15:48

# CREDITS: Initially inspired by the Triton tutorial


@triton.jit
def _layer_norm_fwd_fused(
def _layer_norm_fwd_fused_single_pass(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it misses documentaiton and naming: What is A ? What size ? stride of what and what dimension ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
"""
# position of elements processed by this program
row = tl.program_id(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have naming convention for this (_id or _idx)

Copy link
Member Author

@pommedeterresautee pommedeterresautee Sep 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from the original implementation

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

"""
# position of elements processed by this program
row = tl.program_id(0)
Out += row * stride
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a very bad practice IMO

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from original implementation

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

# compute mean
mean = 0.0
var = 0.0
for start in range(0, N, BLOCK_SIZE):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO start and end should have more explicit name (like in attention kernel)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

for start in range(0, N, BLOCK_SIZE):
end = min((start + BLOCK_SIZE), N)
nb_block_col = end - start
cols = start + tl.arange(0, BLOCK_SIZE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we call this offsets in other kernels

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

nb_block_col = end - start
cols = start + tl.arange(0, BLOCK_SIZE)
mask = cols < N
a = tl.load(A + cols, mask=mask, other=0., eviction_policy="evict_last").to(tl.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you document why eviction_policy="evict_last"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

rstd = 1 / tl.sqrt(var + eps)

# write-back mean/rstd
tl.store(Mean + row, mean)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add why we do this (futur backward)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

var += block_delta + delta_mean_sqr * (start * nb_block_col) / end

var = var / N
rstd = 1 / tl.sqrt(var + eps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what rstd means ? root std ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

@pommedeterresautee pommedeterresautee merged commit 1ed7a03 into main Sep 15, 2022
@pommedeterresautee pommedeterresautee deleted the feat/layernorm branch September 15, 2022 07:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
benchmark Measure, measure, measure
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement efficient single pass on data variance computation in layernorm kernel
2 participants