Skip to content

Commit

Permalink
added unit testing
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Jan 3, 2022
1 parent bdec208 commit 77fa504
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 20 deletions.
33 changes: 26 additions & 7 deletions tests/test_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from contextlib import nullcontext

import pytest
import torch

Expand Down Expand Up @@ -110,17 +112,34 @@ def test_presets(config, reversible, tie_embedding_weights, device):

modelConfig = xFormerConfig(config, tie_embedding_weights)
if isinstance(modelConfig.stack_configs, dict):
for k, blockConfig in modelConfig.stack_configs.items():
for _, blockConfig in modelConfig.stack_configs.items():
assert blockConfig.layer_position
else:
for blockConfig in modelConfig.stack_configs:
assert blockConfig.layer_position

model = xFormer.from_config(modelConfig).to(device)
context = (
pytest.raises(AssertionError)
if reversible and tie_embedding_weights
else nullcontext()
)

with context:
model = xFormer.from_config(modelConfig).to(device)

# Dummy inputs, test a forward
inputs = (torch.rand((BATCH, SEQ), device=device) * 10).abs().to(torch.int)

input_mask = torch.randn(SEQ, dtype=torch.float, device=device)
input_mask[input_mask < 0.0] = -float("inf")
outputs = model(
inputs, encoder_input_mask=input_mask, decoder_input_mask=input_mask
)

# Dummy inputs, test a forward
inputs = (torch.rand((BATCH, SEQ), device=device) * 10).abs().to(torch.int)
# Test a BW
loss = torch.sum(torch.abs(outputs))
loss.backward()

input_mask = torch.randn(SEQ, dtype=torch.float, device=device)
input_mask[input_mask < 0.0] = -float("inf")
_ = model(inputs, encoder_input_mask=input_mask, decoder_input_mask=input_mask)
# If we requested tied embedding weights, check that this is the case indeed
if tie_embedding_weights and not reversible:
assert model.encoders[0].pose_encoding == model.decoders[0].pose_encoding
34 changes: 21 additions & 13 deletions xformers/factory/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.


import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

Expand Down Expand Up @@ -75,8 +76,7 @@ def __init__(
decoders: List[torch.nn.Module] = []

self.reversible_encoder = False
self.enc_pose_encoding = None
self.dec_pose_encoding = None
self.rev_enc_pose_encoding = None

# Unroll the configs and build the model
for config in stack_configs:
Expand Down Expand Up @@ -105,14 +105,30 @@ def __init__(
# WARNING: only one pose encoding is saved here (not Focal Transformer compatible for instance)
assert isinstance(config, xFormerEncoderConfig)
if block.pose_encoding is not None:
self.enc_pose_encoding = block.pose_encoding
self.rev_enc_pose_encoding = block.pose_encoding
self.reversible_encoder = True

f, g = xFormerEncoderBlock.get_reversible_layer(config)
recipient.append(torch.nn.ModuleList([f, g]))
else:
recipient.append(block) # type: ignore

# Tie embedding weights, if requested and possible
assert (
not tie_embedding_weights or not self.reversible_encoder
), "Reversible layers and tied embeddings is not supported for now"

if (
tie_embedding_weights
and encoders
and encoders[0].pose_encoding
and decoders
and decoders[0].pose_encoding
and not config.reversible
):
logging.info("Tying encoder and decoder embeddings, as requested")
encoders[0].pose_encoding = decoders[0].pose_encoding

self.encoders: torch.nn.Module = (
rv.ReversibleSequence(torch.nn.ModuleList(encoders))
if self.reversible_encoder
Expand All @@ -124,14 +140,6 @@ def __init__(
# Use Xavier init for encoding/decoding tasks
self._reset_parameters()

# Tie embedding weights, if requested
if (
tie_embedding_weights
and self.enc_pose_encoding
and self.dec_pose_encoding
):
self.enc_pose_encoding = self.dec_pose_encoding

@classmethod
def from_config(cls, config: xFormerConfig):
return cls(config.stack_configs, config.tie_embedding_weights)
Expand Down Expand Up @@ -172,8 +180,8 @@ def forward(
for encoder in encoders:
memory = encoder(memory, input_mask=encoder_input_mask)
else:
if self.enc_pose_encoding:
memory = self.enc_pose_encoding(src)
if self.rev_enc_pose_encoding:
memory = self.rev_enc_pose_encoding(src)

# Reversible Encoder
x = torch.cat([memory, memory], dim=-1)
Expand Down

0 comments on commit 77fa504

Please sign in to comment.