Skip to content

Commit

Permalink
[refactor] Better split responsibilities in between building the bloc…
Browse files Browse the repository at this point in the history
…k and the model (facebookresearch#189)
  • Loading branch information
blefaudeux authored Jul 8, 2021
1 parent 77883ce commit 7ead2e9
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 27 deletions.
7 changes: 0 additions & 7 deletions xformers/factory/block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def ln_factory(sublayer: nn.Module):
@dataclass(init=False) # handle constructors explicitly to force type changes
class xFormerBlockConfig:
dim_model: int
num_layers: int
feedforward_config: FeedforwardConfig
position_encoding_config: Optional[PositionEmbeddingConfig]
block_type: BlockType
Expand All @@ -114,11 +113,9 @@ def __init__(
feedforward_config: Dict[str, Any],
position_encoding_config: Optional[Dict[str, Any]],
block_type: BlockType,
num_layers: int = 1,
layer_norm_style: LayerNormStyle = LayerNormStyle("post"),
):
self.dim_model = dim_model
self.num_layers = num_layers
self.block_type = block_type
self.layer_norm_style = layer_norm_style

Expand Down Expand Up @@ -148,7 +145,6 @@ def __init__(
feedforward_config: Dict[str, Any],
position_encoding_config: Optional[Dict[str, Any]],
multi_head_config: Dict[str, Any],
num_layers: int = 1,
layer_norm_style: str = "post",
*args,
**kwargs,
Expand All @@ -158,7 +154,6 @@ def __init__(
feedforward_config=feedforward_config,
position_encoding_config=position_encoding_config,
layer_norm_style=LayerNormStyle(layer_norm_style),
num_layers=num_layers,
block_type=BlockType("encoder"),
)

Expand All @@ -177,7 +172,6 @@ def __init__(
position_encoding_config: Optional[Dict[str, Any]],
multi_head_config_masked: Dict[str, Any],
multi_head_config_cross: Dict[str, Any],
num_layers: int = 1,
layer_norm_style: str = "post",
*args,
**kwargs,
Expand All @@ -187,7 +181,6 @@ def __init__(
feedforward_config=feedforward_config,
position_encoding_config=position_encoding_config,
layer_norm_style=LayerNormStyle(layer_norm_style),
num_layers=num_layers,
block_type=BlockType("encoder"),
)

Expand Down
57 changes: 37 additions & 20 deletions xformers/factory/model_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Optional, Union, cast
from typing import Any, Dict, List, Optional, Union

import torch

Expand All @@ -13,24 +13,41 @@


@dataclass(init=False)
class xFormerConfig:
block_configs: List[Union[xFormerEncoderConfig, xFormerDecoderConfig]]
class xFormerStackConfig:
"""
A stack is defined by the definition of a given block, and an optional repetition factor
"""

def __init__(self, block_configs):
typed_configs = []
block_config: Union[xFormerEncoderConfig, xFormerDecoderConfig]
num_layers: int

for config in block_configs:
if config["block_type"] == BlockType.Encoder:
typed_configs.append(xFormerEncoderConfig(**config))
else:
typed_configs.append(xFormerDecoderConfig(**config))
def __init__(self, block_config: Dict[str, Any]):

if block_config["block_type"] == BlockType.Encoder:
self.block_config = xFormerEncoderConfig(**block_config)
else:
self.block_config = xFormerDecoderConfig(**block_config)

self.block_configs = typed_configs
# Convenience: make num_layers optional, so that a stack at that point could
# only be defined by a given block, and no repetition
if "num_layers" in block_config.keys():
self.num_layers = block_config["num_layers"]
else:
self.num_layers = 1


@dataclass(init=False)
class xFormerConfig:
stack_configs: List[Union[xFormerStackConfig, xFormerStackConfig]]

def __init__(self, block_configs: List[Dict[str, Any]]):
# Type all the configurations. Possible typos are caught here
self.stack_configs = [xFormerStackConfig(config) for config in block_configs]


class xFormer(torch.nn.Module):
def __init__(
self, block_configs: List[Union[xFormerEncoderConfig, xFormerDecoderConfig]]
self, stack_configs: List[Union[xFormerStackConfig, xFormerStackConfig]]
):
"""
Given a serialized configuration, generate the corresponding model.
Expand All @@ -42,17 +59,17 @@ def __init__(
encoders: List[torch.nn.Module] = []
decoders: List[torch.nn.Module] = []

for config in block_configs:
if type(config) is xFormerEncoderConfig:
config = cast(xFormerEncoderConfig, config)
for i in range(config.num_layers):
for stack in stack_configs:
config = stack.block_config

if isinstance(config, xFormerEncoderConfig):
for i in range(stack.num_layers):
if i > 0:
config.position_encoding_config = None
encoders.append(xFormerEncoderBlock.from_config(config))

elif type(config) is xFormerDecoderConfig:
config = cast(xFormerDecoderConfig, config)
for i in range(config.num_layers):
elif isinstance(config, xFormerDecoderConfig):
for i in range(stack.num_layers):
if i > 0:
config.position_encoding_config = None
decoders.append(xFormerDecoderBlock.from_config(config))
Expand All @@ -65,7 +82,7 @@ def __init__(

@classmethod
def from_config(cls, config: xFormerConfig):
return cls(config.block_configs)
return cls(config.stack_configs)

def forward(
self,
Expand Down

0 comments on commit 7ead2e9

Please sign in to comment.