Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/layernorm #36

Merged
merged 70 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from 68 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
a45997c
feat: add attention
gaetansnl Aug 11, 2022
0745dd2
fix: use tuple in triton
gaetansnl Aug 11, 2022
d06fb8a
docs: attention
gaetansnl Aug 11, 2022
2c4f4bf
feat: add torchdynamo end to end fusion
gaetansnl Aug 12, 2022
0644ae6
feat: causal masked attention
gaetansnl Aug 12, 2022
e6e359b
feat: benchmark dynamo backends
gaetansnl Aug 12, 2022
2ba239c
Merge branch 'main' into feat/torchdynamo-fused
gaetansnl Aug 16, 2022
189ca47
fix: renaming
gaetansnl Aug 16, 2022
0e13b3c
feat: add support for arbitrary stride
gaetansnl Aug 16, 2022
a0f880d
fix: move output outside kernel
gaetansnl Aug 16, 2022
a6dcc70
feat: module replacement example
gaetansnl Aug 17, 2022
b142c1b
fix: missing benchmark for masked attention
gaetansnl Aug 18, 2022
6aceec7
feat: add pattern and fix fx bug
gaetansnl Aug 19, 2022
a4f920f
fix: refactoring
gaetansnl Aug 19, 2022
5157a8b
feat: add layer_norm
gaetansnl Aug 19, 2022
90f989c
fix: show speedup in benchmark display
gaetansnl Aug 23, 2022
f7a0ffa
fix: update torchdyname and matcher
gaetansnl Aug 24, 2022
bf5ecfa
fix: update matcher
gaetansnl Aug 24, 2022
ce73e78
fix: cuda graph
gaetansnl Aug 25, 2022
9590e78
fix: small seq_length
gaetansnl Aug 25, 2022
3534cb5
feat: viz server
gaetansnl Aug 26, 2022
26e5720
fix: bug in matcher and add complete graph report
gaetansnl Aug 26, 2022
1576ef5
fix: compatibility with pytorch stable
gaetansnl Aug 31, 2022
f5729df
fix: add credit, rename variables, add doc
gaetansnl Aug 31, 2022
be04b88
fix: add test for shape change
gaetansnl Aug 31, 2022
7e2cb62
fix: attention renaming
gaetansnl Aug 31, 2022
672ea04
fix: add license
gaetansnl Aug 31, 2022
f5c1ae2
feat: add stride management on linear layer + replace cuda graph + ne…
pommedeterresautee Sep 3, 2022
c8bb059
feat: remove M, N masking
pommedeterresautee Sep 3, 2022
3998bc9
feat: improve autotune
pommedeterresautee Sep 3, 2022
dbf7c93
fix: command line
pommedeterresautee Sep 5, 2022
95fac33
fix: refactoring benchmarks (fix cuda graphs API, tests as dict, add …
pommedeterresautee Sep 5, 2022
368f3e5
fix: linear layer is now working (remove trick on max contiguous), re…
pommedeterresautee Sep 5, 2022
804335b
fix: refactoring of the linear benchmark with sizes similar to bert o…
pommedeterresautee Sep 5, 2022
e93632b
feat: add more instructions
pommedeterresautee Sep 6, 2022
a392150
fix: restore cuda graphs warmup (and remove todo)
pommedeterresautee Sep 6, 2022
efdabc9
feat: change linear implementation
pommedeterresautee Sep 6, 2022
d464f6b
feat: change linear implementation
pommedeterresautee Sep 6, 2022
e9d6b04
feat: replace GELU + layernorm by more precise version, fix all preci…
pommedeterresautee Sep 6, 2022
a76b710
Merge remote-tracking branch 'origin/main' into fix/refactoring_bench…
pommedeterresautee Sep 6, 2022
b91456c
fix: plural
pommedeterresautee Sep 6, 2022
d20619c
fix: some doc
pommedeterresautee Sep 6, 2022
87cf360
fix: layer norm unit test
pommedeterresautee Sep 6, 2022
b112b4b
feat: add cuda graph layer norm unit test
pommedeterresautee Sep 6, 2022
479254f
fix: remove TODO
pommedeterresautee Sep 6, 2022
afc8d41
feat: add split K support
pommedeterresautee Sep 7, 2022
106cb24
Merge branch 'main' into feat/tools
pommedeterresautee Sep 8, 2022
ed6d1a1
Merge branch 'feat/tools' into fix/refactoring_benchmarks
pommedeterresautee Sep 8, 2022
df3daa8
fix: remove split k
pommedeterresautee Sep 8, 2022
cf8e7d0
Merge remote-tracking branch 'origin/fix/refactoring_benchmarks' into…
pommedeterresautee Sep 8, 2022
7d91384
feat: refactoring layernorm test
pommedeterresautee Sep 9, 2022
f2130a3
fix: add back bias and activation tests + refactoring
pommedeterresautee Sep 9, 2022
119e776
feat: make test display understandable
pommedeterresautee Sep 9, 2022
38ee01b
fix: remove benchmark display
gaetansnl Sep 9, 2022
94dda21
fix: remove benchmark display
gaetansnl Sep 9, 2022
4bffeef
Merge branch 'feat/tools' into fix/refactoring_benchmarks
pommedeterresautee Sep 9, 2022
8f91f79
fix: avoid OOM on reference implementation
pommedeterresautee Sep 9, 2022
61cbc1f
fix: get input
pommedeterresautee Sep 9, 2022
f39219c
Merge remote-tracking branch 'origin/fix/refactoring_benchmarks' into…
pommedeterresautee Sep 9, 2022
10703a9
fix: remove some OOM test for reference implementation
pommedeterresautee Sep 9, 2022
dabf536
fix: add tests
pommedeterresautee Sep 9, 2022
bd83caa
feat: new layernorm single pass variance computation implementation
pommedeterresautee Sep 12, 2022
85ae2d8
Merge branch 'main' into fix/refactoring_benchmarks
pommedeterresautee Sep 13, 2022
f3564b0
fix: rename variables
pommedeterresautee Sep 13, 2022
7b06621
Merge branch 'fix/refactoring_benchmarks' into feat/layernorm
pommedeterresautee Sep 13, 2022
c2b8ade
feat: add naive implem of layernorm
pommedeterresautee Sep 13, 2022
cca52e8
Merge branch 'main' into feat/layernorm
pommedeterresautee Sep 13, 2022
1100ef6
fix: store mean/var in layernorm (for bw pass)
pommedeterresautee Sep 13, 2022
2cedbfa
fix: following review comments
pommedeterresautee Sep 14, 2022
7f3aefa
fix: add manual seed
pommedeterresautee Sep 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 65 additions & 4 deletions implementations/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,73 @@

import triton
import triton.language as tl
from triton import JITFunction


# CREDITS: Initially inspired by the Triton tutorial


@triton.jit
def _layer_norm_fwd_fused(
def _layer_norm_fwd_fused_single_pass(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it misses documentaiton and naming: What is A ? What size ? stride of what and what dimension ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Out,
A,
Weight,
Bias,
Mean, Rstd,
stride, N, eps,
BLOCK_SIZE: tl.constexpr,
):
"""
Based on Welford's variance computation algorithm.
https://changyaochen.github.io/welford/
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
"""
# position of elements processed by this program
row = tl.program_id(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have naming convention for this (_id or _idx)

Copy link
Member Author

@pommedeterresautee pommedeterresautee Sep 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from the original implementation

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

Out += row * stride
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a very bad practice IMO

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from original implementation

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

A += row * stride
# compute mean
mean = 0.0
var = 0.0
for start in range(0, N, BLOCK_SIZE):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO start and end should have more explicit name (like in attention kernel)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

end = min((start + BLOCK_SIZE), N)
nb_block_col = end - start
cols = start + tl.arange(0, BLOCK_SIZE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we call this offsets in other kernels

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

mask = cols < N
a = tl.load(A + cols, mask=mask, other=0., eviction_policy="evict_last").to(tl.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you document why eviction_policy="evict_last"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


block_mean = tl.sum(a, axis=0) / nb_block_col
# mean is 0 or has a mask applied to it, no need to mask delta_mean!
delta_mean = block_mean - mean
delta_mean_sqr = delta_mean * delta_mean

block_delta = tl.sum((a - block_mean) * a, axis=0)
# mean has a mask!
mean += tl.sum((a - mean) * mask, axis=0) / end
var += block_delta + delta_mean_sqr * (start * nb_block_col) / end

var = var / N
rstd = 1 / tl.sqrt(var + eps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what rstd means ? root std ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed


# write-back mean/rstd
tl.store(Mean + row, mean)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add why we do this (futur backward)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

tl.store(Rstd + row, rstd)

# multiply by weight and add bias
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
weight = tl.load(Weight + cols, mask=mask)
bias = tl.load(Bias + cols, mask=mask)
a = tl.load(A + cols, mask=mask, other=0., eviction_policy="evict_first").to(tl.float32)
a_hat = (a - mean) * rstd
out = a_hat * weight + bias
# write-back
tl.store(Out + cols, out, mask=mask)


@triton.jit
def _layer_norm_fwd_fused_multi_pass(
Out,
A,
Weight,
Expand Down Expand Up @@ -49,11 +110,11 @@ def _layer_norm_fwd_fused(
a = tl.load(A + cols, mask=mask, other=0., eviction_policy="evict_first").to(tl.float32)
a_hat = (a - mean) * rstd
out = a_hat * weight + bias
# # write-back
# write-back
tl.store(Out + cols, out, mask=mask)


def layer_norm_forward(a: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
def layer_norm_forward(a: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float, implementation: JITFunction = _layer_norm_fwd_fused_single_pass):
# allocate output
out = torch.empty_like(a)
# reshape input data into 2D tensor
Expand All @@ -69,7 +130,7 @@ def layer_norm_forward(a: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
eps = min(eps, 1e-6) # >= 1e-5 may decrease Bert accuracy
_layer_norm_fwd_fused[(M,)](
implementation[(M,)](
out,
a_arg,
weight,
Expand Down
23 changes: 18 additions & 5 deletions test/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,30 @@
import pytest

from implementations.cuda_graph import cuda_graphs_wrapper
from implementations.layer_norm import layer_norm_forward
from implementations.layer_norm import layer_norm_forward, _layer_norm_fwd_fused_single_pass, \
_layer_norm_fwd_fused_multi_pass


def pytorch_naive(a: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
mean = a.mean(dim=-1, keepdim=True)
var = a.var(dim=-1, keepdim=True)
rstd = 1 / torch.sqrt(var + eps)
a_hat = (a - mean) * rstd
out = a_hat * weight + bias
return out


implementations: dict[str, Callable[[torch.Tensor, torch.Tensor, torch.Tensor, float], torch.Tensor]] = {
"pytorch": lambda x, weight, bias, eps: torch.nn.functional.layer_norm(x, weight.shape, weight, bias, eps),
"triton": lambda x, weight, bias, eps: layer_norm_forward(x, weight, bias, eps),
"triton_original": lambda x, weight, bias, eps: layer_norm_forward(x, weight, bias, eps, _layer_norm_fwd_fused_multi_pass),
"triton_improved": lambda x, weight, bias, eps: layer_norm_forward(x, weight, bias, eps, _layer_norm_fwd_fused_single_pass),
"pytorch_naive": lambda x, weight, bias, eps: pytorch_naive(x, weight, bias, eps),
}


@pytest.mark.parametrize("shape", [128, 512, 1024, 2048, 4096], ids=lambda x: f"shape={x}x{x}")
@pytest.mark.parametrize("shape", [128, 512, 1024, 2048, 4096, 8192], ids=lambda x: f"shape={x}x{x}")
@pytest.mark.parametrize("cuda_graphs", [True, False], ids=["cuda_graphs", "no_cuda_graphs"])
@pytest.mark.parametrize("implementation", ["pytorch", "triton"])
@pytest.mark.parametrize("implementation", ["triton_original", "triton_improved", "pytorch", "pytorch_naive"])
def test_benchmark_layer_norm(benchmark, shape: int, cuda_graphs: bool, implementation: str):
assert implementation in implementations, f"Unknown implementation: {implementation}"

Expand All @@ -39,4 +52,4 @@ def inference(x, *args, **kwargs):

value = benchmark(inference, x, weight, bias, eps)

assert torch.allclose(value, expected, atol=1e-2)
assert torch.allclose(value, expected, atol=1e-1)
1 change: 1 addition & 0 deletions test/test_linear_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_benchmark(benchmark, shape: Shape, bias: bool, activation: str, contigu
batch, M, N, K = dataclasses.astuple(shape)

# order of dimensions is wrong so we force contiguous call

a = torch.randn((batch, K, M), device='cuda', dtype=torch.float16, requires_grad=False)
a = a.mT
if contiguous:
Expand Down
26 changes: 14 additions & 12 deletions test/test_torchdynamo_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@pytest.fixture
def model_baseline_fp32():
def model_reference_fp32():
return get_model_baseline(float_16=False)


Expand Down Expand Up @@ -48,35 +48,37 @@ def get_input_non_causal(shape: (int, int)) -> Dict[str, torch.Tensor]:
}


@pytest.mark.parametrize("input_shape", [(1, 16), (1, 128), (1, 256), (1, 384), (1, 512),
(8, 16), (8, 128), (8, 256), (8, 384), (8, 512),
(32, 16), (32, 128), (32, 256),
], ids=lambda x: f"{x[0]}x{x[1]}")
@pytest.mark.parametrize("shape", [(1, 16), (1, 128), (1, 256), (1, 384), (1, 512),
(8, 16), (8, 128), (8, 256), (8, 384), (8, 512),
(32, 16), (32, 128), (32, 256),
], ids=lambda x: f"{x[0]}x{x[1]}")
@pytest.mark.parametrize("implementation", implementations.keys())
def test_benchmark_implementations(benchmark, model_baseline_fp32, input_shape: (int, int), implementation: str):
def test_benchmark_implementations(benchmark, model_reference_fp32, shape: (int, int), implementation: str):
torch.manual_seed(0)
assert implementation in implementations, f"unknown implementation: {implementation}"
model_tested = implementations[implementation]

inputs = get_input_causal(input_shape) if model_tested.is_causal else get_input_non_causal(input_shape)
inputs = get_input_causal(shape) if model_tested.is_causal else get_input_non_causal(shape)

with torch.inference_mode():
expected = model_baseline_fp32(**inputs)
expected = model_reference_fp32(**inputs)
model = model_tested.model()
value = benchmark(model, **inputs)

torchdynamo.reset()

assert torch.allclose(input=value["last_hidden_state"].float(), other=expected["last_hidden_state"], rtol=1e-1, atol=1e-1)
assert torch.allclose(input=value["last_hidden_state"].float(), other=expected["last_hidden_state"], rtol=1e-1,
atol=1e-1)
assert torch.allclose(input=value["pooler_output"].float(), other=expected["pooler_output"], rtol=1e-1, atol=1e-1)


def test_support_shape_change(model_baseline_fp32):
def test_support_shape_change(model_reference_fp32):
"""Test that the model can handle shape changes without being reloaded/rebuilt."""
for name, implementation in implementations.items():
model_tested = implementation.model()
for shape in [(1, 64), (8, 256), (16, 256), (16, 64)]:
pytorch_input = get_input_causal(shape) if implementation.is_causal else get_input_non_causal(shape)
expected = model_baseline_fp32(**pytorch_input)
expected = model_reference_fp32(**pytorch_input)
result = model_tested(**pytorch_input)
assert torch.allclose(result["last_hidden_state"].float(), expected["last_hidden_state"], atol=1e-1), f"failed on {name} with shape {shape}"
assert torch.allclose(result["last_hidden_state"].float(), expected["last_hidden_state"],
atol=1e-1), f"failed on {name} with shape {shape}"