Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] Embedding weight tying (#169) #172

Merged
merged 3 commits into from
Jan 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Much faster fused dropout [#164]

### Added
- Embedding weight tying option [#172]

## [0.0.7] - 2021-11-30
### Fixed
- Dropout setting not properly passed in many attentions [#123]
Expand Down
38 changes: 29 additions & 9 deletions tests/test_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from contextlib import nullcontext

import pytest
import torch

Expand Down Expand Up @@ -99,27 +101,45 @@

@pytest.mark.parametrize("config", [test_configs_list, test_configs_dict])
@pytest.mark.parametrize("reversible", [True, False])
@pytest.mark.parametrize("tie_embedding_weights", [True, False])
@pytest.mark.parametrize("device", DEVICES)
def test_presets(config, reversible, device):
def test_presets(config, reversible, tie_embedding_weights, device):
# Build the model
if isinstance(config, list):
config[0]["reversible"] = reversible
else:
config["encoder"]["reversible"] = reversible

modelConfig = xFormerConfig(config)
modelConfig = xFormerConfig(config, tie_embedding_weights)
if isinstance(modelConfig.stack_configs, dict):
for k, blockConfig in modelConfig.stack_configs.items():
for _, blockConfig in modelConfig.stack_configs.items():
assert blockConfig.layer_position
else:
for blockConfig in modelConfig.stack_configs:
assert blockConfig.layer_position

model = xFormer.from_config(modelConfig).to(device)
context = (
pytest.raises(AssertionError)
if reversible and tie_embedding_weights
else nullcontext()
)

with context:
model = xFormer.from_config(modelConfig).to(device)

# Dummy inputs, test a forward
inputs = (torch.rand((BATCH, SEQ), device=device) * 10).abs().to(torch.int)

input_mask = torch.randn(SEQ, dtype=torch.float, device=device)
input_mask[input_mask < 0.0] = -float("inf")
outputs = model(
inputs, encoder_input_mask=input_mask, decoder_input_mask=input_mask
)

# Dummy inputs, test a forward
inputs = (torch.rand((BATCH, SEQ), device=device) * 10).abs().to(torch.int)
# Test a BW
loss = torch.sum(torch.abs(outputs))
loss.backward()

input_mask = torch.randn(SEQ, dtype=torch.float, device=device)
input_mask[input_mask < 0.0] = -float("inf")
_ = model(inputs, encoder_input_mask=input_mask, decoder_input_mask=input_mask)
# If we requested tied embedding weights, check that this is the case indeed
if tie_embedding_weights and not reversible:
assert model.encoders[0].pose_encoding == model.decoders[0].pose_encoding
19 changes: 19 additions & 0 deletions xformers/factory/block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ def ln_factory(sublayer: nn.Module):

@dataclass(init=False) # handle constructors explicitly to force type changes
class xFormerBlockConfig:
"""
The configuration structure to define a Transformer block.
This base class is applicable to both encoder and decoder definitions.

This completely defines each of the blocks, for instance in terms of dimensions,
position encoding, pre or post layer norms or reversibility.
"""

dim_model: int
feedforward_config: FeedforwardConfig
position_encoding_config: Optional[PositionEmbeddingConfig]
Expand Down Expand Up @@ -145,6 +153,10 @@ def __init__(

@dataclass(init=False)
class xFormerEncoderConfig(xFormerBlockConfig):
"""
The configuration structure for an encoder block
"""

multi_head_config: Dict[str, Any]
use_triton: bool

Expand Down Expand Up @@ -192,6 +204,13 @@ def __init__(

@dataclass(init=False)
class xFormerDecoderConfig(xFormerBlockConfig):
"""
The configuration structure for a decoder block.

This specifically defines the masked and cross attention mechanisms,
on top of the settings defining all blocks.
"""

multi_head_config_masked: Dict[str, Any] # prior to encoder output
multi_head_config_cross: Dict[str, Any] # cross attention, takes encoder output

Expand Down
47 changes: 40 additions & 7 deletions xformers/factory/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.


import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

Expand All @@ -21,10 +22,24 @@

@dataclass(init=False)
class xFormerConfig:
"""
The configuration structure to define a full Transformer.
This can include a stack of encoder layers, and a stack of decoder layers.

It is optionally possible to share the embedding weights in between
the encoder and decoder positional encoding, as proposed for instance by
`Using the Output Embedding to Improve Language Models`, Press et al.

.. _`Using the Output Embedding to Improve Language Models`: https://arxiv.org/pdf/1608.05859.pdf
"""

stack_configs: Union[List[xFormerBlockConfig], Dict[str, xFormerBlockConfig]]
tie_embedding_weights: bool = False

def __init__(
self, stack_configs: Union[List[Dict[str, Any]], Dict[str, Dict[str, Any]]]
self,
stack_configs: Union[List[Dict[str, Any]], Dict[str, Dict[str, Any]]],
tie_embedding_weights: bool = False,
):
# Type all the configurations. Possible typos are caught here
if isinstance(stack_configs, dict):
Expand All @@ -42,13 +57,16 @@ def __init__(
else:
self.stack_configs.append(xFormerDecoderConfig(**config))

self.tie_embedding_weights = tie_embedding_weights


class xFormer(torch.nn.Module):
def __init__(
self,
stack_configs: Union[
xFormerBlockConfig, List[xFormerBlockConfig], Dict[str, xFormerBlockConfig]
],
tie_embedding_weights: bool = False,
):
"""
Given a serialized configuration, generate the corresponding model.
Expand All @@ -69,8 +87,7 @@ def __init__(
decoders: List[torch.nn.Module] = []

self.reversible_encoder = False
self.enc_pose_encoding = None
self.dec_pose_encoding = None
self.rev_enc_pose_encoding = None

# Unroll the configs and build the model
for config in stack_configs:
Expand Down Expand Up @@ -99,14 +116,30 @@ def __init__(
# WARNING: only one pose encoding is saved here (not Focal Transformer compatible for instance)
assert isinstance(config, xFormerEncoderConfig)
if block.pose_encoding is not None:
self.enc_pose_encoding = block.pose_encoding
self.rev_enc_pose_encoding = block.pose_encoding
self.reversible_encoder = True

f, g = xFormerEncoderBlock.get_reversible_layer(config)
recipient.append(torch.nn.ModuleList([f, g]))
else:
recipient.append(block) # type: ignore

# Tie embedding weights, if requested and possible
assert (
not tie_embedding_weights or not self.reversible_encoder
), "Reversible layers and tied embeddings is not supported for now"

if (
tie_embedding_weights
and encoders
and encoders[0].pose_encoding
and decoders
and decoders[0].pose_encoding
and not config.reversible
):
logging.info("Tying encoder and decoder embeddings, as requested")
encoders[0].pose_encoding = decoders[0].pose_encoding

self.encoders: torch.nn.Module = (
rv.ReversibleSequence(torch.nn.ModuleList(encoders))
if self.reversible_encoder
Expand All @@ -120,7 +153,7 @@ def __init__(

@classmethod
def from_config(cls, config: xFormerConfig):
return cls(config.stack_configs)
return cls(config.stack_configs, config.tie_embedding_weights)

def _reset_parameters(self):
r"""Initiate parameters in the transformer model
Expand Down Expand Up @@ -158,8 +191,8 @@ def forward(
for encoder in encoders:
memory = encoder(memory, input_mask=encoder_input_mask)
else:
if self.enc_pose_encoding:
memory = self.enc_pose_encoding(src)
if self.rev_enc_pose_encoding:
memory = self.rev_enc_pose_encoding(src)

# Reversible Encoder
x = torch.cat([memory, memory], dim=-1)
Expand Down