Skip to content

Commit

Permalink
Make Transformers official (#170)
Browse files Browse the repository at this point in the history
* simple transformer test

* fix docs
  • Loading branch information
cgarciae authored Mar 1, 2021
1 parent d0ff8f0 commit 4ed4c86
Show file tree
Hide file tree
Showing 14 changed files with 207 additions and 73 deletions.
11 changes: 10 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Changelog

## [Unreleased](https://github.com/poets-ai/elegy/tree/HEAD)

[Full Changelog](https://github.com/poets-ai/elegy/compare/0.7.0...HEAD)

**Merged pull requests:**

- More Docs: Expand documentation for the low-level API guides. [\#168](https://github.com/poets-ai/elegy/pull/168) ([cgarciae](https://github.com/cgarciae))
- Rich Summary: uses rich to style the summary console output [\#167](https://github.com/poets-ai/elegy/pull/167) ([cgarciae](https://github.com/cgarciae))

## [0.7.0](https://github.com/poets-ai/elegy/tree/0.7.0) (2021-02-22)

[Full Changelog](https://github.com/poets-ai/elegy/compare/0.6.0...0.7.0)
Expand Down Expand Up @@ -40,7 +49,6 @@
**Merged pull requests:**

- fix-maybe-initialize [\#155](https://github.com/poets-ai/elegy/pull/155) ([cgarciae](https://github.com/cgarciae))
- Add simple Flax low-level API Model example to README.md [\#153](https://github.com/poets-ai/elegy/pull/153) ([sooheon](https://github.com/sooheon))

## [0.4.0](https://github.com/poets-ai/elegy/tree/0.4.0) (2021-02-01)

Expand All @@ -52,6 +60,7 @@

**Merged pull requests:**

- Add simple Flax low-level API Model example to README.md [\#153](https://github.com/poets-ai/elegy/pull/153) ([sooheon](https://github.com/sooheon))
- Update Getting Started + README [\#152](https://github.com/poets-ai/elegy/pull/152) ([cgarciae](https://github.com/cgarciae))
- Pretrained ResNet fix after \#139 [\#151](https://github.com/poets-ai/elegy/pull/151) ([alexander-g](https://github.com/alexander-g))
- Dataset: better default batch\_fn and custom batch\_fn [\#148](https://github.com/poets-ai/elegy/pull/148) ([alexander-g](https://github.com/alexander-g))
Expand Down
15 changes: 15 additions & 0 deletions docs/api/nn/MultiHeadAttention.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

# elegy.nn.MultiHeadAttention

::: elegy.nn.multi_head_attention.MultiHeadAttention
selection:
inherited_members: true
members:
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- init

15 changes: 15 additions & 0 deletions docs/api/nn/Transformer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

# elegy.nn.Transformer

::: elegy.nn.transformers.Transformer
selection:
inherited_members: true
members:
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- init

15 changes: 15 additions & 0 deletions docs/api/nn/TransformerDecoder.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

# elegy.nn.TransformerDecoder

::: elegy.nn.transformers.TransformerDecoder
selection:
inherited_members: true
members:
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- init

15 changes: 15 additions & 0 deletions docs/api/nn/TransformerDecoderLayer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

# elegy.nn.TransformerDecoderLayer

::: elegy.nn.transformers.TransformerDecoderLayer
selection:
inherited_members: true
members:
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- init

15 changes: 15 additions & 0 deletions docs/api/nn/TransformerEncoder.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

# elegy.nn.TransformerEncoder

::: elegy.nn.transformers.TransformerEncoder
selection:
inherited_members: true
members:
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- init

15 changes: 15 additions & 0 deletions docs/api/nn/TransformerEncoderLayer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

# elegy.nn.TransformerEncoderLayer

::: elegy.nn.transformers.TransformerEncoderLayer
selection:
inherited_members: true
members:
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- init

5 changes: 3 additions & 2 deletions docs/low-level-api/default-implementation.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ call_summary_step call_pred_step call_test_step call_trai
```
This structure allows you to for example override `test_step` and still be able to use use `fit` since `train_step` (called by `fit`) will call your `test_step` via `grad_step`. It also means that if you implement `test_step` but not `pred_step` there is a high chance both `predict` and `summary` will not work.

##### call_* methods
The `call_<method>` method family are _entrypoints_ that usually just redirect to their inputs to `<method>`, you choose to override these if you need to perform some some computation only when method in question is the entry point. For example if you want to change the behavior of `evaluate` without affecting the behavior of `fit` while preserving most of the default implementation you can override `call_step_step` to do the corresponding adjustments and then call `test_step`. Since `train_step` does not depend on `call_step_step` then the change will manifest during `evaluate` but not during `fit`.
#### call_* methods
The `call_<method>` method family are _entrypoints_ that usually just redirect to their inputs to `<method>`, you choose to override these if you need to perform some some computation only when method in question is the entry point i.e. when its not called by other methods in the bottom path.
For example if you want to change the behavior of `evaluate` without affecting the behavior of `fit` while preserving most of the default implementation you can override `call_step_step` to do the corresponding adjustments and then call `test_step`. Since `train_step` does not depend on `call_step_step` then the change will manifest during `evaluate` but not during `fit`.
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,5 @@ Here we implement the same `LinearClassifier` from the [basics](./basics) sectio

### Default Implementation
The default implementation of `pred_step` does the following:

* Calls `api_module.init` or `api_module.apply` depending on state of `initializing`. `api_module` of type `GeneralizedModule` is a wrapper over the `module` object passed by the user to the `Model`s constructor.
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ There are cases however where you might want to implement a forward pass inside

### Default Implementation
The default implementation of `pred_step` does the following:

* Call `pred_step` to get `y_pred`.
* Calls `api_loss.init` or `api_loss.apply` depending on state of `initializing`. `api_loss` of type `Losses` computes the aggregated batch loss from the loss functions passed by the user through the `loss` argument in the `Model`s constructor, and also computes a running mean of each loss individually which is passed for reporting to `logs`.
* Calls `api_metrics.init` or `api_metrics.apply` depending on state of `initializing`. `api_metrics` of type `Metrics` calculates the metrics passed by the user through the `metrics` argument in the `Model`s constructor and passes their values to `logs` for reporting.
24 changes: 18 additions & 6 deletions elegy/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@
from .batch_normalization import BatchNormalization
from .conv import Conv1D, Conv2D, Conv3D, ConvND
from .dropout import Dropout
from .embedding import Embedding, EmbedLookupStyle
from .flatten import Flatten, Reshape
from .layer_normalization import InstanceNormalization, LayerNormalization
from .linear import Linear
from .sequential_module import Sequential, sequential

from .layer_normalization import LayerNormalization, InstanceNormalization
from .embedding import Embedding, EmbedLookupStyle
from .pool import MaxPool, AvgPool
from .moving_averages import EMAParamsTree
from . import transformers
from .multi_head_attention import MultiHeadAttention
from .pool import AvgPool, MaxPool
from .sequential_module import Sequential, sequential
from .transformers import (
Transformer,
TransformerDecoder,
TransformerDecoderLayer,
TransformerEncoder,
TransformerEncoderLayer,
)

__all__ = [
"EMAParamsTree",
"BatchNormalization",
"MultiHeadAttention",
"Transformer",
"TransformerDecoder",
"TransformerDecoderLayer",
"TransformerEncoder",
"TransformerEncoderLayer",
"Conv1D",
"Conv2D",
"Conv3D",
Expand Down
70 changes: 31 additions & 39 deletions elegy/nn/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,13 @@ class TransformerEncoderLayer(Module):
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
in a different way during application.
Args:
Arguments:
head_size: the number of expected features in the input (required).
num_heads: the number of heads in the multiheadattention models (required).
output_size: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of intermediate layer, relu or gelu (default=relu).
Examples::
>>> # encoder_layer = nn.TransformerEncoderLayer(head_size=512, num_heads=8)
>>> # src = torch.rand(10, 32, 512)
>>> # out = encoder_layer(src)
"""

def __init__(
Expand All @@ -47,7 +43,7 @@ def __init__(
super().__init__(**kwargs)
self.head_size = head_size
self.num_heads = num_heads
self.output_size = output_size
self.output_size = output_size if output_size is not None else head_size
self.dropout = dropout
self.activation = activation

Expand All @@ -59,7 +55,7 @@ def call(
) -> np.ndarray:
r"""Pass the input through the encoder layer.
Args:
Arguments:
src: the sequence to the encoder layer (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Expand Down Expand Up @@ -95,12 +91,12 @@ class TransformerEncoder(Module):
r"""
TransformerEncoder is a stack of N encoder layers
Args:
Arguments:
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
Examples::
Examples:
>>> import elegy
>>> transformer_encoder = elegy.nn.transformers.TransformerEncoder(
... lambda: elegy.nn.transformers.TransformerEncoderLayer(head_size=512, num_heads=8),
Expand Down Expand Up @@ -131,7 +127,7 @@ def call(
) -> np.ndarray:
r"""Pass the input through the encoder layers in turn.
Args:
Arguments:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Expand All @@ -158,32 +154,27 @@ class TransformerDecoderLayer(Module):
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
in a different way during application.
Args:
Arguments:
head_size: the number of expected features in the input (required).
num_heads: the number of heads in the multiheadattention models (required).
output_size: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of intermediate layer, relu or gelu (default=relu).
Examples::
>>> # decoder_layer = nn.TransformerDecoderLayer(head_size=512, num_heads=8)
>>> # memory = torch.rand(10, 32, 512)
>>> # tgt = torch.rand(20, 32, 512)
>>> # out = decoder_layer(tgt, memory)
"""

def __init__(
self,
head_size: int,
num_heads: int,
output_size: int = 2048,
output_size: tp.Optional[int] = None,
dropout: float = 0.1,
activation: tp.Callable[[np.ndarray], np.ndarray] = jax.nn.relu,
):
super().__init__()
self.head_size = head_size
self.num_heads = num_heads
self.output_size = output_size
self.output_size = output_size if output_size is not None else head_size
self.dropout = dropout
self.activation = activation

Expand All @@ -198,7 +189,7 @@ def call(
) -> np.ndarray:
r"""Pass the inputs (and mask) through the decoder layer.
Args:
Arguments:
tgt: the sequence to the decoder layer (required).
memory: the sequence from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
Expand Down Expand Up @@ -237,17 +228,11 @@ def call(
class TransformerDecoder(Module):
r"""TransformerDecoder is a stack of N decoder layers
Args:
Arguments:
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
num_layers: the number of sub-decoder-layers in the decoder (required).
norm: the layer normalization component (optional).
Examples::
>>> # decoder_layer = nn.TransformerDecoderLayer(head_size=512, num_heads=8)
>>> # transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
>>> # memory = torch.rand(10, 32, 512)
>>> # tgt = torch.rand(20, 32, 512)
>>> # out = transformer_decoder(tgt, memory)
"""

def __init__(
Expand All @@ -272,7 +257,7 @@ def call(
) -> np.ndarray:
r"""Pass the inputs (and mask) through the decoder layer in turn.
Args:
Arguments:
tgt: the sequence to the decoder (required).
memory: the sequence from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
Expand Down Expand Up @@ -309,7 +294,7 @@ class Transformer(Module):
Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805)
model with corresponding parameters.
Args:
Arguments:
head_size: the number of expected features in the encoder/decoder inputs (default=512).
num_heads: the number of heads in the multiheadattention models (default=8).
num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
Expand All @@ -320,14 +305,23 @@ class Transformer(Module):
custom_encoder: custom encoder (default=None).
custom_decoder: custom decoder (default=None).
Examples::
>>> # transformer_model = nn.Transformer(num_heads=16, num_encoder_layers=12)
>>> # src = torch.rand((10, 32, 512))
>>> # tgt = torch.rand((20, 32, 512))
>>> # out = transformer_model(src, tgt)
Examples:
>>> import elegy
>>> import numpy as np
>>> transformer_model = elegy.nn.Transformer(
... head_size=64,
... num_heads=4,
... num_encoder_layers=2,
... num_decoder_layers=2,
... )
>>> src = np.random.uniform(size=(5, 32, 64))
>>> tgt = np.random.uniform(size=(5, 32, 64))
>>> _, params, collections = transformer_model.init(rng=elegy.RNGSeq(42))(src, tgt)
>>> out, params, collections = transformer_model.apply(params, collections, rng=elegy.RNGSeq(420))(src, tgt)
Note: A full example to apply nn.Transformer module for the word language model is available in
https://github.com/pytorch/examples/tree/master/word_language_model
"""

def __init__(
Expand All @@ -336,7 +330,7 @@ def __init__(
num_heads: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
output_size: int = 2048,
output_size: tp.Optional[int] = None,
dropout: float = 0.1,
activation: tp.Callable[[np.ndarray], np.ndarray] = jax.nn.relu,
custom_encoder: tp.Optional[tp.Any] = None,
Expand Down Expand Up @@ -367,7 +361,7 @@ def call(
) -> np.ndarray:
r"""Take in and process masked source/target sequences.
Args:
Arguments:
src: the sequence to the encoder (required).
tgt: the sequence to the decoder (required).
src_mask: the additive mask for the src sequence (optional).
Expand Down Expand Up @@ -406,8 +400,6 @@ def call(
where S is the source sequence length, T is the target sequence length, N is the
batch size, E is the feature number
Examples:
>>> # output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
"""

if src.shape[0] != tgt.shape[0]:
Expand Down
Loading

0 comments on commit 4ed4c86

Please sign in to comment.