From bdec2088cb04e2bdac6be4e46f5aed25d53850e3 Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Mon, 3 Jan 2022 14:35:26 -0800 Subject: [PATCH 1/3] tentative implementation of #169 --- tests/test_model_factory.py | 5 +++-- xformers/factory/model_factory.py | 18 ++++++++++++++++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/tests/test_model_factory.py b/tests/test_model_factory.py index 9e8bcefcb9..2042783529 100644 --- a/tests/test_model_factory.py +++ b/tests/test_model_factory.py @@ -99,15 +99,16 @@ @pytest.mark.parametrize("config", [test_configs_list, test_configs_dict]) @pytest.mark.parametrize("reversible", [True, False]) +@pytest.mark.parametrize("tie_embedding_weights", [True, False]) @pytest.mark.parametrize("device", DEVICES) -def test_presets(config, reversible, device): +def test_presets(config, reversible, tie_embedding_weights, device): # Build the model if isinstance(config, list): config[0]["reversible"] = reversible else: config["encoder"]["reversible"] = reversible - modelConfig = xFormerConfig(config) + modelConfig = xFormerConfig(config, tie_embedding_weights) if isinstance(modelConfig.stack_configs, dict): for k, blockConfig in modelConfig.stack_configs.items(): assert blockConfig.layer_position diff --git a/xformers/factory/model_factory.py b/xformers/factory/model_factory.py index 9fcd0c3794..07ca158723 100644 --- a/xformers/factory/model_factory.py +++ b/xformers/factory/model_factory.py @@ -22,9 +22,12 @@ @dataclass(init=False) class xFormerConfig: stack_configs: Union[List[xFormerBlockConfig], Dict[str, xFormerBlockConfig]] + tie_embedding_weights: bool = False def __init__( - self, stack_configs: Union[List[Dict[str, Any]], Dict[str, Dict[str, Any]]] + self, + stack_configs: Union[List[Dict[str, Any]], Dict[str, Dict[str, Any]]], + tie_embedding_weights: bool = False, ): # Type all the configurations. Possible typos are caught here if isinstance(stack_configs, dict): @@ -42,6 +45,8 @@ def __init__( else: self.stack_configs.append(xFormerDecoderConfig(**config)) + self.tie_embedding_weights = tie_embedding_weights + class xFormer(torch.nn.Module): def __init__( @@ -49,6 +54,7 @@ def __init__( stack_configs: Union[ xFormerBlockConfig, List[xFormerBlockConfig], Dict[str, xFormerBlockConfig] ], + tie_embedding_weights: bool = False, ): """ Given a serialized configuration, generate the corresponding model. @@ -118,9 +124,17 @@ def __init__( # Use Xavier init for encoding/decoding tasks self._reset_parameters() + # Tie embedding weights, if requested + if ( + tie_embedding_weights + and self.enc_pose_encoding + and self.dec_pose_encoding + ): + self.enc_pose_encoding = self.dec_pose_encoding + @classmethod def from_config(cls, config: xFormerConfig): - return cls(config.stack_configs) + return cls(config.stack_configs, config.tie_embedding_weights) def _reset_parameters(self): r"""Initiate parameters in the transformer model From 77fa504a3a29d16cff4ffc1035a47b132c9f3d11 Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Mon, 3 Jan 2022 15:08:09 -0800 Subject: [PATCH 2/3] added unit testing --- tests/test_model_factory.py | 33 +++++++++++++++++++++++------- xformers/factory/model_factory.py | 34 +++++++++++++++++++------------ 2 files changed, 47 insertions(+), 20 deletions(-) diff --git a/tests/test_model_factory.py b/tests/test_model_factory.py index 2042783529..247c8ceb81 100644 --- a/tests/test_model_factory.py +++ b/tests/test_model_factory.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +from contextlib import nullcontext + import pytest import torch @@ -110,17 +112,34 @@ def test_presets(config, reversible, tie_embedding_weights, device): modelConfig = xFormerConfig(config, tie_embedding_weights) if isinstance(modelConfig.stack_configs, dict): - for k, blockConfig in modelConfig.stack_configs.items(): + for _, blockConfig in modelConfig.stack_configs.items(): assert blockConfig.layer_position else: for blockConfig in modelConfig.stack_configs: assert blockConfig.layer_position - model = xFormer.from_config(modelConfig).to(device) + context = ( + pytest.raises(AssertionError) + if reversible and tie_embedding_weights + else nullcontext() + ) + + with context: + model = xFormer.from_config(modelConfig).to(device) + + # Dummy inputs, test a forward + inputs = (torch.rand((BATCH, SEQ), device=device) * 10).abs().to(torch.int) + + input_mask = torch.randn(SEQ, dtype=torch.float, device=device) + input_mask[input_mask < 0.0] = -float("inf") + outputs = model( + inputs, encoder_input_mask=input_mask, decoder_input_mask=input_mask + ) - # Dummy inputs, test a forward - inputs = (torch.rand((BATCH, SEQ), device=device) * 10).abs().to(torch.int) + # Test a BW + loss = torch.sum(torch.abs(outputs)) + loss.backward() - input_mask = torch.randn(SEQ, dtype=torch.float, device=device) - input_mask[input_mask < 0.0] = -float("inf") - _ = model(inputs, encoder_input_mask=input_mask, decoder_input_mask=input_mask) + # If we requested tied embedding weights, check that this is the case indeed + if tie_embedding_weights and not reversible: + assert model.encoders[0].pose_encoding == model.decoders[0].pose_encoding diff --git a/xformers/factory/model_factory.py b/xformers/factory/model_factory.py index 07ca158723..5ec514afef 100644 --- a/xformers/factory/model_factory.py +++ b/xformers/factory/model_factory.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. +import logging from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union @@ -75,8 +76,7 @@ def __init__( decoders: List[torch.nn.Module] = [] self.reversible_encoder = False - self.enc_pose_encoding = None - self.dec_pose_encoding = None + self.rev_enc_pose_encoding = None # Unroll the configs and build the model for config in stack_configs: @@ -105,7 +105,7 @@ def __init__( # WARNING: only one pose encoding is saved here (not Focal Transformer compatible for instance) assert isinstance(config, xFormerEncoderConfig) if block.pose_encoding is not None: - self.enc_pose_encoding = block.pose_encoding + self.rev_enc_pose_encoding = block.pose_encoding self.reversible_encoder = True f, g = xFormerEncoderBlock.get_reversible_layer(config) @@ -113,6 +113,22 @@ def __init__( else: recipient.append(block) # type: ignore + # Tie embedding weights, if requested and possible + assert ( + not tie_embedding_weights or not self.reversible_encoder + ), "Reversible layers and tied embeddings is not supported for now" + + if ( + tie_embedding_weights + and encoders + and encoders[0].pose_encoding + and decoders + and decoders[0].pose_encoding + and not config.reversible + ): + logging.info("Tying encoder and decoder embeddings, as requested") + encoders[0].pose_encoding = decoders[0].pose_encoding + self.encoders: torch.nn.Module = ( rv.ReversibleSequence(torch.nn.ModuleList(encoders)) if self.reversible_encoder @@ -124,14 +140,6 @@ def __init__( # Use Xavier init for encoding/decoding tasks self._reset_parameters() - # Tie embedding weights, if requested - if ( - tie_embedding_weights - and self.enc_pose_encoding - and self.dec_pose_encoding - ): - self.enc_pose_encoding = self.dec_pose_encoding - @classmethod def from_config(cls, config: xFormerConfig): return cls(config.stack_configs, config.tie_embedding_weights) @@ -172,8 +180,8 @@ def forward( for encoder in encoders: memory = encoder(memory, input_mask=encoder_input_mask) else: - if self.enc_pose_encoding: - memory = self.enc_pose_encoding(src) + if self.rev_enc_pose_encoding: + memory = self.rev_enc_pose_encoding(src) # Reversible Encoder x = torch.cat([memory, memory], dim=-1) From 4c68c1158b955fcd815db28ec7ea7963d8e0840d Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Tue, 4 Jan 2022 13:35:16 -0800 Subject: [PATCH 3/3] Improve on the doc --- CHANGELOG.md | 3 +++ xformers/factory/block_factory.py | 19 +++++++++++++++++++ xformers/factory/model_factory.py | 11 +++++++++++ 3 files changed, 33 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e82850914c..b8902f02f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Much faster fused dropout [#164] +### Added +- Embedding weight tying option [#172] + ## [0.0.7] - 2021-11-30 ### Fixed - Dropout setting not properly passed in many attentions [#123] diff --git a/xformers/factory/block_factory.py b/xformers/factory/block_factory.py index bc7aad42a0..7320441cef 100644 --- a/xformers/factory/block_factory.py +++ b/xformers/factory/block_factory.py @@ -93,6 +93,14 @@ def ln_factory(sublayer: nn.Module): @dataclass(init=False) # handle constructors explicitly to force type changes class xFormerBlockConfig: + """ + The configuration structure to define a Transformer block. + This base class is applicable to both encoder and decoder definitions. + + This completely defines each of the blocks, for instance in terms of dimensions, + position encoding, pre or post layer norms or reversibility. + """ + dim_model: int feedforward_config: FeedforwardConfig position_encoding_config: Optional[PositionEmbeddingConfig] @@ -145,6 +153,10 @@ def __init__( @dataclass(init=False) class xFormerEncoderConfig(xFormerBlockConfig): + """ + The configuration structure for an encoder block + """ + multi_head_config: Dict[str, Any] use_triton: bool @@ -192,6 +204,13 @@ def __init__( @dataclass(init=False) class xFormerDecoderConfig(xFormerBlockConfig): + """ + The configuration structure for a decoder block. + + This specifically defines the masked and cross attention mechanisms, + on top of the settings defining all blocks. + """ + multi_head_config_masked: Dict[str, Any] # prior to encoder output multi_head_config_cross: Dict[str, Any] # cross attention, takes encoder output diff --git a/xformers/factory/model_factory.py b/xformers/factory/model_factory.py index 5ec514afef..8ef8cd60d9 100644 --- a/xformers/factory/model_factory.py +++ b/xformers/factory/model_factory.py @@ -22,6 +22,17 @@ @dataclass(init=False) class xFormerConfig: + """ + The configuration structure to define a full Transformer. + This can include a stack of encoder layers, and a stack of decoder layers. + + It is optionally possible to share the embedding weights in between + the encoder and decoder positional encoding, as proposed for instance by + `Using the Output Embedding to Improve Language Models`, Press et al. + + .. _`Using the Output Embedding to Improve Language Models`: https://arxiv.org/pdf/1608.05859.pdf + """ + stack_configs: Union[List[xFormerBlockConfig], Dict[str, xFormerBlockConfig]] tie_embedding_weights: bool = False