From 4cdfd2f56bec4a596a60add7f797568b6436dba5 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Mon, 1 Mar 2021 16:50:51 -0500 Subject: [PATCH 1/2] simple transformer test --- elegy/nn/__init__.py | 24 ++++++++++++++++++------ elegy/nn/transformers.py | 27 +++++++++++++++++++-------- elegy/nn/transformers_test.py | 24 ++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 14 deletions(-) create mode 100644 elegy/nn/transformers_test.py diff --git a/elegy/nn/__init__.py b/elegy/nn/__init__.py index 5350ecc1..747cacfd 100644 --- a/elegy/nn/__init__.py +++ b/elegy/nn/__init__.py @@ -1,19 +1,31 @@ from .batch_normalization import BatchNormalization from .conv import Conv1D, Conv2D, Conv3D, ConvND from .dropout import Dropout +from .embedding import Embedding, EmbedLookupStyle from .flatten import Flatten, Reshape +from .layer_normalization import InstanceNormalization, LayerNormalization from .linear import Linear -from .sequential_module import Sequential, sequential - -from .layer_normalization import LayerNormalization, InstanceNormalization -from .embedding import Embedding, EmbedLookupStyle -from .pool import MaxPool, AvgPool from .moving_averages import EMAParamsTree -from . import transformers +from .multi_head_attention import MultiHeadAttention +from .pool import AvgPool, MaxPool +from .sequential_module import Sequential, sequential +from .transformers import ( + Transformer, + TransformerDecoder, + TransformerDecoderLayer, + TransformerEncoder, + TransformerEncoderLayer, +) __all__ = [ "EMAParamsTree", "BatchNormalization", + "MultiHeadAttention", + "Transformer", + "TransformerDecoder", + "TransformerDecoderLayer", + "TransformerEncoder", + "TransformerEncoderLayer", "Conv1D", "Conv2D", "Conv3D", diff --git a/elegy/nn/transformers.py b/elegy/nn/transformers.py index bf442d66..2bc2568b 100644 --- a/elegy/nn/transformers.py +++ b/elegy/nn/transformers.py @@ -47,7 +47,7 @@ def __init__( super().__init__(**kwargs) self.head_size = head_size self.num_heads = num_heads - self.output_size = output_size + self.output_size = output_size if output_size is not None else head_size self.dropout = dropout self.activation = activation @@ -176,14 +176,14 @@ def __init__( self, head_size: int, num_heads: int, - output_size: int = 2048, + output_size: tp.Optional[int] = None, dropout: float = 0.1, activation: tp.Callable[[np.ndarray], np.ndarray] = jax.nn.relu, ): super().__init__() self.head_size = head_size self.num_heads = num_heads - self.output_size = output_size + self.output_size = output_size if output_size is not None else head_size self.dropout = dropout self.activation = activation @@ -321,10 +321,21 @@ class Transformer(Module): custom_decoder: custom decoder (default=None). Examples:: - >>> # transformer_model = nn.Transformer(num_heads=16, num_encoder_layers=12) - >>> # src = torch.rand((10, 32, 512)) - >>> # tgt = torch.rand((20, 32, 512)) - >>> # out = transformer_model(src, tgt) + >>> import elegy + >>> import numpy as np + + >>> transformer_model = elegy.nn.Transformer( + ... head_size=64, + ... num_heads=4, + ... num_encoder_layers=2, + ... num_decoder_layers=2, + ... ) + + >>> src = np.random.uniform(size=(5, 32, 64)) + >>> tgt = np.random.uniform(size=(5, 32, 64)) + + >>> _, params, collections = transformer_model.init(rng=elegy.RNGSeq(42))(src, tgt) + >>> out, params, collections = transformer_model.apply(params, collections, rng=elegy.RNGSeq(420))(src, tgt) Note: A full example to apply nn.Transformer module for the word language model is available in https://github.com/pytorch/examples/tree/master/word_language_model @@ -336,7 +347,7 @@ def __init__( num_heads: int = 8, num_encoder_layers: int = 6, num_decoder_layers: int = 6, - output_size: int = 2048, + output_size: tp.Optional[int] = None, dropout: float = 0.1, activation: tp.Callable[[np.ndarray], np.ndarray] = jax.nn.relu, custom_encoder: tp.Optional[tp.Any] = None, diff --git a/elegy/nn/transformers_test.py b/elegy/nn/transformers_test.py new file mode 100644 index 00000000..502ee354 --- /dev/null +++ b/elegy/nn/transformers_test.py @@ -0,0 +1,24 @@ +from unittest import TestCase + +import elegy +import jax +import jax.numpy as jnp +import numpy as np + + +class TransformerTest(TestCase): + def test_connects(self): + transformer_model = elegy.nn.Transformer( + head_size=64, + num_heads=4, + num_encoder_layers=2, + num_decoder_layers=2, + ) + + src = np.random.uniform(size=(5, 32, 64)) + tgt = np.random.uniform(size=(5, 32, 64)) + + _, params, collections = transformer_model.init(rng=elegy.RNGSeq(42))(src, tgt) + out, params, collections = transformer_model.apply( + params, collections, rng=elegy.RNGSeq(420) + )(src, tgt) From bf5592f5b72dbaf50e05b8add8af88cf87399473 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Mon, 1 Mar 2021 17:20:51 -0500 Subject: [PATCH 2/2] fix docs --- CHANGELOG.md | 11 +++- docs/api/nn/MultiHeadAttention.md | 15 ++++++ docs/api/nn/Transformer.md | 15 ++++++ docs/api/nn/TransformerDecoder.md | 15 ++++++ docs/api/nn/TransformerDecoderLayer.md | 15 ++++++ docs/api/nn/TransformerEncoder.md | 15 ++++++ docs/api/nn/TransformerEncoderLayer.md | 15 ++++++ docs/low-level-api/default-implementation.md | 5 +- docs/low-level-api/{ => methods}/pred_step.md | 1 + docs/low-level-api/{ => methods}/test_step.md | 1 + elegy/nn/transformers.py | 43 +++++---------- mkdocs.yml | 54 ++++++++++--------- 12 files changed, 146 insertions(+), 59 deletions(-) create mode 100644 docs/api/nn/MultiHeadAttention.md create mode 100644 docs/api/nn/Transformer.md create mode 100644 docs/api/nn/TransformerDecoder.md create mode 100644 docs/api/nn/TransformerDecoderLayer.md create mode 100644 docs/api/nn/TransformerEncoder.md create mode 100644 docs/api/nn/TransformerEncoderLayer.md rename docs/low-level-api/{ => methods}/pred_step.md (99%) rename docs/low-level-api/{ => methods}/test_step.md (99%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5216823d..e405b0cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ # Changelog +## [Unreleased](https://github.com/poets-ai/elegy/tree/HEAD) + +[Full Changelog](https://github.com/poets-ai/elegy/compare/0.7.0...HEAD) + +**Merged pull requests:** + +- More Docs: Expand documentation for the low-level API guides. [\#168](https://github.com/poets-ai/elegy/pull/168) ([cgarciae](https://github.com/cgarciae)) +- Rich Summary: uses rich to style the summary console output [\#167](https://github.com/poets-ai/elegy/pull/167) ([cgarciae](https://github.com/cgarciae)) + ## [0.7.0](https://github.com/poets-ai/elegy/tree/0.7.0) (2021-02-22) [Full Changelog](https://github.com/poets-ai/elegy/compare/0.6.0...0.7.0) @@ -40,7 +49,6 @@ **Merged pull requests:** - fix-maybe-initialize [\#155](https://github.com/poets-ai/elegy/pull/155) ([cgarciae](https://github.com/cgarciae)) -- Add simple Flax low-level API Model example to README.md [\#153](https://github.com/poets-ai/elegy/pull/153) ([sooheon](https://github.com/sooheon)) ## [0.4.0](https://github.com/poets-ai/elegy/tree/0.4.0) (2021-02-01) @@ -52,6 +60,7 @@ **Merged pull requests:** +- Add simple Flax low-level API Model example to README.md [\#153](https://github.com/poets-ai/elegy/pull/153) ([sooheon](https://github.com/sooheon)) - Update Getting Started + README [\#152](https://github.com/poets-ai/elegy/pull/152) ([cgarciae](https://github.com/cgarciae)) - Pretrained ResNet fix after \#139 [\#151](https://github.com/poets-ai/elegy/pull/151) ([alexander-g](https://github.com/alexander-g)) - Dataset: better default batch\_fn and custom batch\_fn [\#148](https://github.com/poets-ai/elegy/pull/148) ([alexander-g](https://github.com/alexander-g)) diff --git a/docs/api/nn/MultiHeadAttention.md b/docs/api/nn/MultiHeadAttention.md new file mode 100644 index 00000000..fc6c77b6 --- /dev/null +++ b/docs/api/nn/MultiHeadAttention.md @@ -0,0 +1,15 @@ + +# elegy.nn.MultiHeadAttention + +::: elegy.nn.multi_head_attention.MultiHeadAttention + selection: + inherited_members: true + members: + - __init__ + - call + - add_parameter + - get_parameters + - set_parameters + - reset + - init + \ No newline at end of file diff --git a/docs/api/nn/Transformer.md b/docs/api/nn/Transformer.md new file mode 100644 index 00000000..28a4d29e --- /dev/null +++ b/docs/api/nn/Transformer.md @@ -0,0 +1,15 @@ + +# elegy.nn.Transformer + +::: elegy.nn.transformers.Transformer + selection: + inherited_members: true + members: + - __init__ + - call + - add_parameter + - get_parameters + - set_parameters + - reset + - init + \ No newline at end of file diff --git a/docs/api/nn/TransformerDecoder.md b/docs/api/nn/TransformerDecoder.md new file mode 100644 index 00000000..c6d5a378 --- /dev/null +++ b/docs/api/nn/TransformerDecoder.md @@ -0,0 +1,15 @@ + +# elegy.nn.TransformerDecoder + +::: elegy.nn.transformers.TransformerDecoder + selection: + inherited_members: true + members: + - __init__ + - call + - add_parameter + - get_parameters + - set_parameters + - reset + - init + \ No newline at end of file diff --git a/docs/api/nn/TransformerDecoderLayer.md b/docs/api/nn/TransformerDecoderLayer.md new file mode 100644 index 00000000..b74615b0 --- /dev/null +++ b/docs/api/nn/TransformerDecoderLayer.md @@ -0,0 +1,15 @@ + +# elegy.nn.TransformerDecoderLayer + +::: elegy.nn.transformers.TransformerDecoderLayer + selection: + inherited_members: true + members: + - __init__ + - call + - add_parameter + - get_parameters + - set_parameters + - reset + - init + \ No newline at end of file diff --git a/docs/api/nn/TransformerEncoder.md b/docs/api/nn/TransformerEncoder.md new file mode 100644 index 00000000..5b75973f --- /dev/null +++ b/docs/api/nn/TransformerEncoder.md @@ -0,0 +1,15 @@ + +# elegy.nn.TransformerEncoder + +::: elegy.nn.transformers.TransformerEncoder + selection: + inherited_members: true + members: + - __init__ + - call + - add_parameter + - get_parameters + - set_parameters + - reset + - init + \ No newline at end of file diff --git a/docs/api/nn/TransformerEncoderLayer.md b/docs/api/nn/TransformerEncoderLayer.md new file mode 100644 index 00000000..f412403f --- /dev/null +++ b/docs/api/nn/TransformerEncoderLayer.md @@ -0,0 +1,15 @@ + +# elegy.nn.TransformerEncoderLayer + +::: elegy.nn.transformers.TransformerEncoderLayer + selection: + inherited_members: true + members: + - __init__ + - call + - add_parameter + - get_parameters + - set_parameters + - reset + - init + \ No newline at end of file diff --git a/docs/low-level-api/default-implementation.md b/docs/low-level-api/default-implementation.md index bf701206..a58a7034 100644 --- a/docs/low-level-api/default-implementation.md +++ b/docs/low-level-api/default-implementation.md @@ -12,5 +12,6 @@ call_summary_step call_pred_step call_test_step call_trai ``` This structure allows you to for example override `test_step` and still be able to use use `fit` since `train_step` (called by `fit`) will call your `test_step` via `grad_step`. It also means that if you implement `test_step` but not `pred_step` there is a high chance both `predict` and `summary` will not work. -##### call_* methods -The `call_` method family are _entrypoints_ that usually just redirect to their inputs to ``, you choose to override these if you need to perform some some computation only when method in question is the entry point. For example if you want to change the behavior of `evaluate` without affecting the behavior of `fit` while preserving most of the default implementation you can override `call_step_step` to do the corresponding adjustments and then call `test_step`. Since `train_step` does not depend on `call_step_step` then the change will manifest during `evaluate` but not during `fit`. \ No newline at end of file +#### call_* methods +The `call_` method family are _entrypoints_ that usually just redirect to their inputs to ``, you choose to override these if you need to perform some some computation only when method in question is the entry point i.e. when its not called by other methods in the bottom path. +For example if you want to change the behavior of `evaluate` without affecting the behavior of `fit` while preserving most of the default implementation you can override `call_step_step` to do the corresponding adjustments and then call `test_step`. Since `train_step` does not depend on `call_step_step` then the change will manifest during `evaluate` but not during `fit`. \ No newline at end of file diff --git a/docs/low-level-api/pred_step.md b/docs/low-level-api/methods/pred_step.md similarity index 99% rename from docs/low-level-api/pred_step.md rename to docs/low-level-api/methods/pred_step.md index 3cc947d8..cd90ee98 100644 --- a/docs/low-level-api/pred_step.md +++ b/docs/low-level-api/methods/pred_step.md @@ -68,4 +68,5 @@ Here we implement the same `LinearClassifier` from the [basics](./basics) sectio ### Default Implementation The default implementation of `pred_step` does the following: + * Calls `api_module.init` or `api_module.apply` depending on state of `initializing`. `api_module` of type `GeneralizedModule` is a wrapper over the `module` object passed by the user to the `Model`s constructor. \ No newline at end of file diff --git a/docs/low-level-api/test_step.md b/docs/low-level-api/methods/test_step.md similarity index 99% rename from docs/low-level-api/test_step.md rename to docs/low-level-api/methods/test_step.md index 958d0139..88509907 100644 --- a/docs/low-level-api/test_step.md +++ b/docs/low-level-api/methods/test_step.md @@ -133,6 +133,7 @@ There are cases however where you might want to implement a forward pass inside ### Default Implementation The default implementation of `pred_step` does the following: + * Call `pred_step` to get `y_pred`. * Calls `api_loss.init` or `api_loss.apply` depending on state of `initializing`. `api_loss` of type `Losses` computes the aggregated batch loss from the loss functions passed by the user through the `loss` argument in the `Model`s constructor, and also computes a running mean of each loss individually which is passed for reporting to `logs`. * Calls `api_metrics.init` or `api_metrics.apply` depending on state of `initializing`. `api_metrics` of type `Metrics` calculates the metrics passed by the user through the `metrics` argument in the `Model`s constructor and passes their values to `logs` for reporting. \ No newline at end of file diff --git a/elegy/nn/transformers.py b/elegy/nn/transformers.py index 2bc2568b..ed861e2d 100644 --- a/elegy/nn/transformers.py +++ b/elegy/nn/transformers.py @@ -22,17 +22,13 @@ class TransformerEncoderLayer(Module): Neural Information Processing Systems, pages 6000-6010. Users may modify or implement in a different way during application. - Args: + Arguments: head_size: the number of expected features in the input (required). num_heads: the number of heads in the multiheadattention models (required). output_size: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). activation: the activation function of intermediate layer, relu or gelu (default=relu). - Examples:: - >>> # encoder_layer = nn.TransformerEncoderLayer(head_size=512, num_heads=8) - >>> # src = torch.rand(10, 32, 512) - >>> # out = encoder_layer(src) """ def __init__( @@ -59,7 +55,7 @@ def call( ) -> np.ndarray: r"""Pass the input through the encoder layer. - Args: + Arguments: src: the sequence to the encoder layer (required). mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). @@ -95,12 +91,12 @@ class TransformerEncoder(Module): r""" TransformerEncoder is a stack of N encoder layers - Args: + Arguments: encoder_layer: an instance of the TransformerEncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). norm: the layer normalization component (optional). - Examples:: + Examples: >>> import elegy >>> transformer_encoder = elegy.nn.transformers.TransformerEncoder( ... lambda: elegy.nn.transformers.TransformerEncoderLayer(head_size=512, num_heads=8), @@ -131,7 +127,7 @@ def call( ) -> np.ndarray: r"""Pass the input through the encoder layers in turn. - Args: + Arguments: src: the sequence to the encoder (required). mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). @@ -158,18 +154,13 @@ class TransformerDecoderLayer(Module): Neural Information Processing Systems, pages 6000-6010. Users may modify or implement in a different way during application. - Args: + Arguments: head_size: the number of expected features in the input (required). num_heads: the number of heads in the multiheadattention models (required). output_size: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). activation: the activation function of intermediate layer, relu or gelu (default=relu). - Examples:: - >>> # decoder_layer = nn.TransformerDecoderLayer(head_size=512, num_heads=8) - >>> # memory = torch.rand(10, 32, 512) - >>> # tgt = torch.rand(20, 32, 512) - >>> # out = decoder_layer(tgt, memory) """ def __init__( @@ -198,7 +189,7 @@ def call( ) -> np.ndarray: r"""Pass the inputs (and mask) through the decoder layer. - Args: + Arguments: tgt: the sequence to the decoder layer (required). memory: the sequence from the last layer of the encoder (required). tgt_mask: the mask for the tgt sequence (optional). @@ -237,17 +228,11 @@ def call( class TransformerDecoder(Module): r"""TransformerDecoder is a stack of N decoder layers - Args: + Arguments: decoder_layer: an instance of the TransformerDecoderLayer() class (required). num_layers: the number of sub-decoder-layers in the decoder (required). norm: the layer normalization component (optional). - Examples:: - >>> # decoder_layer = nn.TransformerDecoderLayer(head_size=512, num_heads=8) - >>> # transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) - >>> # memory = torch.rand(10, 32, 512) - >>> # tgt = torch.rand(20, 32, 512) - >>> # out = transformer_decoder(tgt, memory) """ def __init__( @@ -272,7 +257,7 @@ def call( ) -> np.ndarray: r"""Pass the inputs (and mask) through the decoder layer in turn. - Args: + Arguments: tgt: the sequence to the decoder (required). memory: the sequence from the last layer of the encoder (required). tgt_mask: the mask for the tgt sequence (optional). @@ -309,7 +294,7 @@ class Transformer(Module): Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. - Args: + Arguments: head_size: the number of expected features in the encoder/decoder inputs (default=512). num_heads: the number of heads in the multiheadattention models (default=8). num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). @@ -320,7 +305,7 @@ class Transformer(Module): custom_encoder: custom encoder (default=None). custom_decoder: custom decoder (default=None). - Examples:: + Examples: >>> import elegy >>> import numpy as np @@ -337,8 +322,6 @@ class Transformer(Module): >>> _, params, collections = transformer_model.init(rng=elegy.RNGSeq(42))(src, tgt) >>> out, params, collections = transformer_model.apply(params, collections, rng=elegy.RNGSeq(420))(src, tgt) - Note: A full example to apply nn.Transformer module for the word language model is available in - https://github.com/pytorch/examples/tree/master/word_language_model """ def __init__( @@ -378,7 +361,7 @@ def call( ) -> np.ndarray: r"""Take in and process masked source/target sequences. - Args: + Arguments: src: the sequence to the encoder (required). tgt: the sequence to the decoder (required). src_mask: the additive mask for the src sequence (optional). @@ -417,8 +400,6 @@ def call( where S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - Examples: - >>> # output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) """ if src.shape[0] != tgt.shape[0]: diff --git a/mkdocs.yml b/mkdocs.yml index 461291d5..74c1c75e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -9,34 +9,32 @@ nav: - Low Level API: getting-started/low-level-api.ipynb - High Level API: - Intro: basic-api/modules-losses-metrics.md - - Data Sources: na.md - - Module: - - Flax: na.md - - Haiku: na.md - - Elegy: na.md - - Pure Jax: na.md - - Supporting Other Modules: na.md - - Losses: na.md - - Metrics: na.md - - Dependency Injection: na.md - - Hooks: na.md - - Optimizer: - - Optax Optimizer: na.md - - Elegy Optimizer: na.md - - Monitoring Learning Rate: na.md - - Supporting Other Optimizers: na.md - - Callbacks: na.md - - Serialization: na.md + # - Data Sources: na.md + # - Module: + # - Flax: na.md + # - Haiku: na.md + # - Elegy: na.md + # - Pure Jax: na.md + # - Supporting Other Modules: na.md + # - Losses: na.md + # - Metrics: na.md + # - Dependency Injection: na.md + # - Hooks: na.md + # - Optimizer: + # - Optax Optimizer: na.md + # - Elegy Optimizer: na.md + # - Monitoring Learning Rate: na.md + # - Supporting Other Optimizers: na.md + # - Callbacks: na.md + # - Serialization: na.md - Low Level API: - Basics: low-level-api/basics.md - States: low-level-api/states.md - - pred_step: low-level-api/pred_step.md - - test_step: na.md - - grad_step: na.md - - train_step: na.md - - Supporting the High Level API: low-level-api/supporting-the-high-level-api.md - - Implementing a Model from Scratch: na.md -- Elegy Module: module-system.md + - Methods: + - pred_step: low-level-api/methods/pred_step.md + - test_step: low-level-api/methods/test_step.md + - Default Implementation: low-level-api/default-implementation.md +# - Elegy Module: module-system.md - Contributing: contributing.md - API Reference: GeneralizedModule: api/GeneralizedModule.md @@ -167,8 +165,14 @@ nav: LayerNormalization: api/nn/LayerNormalization.md Linear: api/nn/Linear.md MaxPool: api/nn/MaxPool.md + MultiHeadAttention: api/nn/MultiHeadAttention.md Reshape: api/nn/Reshape.md Sequential: api/nn/Sequential.md + Transformer: api/nn/Transformer.md + TransformerDecoder: api/nn/TransformerDecoder.md + TransformerDecoderLayer: api/nn/TransformerDecoderLayer.md + TransformerEncoder: api/nn/TransformerEncoder.md + TransformerEncoderLayer: api/nn/TransformerEncoderLayer.md sequential: api/nn/sequential.md regularizers: GlobalL1: api/regularizers/GlobalL1.md