Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/layernorm #36

Merged
merged 70 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from 69 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
a45997c
feat: add attention
gaetansnl Aug 11, 2022
0745dd2
fix: use tuple in triton
gaetansnl Aug 11, 2022
d06fb8a
docs: attention
gaetansnl Aug 11, 2022
2c4f4bf
feat: add torchdynamo end to end fusion
gaetansnl Aug 12, 2022
0644ae6
feat: causal masked attention
gaetansnl Aug 12, 2022
e6e359b
feat: benchmark dynamo backends
gaetansnl Aug 12, 2022
2ba239c
Merge branch 'main' into feat/torchdynamo-fused
gaetansnl Aug 16, 2022
189ca47
fix: renaming
gaetansnl Aug 16, 2022
0e13b3c
feat: add support for arbitrary stride
gaetansnl Aug 16, 2022
a0f880d
fix: move output outside kernel
gaetansnl Aug 16, 2022
a6dcc70
feat: module replacement example
gaetansnl Aug 17, 2022
b142c1b
fix: missing benchmark for masked attention
gaetansnl Aug 18, 2022
6aceec7
feat: add pattern and fix fx bug
gaetansnl Aug 19, 2022
a4f920f
fix: refactoring
gaetansnl Aug 19, 2022
5157a8b
feat: add layer_norm
gaetansnl Aug 19, 2022
90f989c
fix: show speedup in benchmark display
gaetansnl Aug 23, 2022
f7a0ffa
fix: update torchdyname and matcher
gaetansnl Aug 24, 2022
bf5ecfa
fix: update matcher
gaetansnl Aug 24, 2022
ce73e78
fix: cuda graph
gaetansnl Aug 25, 2022
9590e78
fix: small seq_length
gaetansnl Aug 25, 2022
3534cb5
feat: viz server
gaetansnl Aug 26, 2022
26e5720
fix: bug in matcher and add complete graph report
gaetansnl Aug 26, 2022
1576ef5
fix: compatibility with pytorch stable
gaetansnl Aug 31, 2022
f5729df
fix: add credit, rename variables, add doc
gaetansnl Aug 31, 2022
be04b88
fix: add test for shape change
gaetansnl Aug 31, 2022
7e2cb62
fix: attention renaming
gaetansnl Aug 31, 2022
672ea04
fix: add license
gaetansnl Aug 31, 2022
f5c1ae2
feat: add stride management on linear layer + replace cuda graph + ne…
pommedeterresautee Sep 3, 2022
c8bb059
feat: remove M, N masking
pommedeterresautee Sep 3, 2022
3998bc9
feat: improve autotune
pommedeterresautee Sep 3, 2022
dbf7c93
fix: command line
pommedeterresautee Sep 5, 2022
95fac33
fix: refactoring benchmarks (fix cuda graphs API, tests as dict, add …
pommedeterresautee Sep 5, 2022
368f3e5
fix: linear layer is now working (remove trick on max contiguous), re…
pommedeterresautee Sep 5, 2022
804335b
fix: refactoring of the linear benchmark with sizes similar to bert o…
pommedeterresautee Sep 5, 2022
e93632b
feat: add more instructions
pommedeterresautee Sep 6, 2022
a392150
fix: restore cuda graphs warmup (and remove todo)
pommedeterresautee Sep 6, 2022
efdabc9
feat: change linear implementation
pommedeterresautee Sep 6, 2022
d464f6b
feat: change linear implementation
pommedeterresautee Sep 6, 2022
e9d6b04
feat: replace GELU + layernorm by more precise version, fix all preci…
pommedeterresautee Sep 6, 2022
a76b710
Merge remote-tracking branch 'origin/main' into fix/refactoring_bench…
pommedeterresautee Sep 6, 2022
b91456c
fix: plural
pommedeterresautee Sep 6, 2022
d20619c
fix: some doc
pommedeterresautee Sep 6, 2022
87cf360
fix: layer norm unit test
pommedeterresautee Sep 6, 2022
b112b4b
feat: add cuda graph layer norm unit test
pommedeterresautee Sep 6, 2022
479254f
fix: remove TODO
pommedeterresautee Sep 6, 2022
afc8d41
feat: add split K support
pommedeterresautee Sep 7, 2022
106cb24
Merge branch 'main' into feat/tools
pommedeterresautee Sep 8, 2022
ed6d1a1
Merge branch 'feat/tools' into fix/refactoring_benchmarks
pommedeterresautee Sep 8, 2022
df3daa8
fix: remove split k
pommedeterresautee Sep 8, 2022
cf8e7d0
Merge remote-tracking branch 'origin/fix/refactoring_benchmarks' into…
pommedeterresautee Sep 8, 2022
7d91384
feat: refactoring layernorm test
pommedeterresautee Sep 9, 2022
f2130a3
fix: add back bias and activation tests + refactoring
pommedeterresautee Sep 9, 2022
119e776
feat: make test display understandable
pommedeterresautee Sep 9, 2022
38ee01b
fix: remove benchmark display
gaetansnl Sep 9, 2022
94dda21
fix: remove benchmark display
gaetansnl Sep 9, 2022
4bffeef
Merge branch 'feat/tools' into fix/refactoring_benchmarks
pommedeterresautee Sep 9, 2022
8f91f79
fix: avoid OOM on reference implementation
pommedeterresautee Sep 9, 2022
61cbc1f
fix: get input
pommedeterresautee Sep 9, 2022
f39219c
Merge remote-tracking branch 'origin/fix/refactoring_benchmarks' into…
pommedeterresautee Sep 9, 2022
10703a9
fix: remove some OOM test for reference implementation
pommedeterresautee Sep 9, 2022
dabf536
fix: add tests
pommedeterresautee Sep 9, 2022
bd83caa
feat: new layernorm single pass variance computation implementation
pommedeterresautee Sep 12, 2022
85ae2d8
Merge branch 'main' into fix/refactoring_benchmarks
pommedeterresautee Sep 13, 2022
f3564b0
fix: rename variables
pommedeterresautee Sep 13, 2022
7b06621
Merge branch 'fix/refactoring_benchmarks' into feat/layernorm
pommedeterresautee Sep 13, 2022
c2b8ade
feat: add naive implem of layernorm
pommedeterresautee Sep 13, 2022
cca52e8
Merge branch 'main' into feat/layernorm
pommedeterresautee Sep 13, 2022
1100ef6
fix: store mean/var in layernorm (for bw pass)
pommedeterresautee Sep 13, 2022
2cedbfa
fix: following review comments
pommedeterresautee Sep 14, 2022
7f3aefa
fix: add manual seed
pommedeterresautee Sep 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 82 additions & 6 deletions implementations/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,87 @@

import triton
import triton.language as tl
from triton import JITFunction


# CREDITS: Initially inspired by the Triton tutorial


@triton.jit
def _layer_norm_fwd_fused(
def _layer_norm_fwd_fused_single_pass(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it misses documentaiton and naming: What is A ? What size ? stride of what and what dimension ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Out,
A,
Weight,
Bias,
Mean, std,
stride, N, eps,
BLOCK_SIZE: tl.constexpr,
):
"""
Layernorm based on Welford's variance computation algorithm.
https://changyaochen.github.io/welford/
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance

:param Out: output tensor
:param A: input tensor
:param Weight: weights applied to the normalized input
:param Bias: bias added to the normalized input
:param Mean: save mean tensor for backward
:param std: save standard deviation tensor for backward
:param stride: stride of the input tensor
:param N: number of elements per row in the input tensor
:param eps: epsilon value to avoid division by zero
:param BLOCK_SIZE: number of threads per block
:return: None
"""
# position of elements processed by this program
_idx = tl.program_id(0)
out_ptr = Out + _idx * stride
a_ptr = A + _idx * stride
# compute mean
mean = 0.0
var = 0.0
for start_n_offset in range(0, N, BLOCK_SIZE):
end_n_offset = min((start_n_offset + BLOCK_SIZE), N)
nb_block_cols = end_n_offset - start_n_offset
column_offset = start_n_offset + tl.arange(0, BLOCK_SIZE)
mask = column_offset < N
# eviction policy below have little impact now because of new implementation. Kept as is.
a = tl.load(a_ptr + column_offset, mask=mask, other=0., eviction_policy="evict_last").to(tl.float32)

block_mean = tl.sum(a, axis=0) / nb_block_cols
# mean is 0 or has a mask applied to it, no need to mask delta_mean!
delta_mean = block_mean - mean
delta_mean_sqr = delta_mean * delta_mean

block_delta = tl.sum((a - block_mean) * a, axis=0)
# mean has a mask!
mean += tl.sum((a - mean) * mask, axis=0) / end_n_offset
var += block_delta + delta_mean_sqr * (start_n_offset * nb_block_cols) / end_n_offset

var = var / N
rstd = 1 / tl.sqrt(var + eps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what rstd means ? root std ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed


# write-back mean/rstd for backward pass
tl.store(Mean + _idx, mean)
tl.store(std + _idx, rstd)

# multiply by weight and add bias
for off in range(0, N, BLOCK_SIZE):
column_offset = off + tl.arange(0, BLOCK_SIZE)
mask = column_offset < N
weight = tl.load(Weight + column_offset, mask=mask)
bias = tl.load(Bias + column_offset, mask=mask)
# eviction policy helps to keep weights in cache (reused by other threads)
a = tl.load(a_ptr + column_offset, mask=mask, other=0., eviction_policy="evict_first").to(tl.float32)
a_hat = (a - mean) * rstd
out = a_hat * weight + bias
# write-back
tl.store(out_ptr + column_offset, out, mask=mask)


@triton.jit
def _layer_norm_fwd_fused_multi_pass(
Out,
A,
Weight,
Expand Down Expand Up @@ -49,18 +124,19 @@ def _layer_norm_fwd_fused(
a = tl.load(A + cols, mask=mask, other=0., eviction_policy="evict_first").to(tl.float32)
a_hat = (a - mean) * rstd
out = a_hat * weight + bias
# # write-back
# write-back
tl.store(Out + cols, out, mask=mask)


def layer_norm_forward(a: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
def layer_norm_forward(a: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float, implementation: JITFunction = _layer_norm_fwd_fused_single_pass):
# allocate output
out = torch.empty_like(a)
# reshape input data into 2D tensor
a_arg = a.reshape(-1, a.shape[-1])
M, N = a_arg.shape
# tensors below for backward pass
mean = torch.empty((M,), dtype=torch.float32, device="cuda")
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
std = torch.empty((M,), dtype=torch.float32, device="cuda")
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // a.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
Expand All @@ -69,12 +145,12 @@ def layer_norm_forward(a: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
eps = min(eps, 1e-6) # >= 1e-5 may decrease Bert accuracy
_layer_norm_fwd_fused[(M,)](
implementation[(M,)](
out,
a_arg,
weight,
bias,
mean, rstd,
mean, std,
a_arg.stride(0), N, eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
Expand Down
23 changes: 18 additions & 5 deletions test/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,30 @@
import pytest

from implementations.cuda_graph import cuda_graphs_wrapper
from implementations.layer_norm import layer_norm_forward
from implementations.layer_norm import layer_norm_forward, _layer_norm_fwd_fused_single_pass, \
_layer_norm_fwd_fused_multi_pass


def pytorch_naive(a: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
mean = a.mean(dim=-1, keepdim=True)
var = a.var(dim=-1, keepdim=True)
rstd = 1 / torch.sqrt(var + eps)
a_hat = (a - mean) * rstd
out = a_hat * weight + bias
return out


implementations: dict[str, Callable[[torch.Tensor, torch.Tensor, torch.Tensor, float], torch.Tensor]] = {
"pytorch": lambda x, weight, bias, eps: torch.nn.functional.layer_norm(x, weight.shape, weight, bias, eps),
"triton": lambda x, weight, bias, eps: layer_norm_forward(x, weight, bias, eps),
"triton_original": lambda x, weight, bias, eps: layer_norm_forward(x, weight, bias, eps, _layer_norm_fwd_fused_multi_pass),
"triton_improved": lambda x, weight, bias, eps: layer_norm_forward(x, weight, bias, eps, _layer_norm_fwd_fused_single_pass),
"pytorch_naive": lambda x, weight, bias, eps: pytorch_naive(x, weight, bias, eps),
}


@pytest.mark.parametrize("shape", [128, 512, 1024, 2048, 4096], ids=lambda x: f"shape={x}x{x}")
@pytest.mark.parametrize("shape", [128, 512, 1024, 2048, 4096, 8192], ids=lambda x: f"shape={x}x{x}")
@pytest.mark.parametrize("cuda_graphs", [True, False], ids=["cuda_graphs", "no_cuda_graphs"])
@pytest.mark.parametrize("implementation", ["pytorch", "triton"])
@pytest.mark.parametrize("implementation", ["triton_original", "triton_improved", "pytorch", "pytorch_naive"])
def test_benchmark_layer_norm(benchmark, shape: int, cuda_graphs: bool, implementation: str):
assert implementation in implementations, f"Unknown implementation: {implementation}"

Expand All @@ -39,4 +52,4 @@ def inference(x, *args, **kwargs):

value = benchmark(inference, x, weight, bias, eps)

assert torch.allclose(value, expected, atol=1e-2)
assert torch.allclose(value, expected, atol=1e-1)
1 change: 1 addition & 0 deletions test/test_linear_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_benchmark(benchmark, shape: Shape, bias: bool, activation: str, contigu
batch, M, N, K = dataclasses.astuple(shape)

# order of dimensions is wrong so we force contiguous call

a = torch.randn((batch, K, M), device='cuda', dtype=torch.float16, requires_grad=False)
a = a.mT
if contiguous:
Expand Down
26 changes: 14 additions & 12 deletions test/test_torchdynamo_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@pytest.fixture
def model_baseline_fp32():
def model_reference_fp32():
return get_model_baseline(float_16=False)


Expand Down Expand Up @@ -48,35 +48,37 @@ def get_input_non_causal(shape: (int, int)) -> Dict[str, torch.Tensor]:
}


@pytest.mark.parametrize("input_shape", [(1, 16), (1, 128), (1, 256), (1, 384), (1, 512),
(8, 16), (8, 128), (8, 256), (8, 384), (8, 512),
(32, 16), (32, 128), (32, 256),
], ids=lambda x: f"{x[0]}x{x[1]}")
@pytest.mark.parametrize("shape", [(1, 16), (1, 128), (1, 256), (1, 384), (1, 512),
(8, 16), (8, 128), (8, 256), (8, 384), (8, 512),
(32, 16), (32, 128), (32, 256),
], ids=lambda x: f"{x[0]}x{x[1]}")
@pytest.mark.parametrize("implementation", implementations.keys())
def test_benchmark_implementations(benchmark, model_baseline_fp32, input_shape: (int, int), implementation: str):
def test_benchmark_implementations(benchmark, model_reference_fp32, shape: (int, int), implementation: str):
torch.manual_seed(0)
assert implementation in implementations, f"unknown implementation: {implementation}"
model_tested = implementations[implementation]

inputs = get_input_causal(input_shape) if model_tested.is_causal else get_input_non_causal(input_shape)
inputs = get_input_causal(shape) if model_tested.is_causal else get_input_non_causal(shape)

with torch.inference_mode():
expected = model_baseline_fp32(**inputs)
expected = model_reference_fp32(**inputs)
model = model_tested.model()
value = benchmark(model, **inputs)

torchdynamo.reset()

assert torch.allclose(input=value["last_hidden_state"].float(), other=expected["last_hidden_state"], rtol=1e-1, atol=1e-1)
assert torch.allclose(input=value["last_hidden_state"].float(), other=expected["last_hidden_state"], rtol=1e-1,
atol=1e-1)
assert torch.allclose(input=value["pooler_output"].float(), other=expected["pooler_output"], rtol=1e-1, atol=1e-1)


def test_support_shape_change(model_baseline_fp32):
def test_support_shape_change(model_reference_fp32):
"""Test that the model can handle shape changes without being reloaded/rebuilt."""
for name, implementation in implementations.items():
model_tested = implementation.model()
for shape in [(1, 64), (8, 256), (16, 256), (16, 64)]:
pytorch_input = get_input_causal(shape) if implementation.is_causal else get_input_non_causal(shape)
expected = model_baseline_fp32(**pytorch_input)
expected = model_reference_fp32(**pytorch_input)
result = model_tested(**pytorch_input)
assert torch.allclose(result["last_hidden_state"].float(), expected["last_hidden_state"], atol=1e-1), f"failed on {name} with shape {shape}"
assert torch.allclose(result["last_hidden_state"].float(), expected["last_hidden_state"],
atol=1e-1), f"failed on {name} with shape {shape}"