-
Notifications
You must be signed in to change notification settings - Fork 633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[bug] Pre-norm wrapper only normalizing the first input #233
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import torch | ||
|
||
from xformers.components import PreNorm | ||
|
||
|
||
class Passthrough(torch.nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
def forward(self, *args): | ||
return args | ||
|
||
|
||
def test_pre_norm(): | ||
# Check that passing the same tensor a bunch of times skips the extra normalizations | ||
x = torch.rand((3, 3)) | ||
|
||
wrap = PreNorm(d_model=3, sublayer=Passthrough(), use_triton=False) | ||
outputs = wrap(inputs=[x, x, x]) | ||
|
||
assert id(outputs[0]) == id(outputs[1]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,14 +16,6 @@ | |
from xformers.triton.layer_norm import FusedLayerNorm | ||
|
||
|
||
def _to_tensor_list( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this was supposed to be a helper, but in the end it was masking bugs (in that layer(x, y, z) could have a different behaviour depending on the residual wraps). I think that it's better to force inputs in a single fashion |
||
inputs: Union[torch.Tensor, List[torch.Tensor]] | ||
) -> List[torch.Tensor]: | ||
if not isinstance(inputs, list): | ||
inputs = [inputs] | ||
return inputs | ||
|
||
|
||
class LayerNormStyle(str, Enum): | ||
"""Support different layer norm styles. | ||
See "On Layer Normalization in the Transformer Architecture", | ||
|
@@ -34,21 +26,37 @@ class LayerNormStyle(str, Enum): | |
Post = "post" | ||
|
||
|
||
class RequiresWrappedInputs: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. classes which derive from this class only accept a single input list (makes it impossible to subtly footgun) |
||
"""Used to mark, through inheritance, | ||
the fact that this class will require inputs to be passed as a single list""" | ||
|
||
pass | ||
|
||
|
||
# CREDITS: the following is inspired by FastAI's Transformer implementation | ||
class Residual(nn.Module): | ||
"""Object-oriented handling of the residual path""" | ||
class Residual(nn.Module, RequiresWrappedInputs): | ||
""" | ||
Object-oriented handling of the residual path | ||
|
||
.. Note: the wrapped layers must accept all the inputs as a single list | ||
""" | ||
|
||
def __init__(self, layer: nn.Module): | ||
super().__init__() | ||
self.layer = layer | ||
|
||
def forward(self, inputs: Union[torch.Tensor, List[torch.Tensor]], *args, **kwargs): | ||
inputs = _to_tensor_list(inputs) | ||
# PreNorm and PostNorm require all the tensors to be passed as a list | ||
self.wrap_inputs = isinstance(layer, PreNorm) or isinstance(layer, PostNorm) | ||
|
||
return inputs[0] + self.layer(*inputs, *args, **kwargs) | ||
def forward(self, inputs: List[torch.Tensor], **kwargs): | ||
if self.wrap_inputs: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the trick here is that these residual/norm wrapper can wrap themselves at times. When they wrap an external layer, then the inputs are unrolled, when the sublayer is another wrap then we maintain inputs=List[Tensor] to prevent bugs like this one |
||
return inputs[0] + self.layer(inputs=inputs, **kwargs) | ||
|
||
else: | ||
return inputs[0] + self.layer(*inputs, **kwargs) | ||
|
||
|
||
class PreNorm(nn.Module): | ||
class PreNorm(nn.Module, RequiresWrappedInputs): | ||
"""Adds LayerNorm before computing attention | ||
|
||
..Note: If a list of inputs is passed, all of them get normalized""" | ||
|
@@ -61,15 +69,28 @@ def __init__(self, d_model: int, sublayer: nn.Module, use_triton: bool = True): | |
self.norm = nn.LayerNorm(d_model) | ||
|
||
self.sublayer = sublayer | ||
self.wrap_inputs = isinstance(sublayer, RequiresWrappedInputs) | ||
|
||
def forward(self, inputs: Union[torch.Tensor, List[torch.Tensor]], *args, **kwargs): | ||
inputs = _to_tensor_list(inputs) | ||
def forward(self, inputs: List[torch.Tensor], **kwargs): | ||
assert len(inputs) > 0 | ||
|
||
x_norm = [self.norm(x_) for x_ in inputs] | ||
return self.sublayer(*x_norm, *args, **kwargs) | ||
# Perf improvement: if the inputs are all the same, only norm once | ||
ids = [id(x) for x in inputs] | ||
if ids.count(ids[0]) == len(ids): | ||
# The same tensor is passed multiple times | ||
x_norm = self.norm(inputs[0]) | ||
inputs_normed = [x_norm for _ in inputs] | ||
else: | ||
# The inputs differ, norm them all | ||
inputs_normed = [self.norm(x_) for x_ in inputs] | ||
|
||
if self.wrap_inputs: | ||
return self.sublayer(inputs=inputs_normed, **kwargs) | ||
else: | ||
return self.sublayer(*inputs_normed, **kwargs) | ||
|
||
class PostNorm(nn.Module): | ||
|
||
class PostNorm(nn.Module, RequiresWrappedInputs): | ||
"""Adds LayerNorm after computing attention""" | ||
|
||
def __init__(self, d_model: int, sublayer: nn.Module, use_triton: bool = True): | ||
|
@@ -80,9 +101,11 @@ def __init__(self, d_model: int, sublayer: nn.Module, use_triton: bool = True): | |
self.norm = nn.LayerNorm(d_model) | ||
|
||
self.sublayer = sublayer | ||
self.wrap_inputs = isinstance(sublayer, RequiresWrappedInputs) | ||
|
||
def forward(self, inputs: Union[torch.Tensor, List[torch.Tensor]], *args, **kwargs): | ||
inputs = _to_tensor_list(inputs) | ||
|
||
x = self.sublayer(*inputs, *args, **kwargs) | ||
def forward(self, inputs: List[torch.Tensor], **kwargs): | ||
if self.wrap_inputs: | ||
x = self.sublayer(inputs=inputs, **kwargs) | ||
else: | ||
x = self.sublayer(*inputs, **kwargs) | ||
return self.norm(x) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -335,8 +335,8 @@ def forward( | |
q, k, v = x, x, x | ||
|
||
# Pre/Post norms and residual paths are already handled | ||
x = self.wrap_att(q, k, v, att_mask=att_mask) | ||
x = self.wrap_ff(x) | ||
x = self.wrap_att(inputs=[q, k, v], att_mask=att_mask) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. all the wraps require a single input list + kwargs, which I believe is more future proof (this normalizing bug cannot happen, or at least not as easily) |
||
x = self.wrap_ff(inputs=[x]) | ||
|
||
return x | ||
|
||
|
@@ -397,8 +397,10 @@ def forward( | |
else: | ||
target_q, target_k, target_v = target, target, target | ||
|
||
x = self.wrap_att([target_q, target_k, target_v], att_mask=decoder_att_mask) | ||
x = self.wrap_cross([x, memory, memory], att_mask=encoder_att_mask) | ||
x = self.wrap_ff(x) | ||
x = self.wrap_att( | ||
inputs=[target_q, target_k, target_v], att_mask=decoder_att_mask | ||
) | ||
x = self.wrap_cross(inputs=[x, memory, memory], att_mask=encoder_att_mask) | ||
x = self.wrap_ff(inputs=[x]) | ||
|
||
return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: extra noqa typo?