From 23bc729051326c795903dae84b51fb492a084257 Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Thu, 10 Mar 2022 21:51:28 -0800 Subject: [PATCH 1/2] bugfix --- CHANGELOG.md | 1 + tests/test_residual.py | 27 ++++++++++++++ xformers/components/residual.py | 62 ++++++++++++++++++++----------- xformers/factory/block_factory.py | 12 +++--- 4 files changed, 76 insertions(+), 26 deletions(-) create mode 100644 tests/test_residual.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ffe4de687..5f9c8e7a6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Expose bias flag for feedforwards, same default as Timm [#220] - Update eps value for layernormm, same default as torch [#221] +- PreNorm bugfix, only one input was normalized [#233] ## [0.0.9] - 2022-02-09 ### Added diff --git a/tests/test_residual.py b/tests/test_residual.py new file mode 100644 index 0000000000..3586c9f85b --- /dev/null +++ b/tests/test_residual.py @@ -0,0 +1,27 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + +from xformers.components import PreNorm + + +class Passthrough(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, *args): + return args + + +def test_pre_norm(): + # Check that passing the same tensor a bunch of times skips the extra normalizations + x = torch.rand((3, 3)) + + wrap = PreNorm(d_model=3, sublayer=Passthrough(), use_triton=False) + outputs = wrap(inputs=[x, x, x]) + + assert id(outputs[0]) == id(outputs[1]) diff --git a/xformers/components/residual.py b/xformers/components/residual.py index e97f358cd1..99d07b7001 100644 --- a/xformers/components/residual.py +++ b/xformers/components/residual.py @@ -16,14 +16,6 @@ from xformers.triton.layer_norm import FusedLayerNorm -def _to_tensor_list( - inputs: Union[torch.Tensor, List[torch.Tensor]] -) -> List[torch.Tensor]: - if not isinstance(inputs, list): - inputs = [inputs] - return inputs - - class LayerNormStyle(str, Enum): """Support different layer norm styles. See "On Layer Normalization in the Transformer Architecture", @@ -36,16 +28,25 @@ class LayerNormStyle(str, Enum): # CREDITS: the following is inspired by FastAI's Transformer implementation class Residual(nn.Module): - """Object-oriented handling of the residual path""" + """ + Object-oriented handling of the residual path + + .. Note: the wrapped layers must accept all the inputs as a single list + """ def __init__(self, layer: nn.Module): super().__init__() self.layer = layer - def forward(self, inputs: Union[torch.Tensor, List[torch.Tensor]], *args, **kwargs): - inputs = _to_tensor_list(inputs) + # PreNorm and PostNorm require all the tensors to be passed as a list + self.wrap_inputs = isinstance(layer, PreNorm) or isinstance(layer, PostNorm) + + def forward(self, inputs: List[torch.Tensor], **kwargs): + if self.wrap_inputs: + return inputs[0] + self.layer(inputs=inputs, **kwargs) - return inputs[0] + self.layer(*inputs, *args, **kwargs) + else: + return inputs[0] + self.layer(*inputs, **kwargs) class PreNorm(nn.Module): @@ -61,12 +62,27 @@ def __init__(self, d_model: int, sublayer: nn.Module, use_triton: bool = True): self.norm = nn.LayerNorm(d_model) self.sublayer = sublayer + self.wrap_inputs = isinstance(sublayer, PostNorm) or isinstance( + sublayer, Residual + ) + + def forward(self, inputs: List[torch.Tensor], **kwargs): + assert len(inputs) > 0 + + # Perf improvement: if the inputs are all the same, only norm once + ids = [id(x) for x in inputs] + if ids.count(ids[0]) == len(ids): + # The same tensor is passed multiple times + x_norm = self.norm(inputs[0]) + inputs_normed = [x_norm for _ in inputs] + else: + # The inputs differ, norm them all + inputs_normed = [self.norm(x_) for x_ in inputs] - def forward(self, inputs: Union[torch.Tensor, List[torch.Tensor]], *args, **kwargs): - inputs = _to_tensor_list(inputs) - - x_norm = [self.norm(x_) for x_ in inputs] - return self.sublayer(*x_norm, *args, **kwargs) + if self.wrap_inputs: + return self.sublayer(inputs=inputs_normed, **kwargs) + else: + return self.sublayer(*inputs_normed, **kwargs) class PostNorm(nn.Module): @@ -80,9 +96,13 @@ def __init__(self, d_model: int, sublayer: nn.Module, use_triton: bool = True): self.norm = nn.LayerNorm(d_model) self.sublayer = sublayer + self.wrap_inputs = isinstance(sublayer, PreNorm) or isinstance( + sublayer, Residual + ) - def forward(self, inputs: Union[torch.Tensor, List[torch.Tensor]], *args, **kwargs): - inputs = _to_tensor_list(inputs) - - x = self.sublayer(*inputs, *args, **kwargs) + def forward(self, inputs: List[torch.Tensor], **kwargs): + if self.wrap_inputs: + x = self.sublayer(inputs=inputs, **kwargs) + else: + x = self.sublayer(*inputs, **kwargs) return self.norm(x) diff --git a/xformers/factory/block_factory.py b/xformers/factory/block_factory.py index 7320441cef..51baeef857 100644 --- a/xformers/factory/block_factory.py +++ b/xformers/factory/block_factory.py @@ -335,8 +335,8 @@ def forward( q, k, v = x, x, x # Pre/Post norms and residual paths are already handled - x = self.wrap_att(q, k, v, att_mask=att_mask) - x = self.wrap_ff(x) + x = self.wrap_att(inputs=[q, k, v], att_mask=att_mask) + x = self.wrap_ff(inputs=[x]) return x @@ -397,8 +397,10 @@ def forward( else: target_q, target_k, target_v = target, target, target - x = self.wrap_att([target_q, target_k, target_v], att_mask=decoder_att_mask) - x = self.wrap_cross([x, memory, memory], att_mask=encoder_att_mask) - x = self.wrap_ff(x) + x = self.wrap_att( + inputs=[target_q, target_k, target_v], att_mask=decoder_att_mask + ) + x = self.wrap_cross(inputs=[x, memory, memory], att_mask=encoder_att_mask) + x = self.wrap_ff(inputs=[x]) return x From fb0b4798a4160bcd2b0b2754fc24df583d51832d Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Thu, 10 Mar 2022 23:01:04 -0800 Subject: [PATCH 2/2] fixing reversible layers --- xformers/components/__init__.py | 6 +++++- xformers/components/residual.py | 21 ++++++++++++--------- xformers/components/reversible.py | 14 ++++++++++++-- 3 files changed, 29 insertions(+), 12 deletions(-) diff --git a/xformers/components/__init__.py b/xformers/components/__init__.py index bfe4e12810..fba562ad8a 100644 --- a/xformers/components/__init__.py +++ b/xformers/components/__init__.py @@ -15,7 +15,11 @@ from .in_proj_container import InProjContainer, InProjParams # noqa from .multi_head_dispatch import MultiHeadDispatch # noqa from .multi_head_dispatch import MultiHeadDispatchConfig -from .residual import LayerNormStyle, PostNorm, PreNorm, Residual # noqa +from .residual import LayerNormStyle # noqa; noqa +from .residual import PostNorm # noqa +from .residual import PreNorm # noqa +from .residual import RequiresWrappedInputs # noqa +from .residual import Residual # noqa # automatically import any Python files in the directory import_all_modules(str(Path(__file__).parent), "xformers.components") diff --git a/xformers/components/residual.py b/xformers/components/residual.py index 99d07b7001..2cade55af3 100644 --- a/xformers/components/residual.py +++ b/xformers/components/residual.py @@ -26,8 +26,15 @@ class LayerNormStyle(str, Enum): Post = "post" +class RequiresWrappedInputs: + """Used to mark, through inheritance, + the fact that this class will require inputs to be passed as a single list""" + + pass + + # CREDITS: the following is inspired by FastAI's Transformer implementation -class Residual(nn.Module): +class Residual(nn.Module, RequiresWrappedInputs): """ Object-oriented handling of the residual path @@ -49,7 +56,7 @@ def forward(self, inputs: List[torch.Tensor], **kwargs): return inputs[0] + self.layer(*inputs, **kwargs) -class PreNorm(nn.Module): +class PreNorm(nn.Module, RequiresWrappedInputs): """Adds LayerNorm before computing attention ..Note: If a list of inputs is passed, all of them get normalized""" @@ -62,9 +69,7 @@ def __init__(self, d_model: int, sublayer: nn.Module, use_triton: bool = True): self.norm = nn.LayerNorm(d_model) self.sublayer = sublayer - self.wrap_inputs = isinstance(sublayer, PostNorm) or isinstance( - sublayer, Residual - ) + self.wrap_inputs = isinstance(sublayer, RequiresWrappedInputs) def forward(self, inputs: List[torch.Tensor], **kwargs): assert len(inputs) > 0 @@ -85,7 +90,7 @@ def forward(self, inputs: List[torch.Tensor], **kwargs): return self.sublayer(*inputs_normed, **kwargs) -class PostNorm(nn.Module): +class PostNorm(nn.Module, RequiresWrappedInputs): """Adds LayerNorm after computing attention""" def __init__(self, d_model: int, sublayer: nn.Module, use_triton: bool = True): @@ -96,9 +101,7 @@ def __init__(self, d_model: int, sublayer: nn.Module, use_triton: bool = True): self.norm = nn.LayerNorm(d_model) self.sublayer = sublayer - self.wrap_inputs = isinstance(sublayer, PreNorm) or isinstance( - sublayer, Residual - ) + self.wrap_inputs = isinstance(sublayer, RequiresWrappedInputs) def forward(self, inputs: List[torch.Tensor], **kwargs): if self.wrap_inputs: diff --git a/xformers/components/reversible.py b/xformers/components/reversible.py index 55fc9fb8f5..38e538d4d3 100644 --- a/xformers/components/reversible.py +++ b/xformers/components/reversible.py @@ -11,6 +11,8 @@ from torch.autograd.function import Function from torch.utils.checkpoint import get_device_states, set_device_states +from xformers.components import RequiresWrappedInputs + # CREDITS: Code adapted from # https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py # https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py, @@ -26,6 +28,7 @@ def __init__(self, net: nn.Module): self.cuda_in_fwd: bool = False self.gpu_devices: List[int] = [] self.gpu_states: List[torch.Tensor] = [] + self.wrap_inputs = isinstance(net, RequiresWrappedInputs) def record_rng(self, *args): self.cpu_state = torch.get_rng_state() @@ -38,7 +41,10 @@ def forward(self, *args, record_rng: bool = False, set_rng: bool = False, **kwar self.record_rng(*args) if not set_rng: - return self.net(*args, **kwargs) + if self.wrap_inputs: + return self.net(inputs=args, **kwargs) + else: + return self.net(*args, **kwargs) rng_devices: List[int] = [] if self.cuda_in_fwd: @@ -48,7 +54,11 @@ def forward(self, *args, record_rng: bool = False, set_rng: bool = False, **kwar torch.set_rng_state(self.cpu_state) if self.cuda_in_fwd: set_device_states(self.gpu_devices, self.gpu_states) - return self.net(*args, **kwargs) + + if self.wrap_inputs: + return self.net(inputs=args, **kwargs) + else: + return self.net(*args, **kwargs) class ReversibleBlock(nn.Module):