Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Transformers official #170

Merged
merged 2 commits into from
Mar 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Changelog

## [Unreleased](https://github.com/poets-ai/elegy/tree/HEAD)

[Full Changelog](https://github.com/poets-ai/elegy/compare/0.7.0...HEAD)

**Merged pull requests:**

- More Docs: Expand documentation for the low-level API guides. [\#168](https://github.com/poets-ai/elegy/pull/168) ([cgarciae](https://github.com/cgarciae))
- Rich Summary: uses rich to style the summary console output [\#167](https://github.com/poets-ai/elegy/pull/167) ([cgarciae](https://github.com/cgarciae))

## [0.7.0](https://github.com/poets-ai/elegy/tree/0.7.0) (2021-02-22)

[Full Changelog](https://github.com/poets-ai/elegy/compare/0.6.0...0.7.0)
Expand Down Expand Up @@ -40,7 +49,6 @@
**Merged pull requests:**

- fix-maybe-initialize [\#155](https://github.com/poets-ai/elegy/pull/155) ([cgarciae](https://github.com/cgarciae))
- Add simple Flax low-level API Model example to README.md [\#153](https://github.com/poets-ai/elegy/pull/153) ([sooheon](https://github.com/sooheon))

## [0.4.0](https://github.com/poets-ai/elegy/tree/0.4.0) (2021-02-01)

Expand All @@ -52,6 +60,7 @@

**Merged pull requests:**

- Add simple Flax low-level API Model example to README.md [\#153](https://github.com/poets-ai/elegy/pull/153) ([sooheon](https://github.com/sooheon))
- Update Getting Started + README [\#152](https://github.com/poets-ai/elegy/pull/152) ([cgarciae](https://github.com/cgarciae))
- Pretrained ResNet fix after \#139 [\#151](https://github.com/poets-ai/elegy/pull/151) ([alexander-g](https://github.com/alexander-g))
- Dataset: better default batch\_fn and custom batch\_fn [\#148](https://github.com/poets-ai/elegy/pull/148) ([alexander-g](https://github.com/alexander-g))
Expand Down
15 changes: 15 additions & 0 deletions docs/api/nn/MultiHeadAttention.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

# elegy.nn.MultiHeadAttention

::: elegy.nn.multi_head_attention.MultiHeadAttention
selection:
inherited_members: true
members:
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- init

15 changes: 15 additions & 0 deletions docs/api/nn/Transformer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

# elegy.nn.Transformer

::: elegy.nn.transformers.Transformer
selection:
inherited_members: true
members:
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- init

15 changes: 15 additions & 0 deletions docs/api/nn/TransformerDecoder.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

# elegy.nn.TransformerDecoder

::: elegy.nn.transformers.TransformerDecoder
selection:
inherited_members: true
members:
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- init

15 changes: 15 additions & 0 deletions docs/api/nn/TransformerDecoderLayer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

# elegy.nn.TransformerDecoderLayer

::: elegy.nn.transformers.TransformerDecoderLayer
selection:
inherited_members: true
members:
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- init

15 changes: 15 additions & 0 deletions docs/api/nn/TransformerEncoder.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

# elegy.nn.TransformerEncoder

::: elegy.nn.transformers.TransformerEncoder
selection:
inherited_members: true
members:
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- init

15 changes: 15 additions & 0 deletions docs/api/nn/TransformerEncoderLayer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

# elegy.nn.TransformerEncoderLayer

::: elegy.nn.transformers.TransformerEncoderLayer
selection:
inherited_members: true
members:
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- init

5 changes: 3 additions & 2 deletions docs/low-level-api/default-implementation.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ call_summary_step call_pred_step call_test_step call_trai
```
This structure allows you to for example override `test_step` and still be able to use use `fit` since `train_step` (called by `fit`) will call your `test_step` via `grad_step`. It also means that if you implement `test_step` but not `pred_step` there is a high chance both `predict` and `summary` will not work.

##### call_* methods
The `call_<method>` method family are _entrypoints_ that usually just redirect to their inputs to `<method>`, you choose to override these if you need to perform some some computation only when method in question is the entry point. For example if you want to change the behavior of `evaluate` without affecting the behavior of `fit` while preserving most of the default implementation you can override `call_step_step` to do the corresponding adjustments and then call `test_step`. Since `train_step` does not depend on `call_step_step` then the change will manifest during `evaluate` but not during `fit`.
#### call_* methods
The `call_<method>` method family are _entrypoints_ that usually just redirect to their inputs to `<method>`, you choose to override these if you need to perform some some computation only when method in question is the entry point i.e. when its not called by other methods in the bottom path.
For example if you want to change the behavior of `evaluate` without affecting the behavior of `fit` while preserving most of the default implementation you can override `call_step_step` to do the corresponding adjustments and then call `test_step`. Since `train_step` does not depend on `call_step_step` then the change will manifest during `evaluate` but not during `fit`.
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,5 @@ Here we implement the same `LinearClassifier` from the [basics](./basics) sectio

### Default Implementation
The default implementation of `pred_step` does the following:

* Calls `api_module.init` or `api_module.apply` depending on state of `initializing`. `api_module` of type `GeneralizedModule` is a wrapper over the `module` object passed by the user to the `Model`s constructor.
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ There are cases however where you might want to implement a forward pass inside

### Default Implementation
The default implementation of `pred_step` does the following:

* Call `pred_step` to get `y_pred`.
* Calls `api_loss.init` or `api_loss.apply` depending on state of `initializing`. `api_loss` of type `Losses` computes the aggregated batch loss from the loss functions passed by the user through the `loss` argument in the `Model`s constructor, and also computes a running mean of each loss individually which is passed for reporting to `logs`.
* Calls `api_metrics.init` or `api_metrics.apply` depending on state of `initializing`. `api_metrics` of type `Metrics` calculates the metrics passed by the user through the `metrics` argument in the `Model`s constructor and passes their values to `logs` for reporting.
24 changes: 18 additions & 6 deletions elegy/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@
from .batch_normalization import BatchNormalization
from .conv import Conv1D, Conv2D, Conv3D, ConvND
from .dropout import Dropout
from .embedding import Embedding, EmbedLookupStyle
from .flatten import Flatten, Reshape
from .layer_normalization import InstanceNormalization, LayerNormalization
from .linear import Linear
from .sequential_module import Sequential, sequential

from .layer_normalization import LayerNormalization, InstanceNormalization
from .embedding import Embedding, EmbedLookupStyle
from .pool import MaxPool, AvgPool
from .moving_averages import EMAParamsTree
from . import transformers
from .multi_head_attention import MultiHeadAttention
from .pool import AvgPool, MaxPool
from .sequential_module import Sequential, sequential
from .transformers import (
Transformer,
TransformerDecoder,
TransformerDecoderLayer,
TransformerEncoder,
TransformerEncoderLayer,
)

__all__ = [
"EMAParamsTree",
"BatchNormalization",
"MultiHeadAttention",
"Transformer",
"TransformerDecoder",
"TransformerDecoderLayer",
"TransformerEncoder",
"TransformerEncoderLayer",
"Conv1D",
"Conv2D",
"Conv3D",
Expand Down
70 changes: 31 additions & 39 deletions elegy/nn/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,13 @@ class TransformerEncoderLayer(Module):
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
in a different way during application.

Args:
Arguments:
head_size: the number of expected features in the input (required).
num_heads: the number of heads in the multiheadattention models (required).
output_size: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of intermediate layer, relu or gelu (default=relu).

Examples::
>>> # encoder_layer = nn.TransformerEncoderLayer(head_size=512, num_heads=8)
>>> # src = torch.rand(10, 32, 512)
>>> # out = encoder_layer(src)
"""

def __init__(
Expand All @@ -47,7 +43,7 @@ def __init__(
super().__init__(**kwargs)
self.head_size = head_size
self.num_heads = num_heads
self.output_size = output_size
self.output_size = output_size if output_size is not None else head_size
self.dropout = dropout
self.activation = activation

Expand All @@ -59,7 +55,7 @@ def call(
) -> np.ndarray:
r"""Pass the input through the encoder layer.

Args:
Arguments:
src: the sequence to the encoder layer (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Expand Down Expand Up @@ -95,12 +91,12 @@ class TransformerEncoder(Module):
r"""
TransformerEncoder is a stack of N encoder layers

Args:
Arguments:
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).

Examples::
Examples:
>>> import elegy
>>> transformer_encoder = elegy.nn.transformers.TransformerEncoder(
... lambda: elegy.nn.transformers.TransformerEncoderLayer(head_size=512, num_heads=8),
Expand Down Expand Up @@ -131,7 +127,7 @@ def call(
) -> np.ndarray:
r"""Pass the input through the encoder layers in turn.

Args:
Arguments:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Expand All @@ -158,32 +154,27 @@ class TransformerDecoderLayer(Module):
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
in a different way during application.

Args:
Arguments:
head_size: the number of expected features in the input (required).
num_heads: the number of heads in the multiheadattention models (required).
output_size: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of intermediate layer, relu or gelu (default=relu).

Examples::
>>> # decoder_layer = nn.TransformerDecoderLayer(head_size=512, num_heads=8)
>>> # memory = torch.rand(10, 32, 512)
>>> # tgt = torch.rand(20, 32, 512)
>>> # out = decoder_layer(tgt, memory)
"""

def __init__(
self,
head_size: int,
num_heads: int,
output_size: int = 2048,
output_size: tp.Optional[int] = None,
dropout: float = 0.1,
activation: tp.Callable[[np.ndarray], np.ndarray] = jax.nn.relu,
):
super().__init__()
self.head_size = head_size
self.num_heads = num_heads
self.output_size = output_size
self.output_size = output_size if output_size is not None else head_size
self.dropout = dropout
self.activation = activation

Expand All @@ -198,7 +189,7 @@ def call(
) -> np.ndarray:
r"""Pass the inputs (and mask) through the decoder layer.

Args:
Arguments:
tgt: the sequence to the decoder layer (required).
memory: the sequence from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
Expand Down Expand Up @@ -237,17 +228,11 @@ def call(
class TransformerDecoder(Module):
r"""TransformerDecoder is a stack of N decoder layers

Args:
Arguments:
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
num_layers: the number of sub-decoder-layers in the decoder (required).
norm: the layer normalization component (optional).

Examples::
>>> # decoder_layer = nn.TransformerDecoderLayer(head_size=512, num_heads=8)
>>> # transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
>>> # memory = torch.rand(10, 32, 512)
>>> # tgt = torch.rand(20, 32, 512)
>>> # out = transformer_decoder(tgt, memory)
"""

def __init__(
Expand All @@ -272,7 +257,7 @@ def call(
) -> np.ndarray:
r"""Pass the inputs (and mask) through the decoder layer in turn.

Args:
Arguments:
tgt: the sequence to the decoder (required).
memory: the sequence from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
Expand Down Expand Up @@ -309,7 +294,7 @@ class Transformer(Module):
Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805)
model with corresponding parameters.

Args:
Arguments:
head_size: the number of expected features in the encoder/decoder inputs (default=512).
num_heads: the number of heads in the multiheadattention models (default=8).
num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
Expand All @@ -320,14 +305,23 @@ class Transformer(Module):
custom_encoder: custom encoder (default=None).
custom_decoder: custom decoder (default=None).

Examples::
>>> # transformer_model = nn.Transformer(num_heads=16, num_encoder_layers=12)
>>> # src = torch.rand((10, 32, 512))
>>> # tgt = torch.rand((20, 32, 512))
>>> # out = transformer_model(src, tgt)
Examples:
>>> import elegy
>>> import numpy as np

>>> transformer_model = elegy.nn.Transformer(
... head_size=64,
... num_heads=4,
... num_encoder_layers=2,
... num_decoder_layers=2,
... )

>>> src = np.random.uniform(size=(5, 32, 64))
>>> tgt = np.random.uniform(size=(5, 32, 64))

>>> _, params, collections = transformer_model.init(rng=elegy.RNGSeq(42))(src, tgt)
>>> out, params, collections = transformer_model.apply(params, collections, rng=elegy.RNGSeq(420))(src, tgt)

Note: A full example to apply nn.Transformer module for the word language model is available in
https://github.com/pytorch/examples/tree/master/word_language_model
"""

def __init__(
Expand All @@ -336,7 +330,7 @@ def __init__(
num_heads: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
output_size: int = 2048,
output_size: tp.Optional[int] = None,
dropout: float = 0.1,
activation: tp.Callable[[np.ndarray], np.ndarray] = jax.nn.relu,
custom_encoder: tp.Optional[tp.Any] = None,
Expand Down Expand Up @@ -367,7 +361,7 @@ def call(
) -> np.ndarray:
r"""Take in and process masked source/target sequences.

Args:
Arguments:
src: the sequence to the encoder (required).
tgt: the sequence to the decoder (required).
src_mask: the additive mask for the src sequence (optional).
Expand Down Expand Up @@ -406,8 +400,6 @@ def call(
where S is the source sequence length, T is the target sequence length, N is the
batch size, E is the feature number

Examples:
>>> # output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
"""

if src.shape[0] != tgt.shape[0]:
Expand Down
Loading