diff --git a/CHANGELOG.md b/CHANGELOG.md index cfeac690a7..6ffe4de687 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.0.x] - TBD ### Fixed - Expose bias flag for feedforwards, same default as Timm [#220] +- Update eps value for layernormm, same default as torch [#221] ## [0.0.9] - 2022-02-09 ### Added diff --git a/tests/test_triton_layernorm.py b/tests/test_triton_layernorm.py index 4d55658c33..0af7bbe82a 100644 --- a/tests/test_triton_layernorm.py +++ b/tests/test_triton_layernorm.py @@ -53,8 +53,8 @@ def test_layernorm_parity(shape, amp): eps = 1e-5 # Initialize the two layers, weights are 1 and 0 by default, no randomness - torch_layernorm = torch.nn.LayerNorm(X.shape[-1], eps).to("cuda") - triton_layernorm = FusedLayerNorm(X.shape[-1], eps).to("cuda") + torch_layernorm = torch.nn.LayerNorm(X.shape[-1], eps=eps).to("cuda") + triton_layernorm = FusedLayerNorm(X.shape[-1], affine=True, eps=eps).to("cuda") with autocast(enabled=amp): assert torch.allclose(X, X_) # sanity checking, else all hell breaks loose diff --git a/xformers/triton/layer_norm.py b/xformers/triton/layer_norm.py index 89c99b23b3..03d368fc11 100644 --- a/xformers/triton/layer_norm.py +++ b/xformers/triton/layer_norm.py @@ -32,7 +32,7 @@ class FusedLayerNorm(nn.Module): """ - def __init__(self, normalized_shape, affine=True, eps=1e-05): + def __init__(self, normalized_shape, affine=True, eps=1e-06): super().__init__() if affine: self.weight = nn.Parameter(torch.ones(normalized_shape)) @@ -49,7 +49,7 @@ def layer_norm( x: torch.Tensor, weight: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, - eps: float = 1e-05, + eps: float = 1e-06, ) -> torch.Tensor: global _triton_registered_warnings