Skip to content

Commit

Permalink
using the same eps in layernorm as default torch (#221)
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux authored Mar 1, 2022
1 parent 19af415 commit 2926d46
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.0.x] - TBD
### Fixed
- Expose bias flag for feedforwards, same default as Timm [#220]
- Update eps value for layernormm, same default as torch [#221]

## [0.0.9] - 2022-02-09
### Added
Expand Down
4 changes: 2 additions & 2 deletions tests/test_triton_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def test_layernorm_parity(shape, amp):
eps = 1e-5

# Initialize the two layers, weights are 1 and 0 by default, no randomness
torch_layernorm = torch.nn.LayerNorm(X.shape[-1], eps).to("cuda")
triton_layernorm = FusedLayerNorm(X.shape[-1], eps).to("cuda")
torch_layernorm = torch.nn.LayerNorm(X.shape[-1], eps=eps).to("cuda")
triton_layernorm = FusedLayerNorm(X.shape[-1], affine=True, eps=eps).to("cuda")

with autocast(enabled=amp):
assert torch.allclose(X, X_) # sanity checking, else all hell breaks loose
Expand Down
4 changes: 2 additions & 2 deletions xformers/triton/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class FusedLayerNorm(nn.Module):
"""

def __init__(self, normalized_shape, affine=True, eps=1e-05):
def __init__(self, normalized_shape, affine=True, eps=1e-06):
super().__init__()
if affine:
self.weight = nn.Parameter(torch.ones(normalized_shape))
Expand All @@ -49,7 +49,7 @@ def layer_norm(
x: torch.Tensor,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
eps: float = 1e-05,
eps: float = 1e-06,
) -> torch.Tensor:

global _triton_registered_warnings
Expand Down

0 comments on commit 2926d46

Please sign in to comment.