Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reformer #3351

Merged
merged 162 commits into from
May 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
162 commits
Select commit Hold shift + click to select a range
ee0ce08
first copy & past commit from Bert and morgans LSH code
patrickvonplaten Mar 19, 2020
3259115
add easy way to compare to trax original code
patrickvonplaten Mar 20, 2020
25d162e
translate most of function
patrickvonplaten Mar 23, 2020
dc07c08
make trax lsh self attention deterministic with numpy seed + copy pas…
patrickvonplaten Mar 23, 2020
09d4230
add same config
patrickvonplaten Mar 23, 2020
9386450
add same config
patrickvonplaten Mar 23, 2020
3fb1182
fix merge conflicts
patrickvonplaten Apr 6, 2020
ebb9d3f
make layer init work
patrickvonplaten Mar 31, 2020
c910e03
implemented hash_vectors function for lsh attention
patrickvonplaten Apr 1, 2020
b956933
continue reformer translation
patrickvonplaten Apr 1, 2020
a449c2e
hf LSHSelfAttentionLayer gives same output as trax layer
patrickvonplaten Apr 2, 2020
35491d8
refactor code
patrickvonplaten Apr 2, 2020
c9a0919
refactor code
patrickvonplaten Apr 2, 2020
4ef4739
refactor code
patrickvonplaten Apr 2, 2020
f580074
refactor
patrickvonplaten Apr 2, 2020
4176564
refactor + add reformer config
patrickvonplaten Apr 2, 2020
8e02fe7
delete bogus file
patrickvonplaten Apr 2, 2020
2af8377
split reformer attention layer into two layers
patrickvonplaten Apr 3, 2020
6fe9478
fix merge conflicts
patrickvonplaten Apr 6, 2020
9e6e1af
save intermediate step
patrickvonplaten Apr 3, 2020
1855074
save intermediate step
patrickvonplaten Apr 3, 2020
1a4e61a
make test work
patrickvonplaten Apr 3, 2020
da6bfe4
add complete reformer block layer
patrickvonplaten Apr 3, 2020
2825d24
finish reformer layer
patrickvonplaten Apr 4, 2020
45e6635
implement causal and self mask
patrickvonplaten Apr 5, 2020
b5ed5d4
clean reformer test and refactor code
patrickvonplaten Apr 5, 2020
ddb2f09
update init
patrickvonplaten Apr 6, 2020
cbb5ab9
fix device for GPU
patrickvonplaten Apr 6, 2020
f17fd5b
fix chunk length init for tests
patrickvonplaten Apr 6, 2020
eca8cce
include morgans optimization
patrickvonplaten Apr 6, 2020
db2ebb1
improve memory a bit
patrickvonplaten Apr 6, 2020
04aa067
improve comment
patrickvonplaten Apr 6, 2020
4aec75e
factorize num_buckets
patrickvonplaten Apr 7, 2020
d030e39
better testing parameters
patrickvonplaten Apr 7, 2020
d318089
make whole model work
patrickvonplaten Apr 9, 2020
4f0b114
make lm model work
patrickvonplaten Apr 9, 2020
6c8bad6
add t5 copy paste tokenizer
patrickvonplaten Apr 10, 2020
b71ef16
add chunking feed forward
patrickvonplaten Apr 10, 2020
99427c6
clean config
patrickvonplaten Apr 14, 2020
4ffa925
add improved assert statements
patrickvonplaten Apr 14, 2020
b116e3c
make tokenizer work
patrickvonplaten Apr 14, 2020
79a0bab
improve test
patrickvonplaten Apr 14, 2020
aceb586
correct typo
patrickvonplaten Apr 14, 2020
a4814bd
extend config
patrickvonplaten Apr 14, 2020
5eeeb25
add complexer test
patrickvonplaten Apr 14, 2020
0ee5db4
add new axial position embeddings
patrickvonplaten Apr 15, 2020
938aa8b
add local block attention layer
patrickvonplaten Apr 15, 2020
4d7c23b
clean tests
patrickvonplaten Apr 15, 2020
50276de
refactor
patrickvonplaten Apr 15, 2020
37a2b00
better testing
patrickvonplaten Apr 15, 2020
07c0c72
save intermediate progress
patrickvonplaten Apr 15, 2020
060a691
clean test file
patrickvonplaten Apr 16, 2020
ace301f
make shorter input length work for model
patrickvonplaten Apr 16, 2020
80d18db
allow variable input length
patrickvonplaten Apr 16, 2020
86f4ac4
refactor
patrickvonplaten Apr 16, 2020
e571849
make forward pass for pretrained model work
patrickvonplaten Apr 16, 2020
d5e1363
add generation possibility
patrickvonplaten Apr 17, 2020
562d530
finish dropout and init
patrickvonplaten Apr 17, 2020
c98eafe
make style
patrickvonplaten Apr 17, 2020
9c9fab9
refactor
patrickvonplaten Apr 17, 2020
a188a39
add first version of RevNet Layers
patrickvonplaten Apr 17, 2020
8047573
make forward pass work and add convert file
patrickvonplaten Apr 18, 2020
31a596b
make uploaded model forward pass work
patrickvonplaten Apr 18, 2020
bae0700
make uploaded model forward pass work
patrickvonplaten Apr 18, 2020
831dcec
refactor code
patrickvonplaten Apr 18, 2020
57ee09c
add namedtuples and cache buckets
patrickvonplaten Apr 19, 2020
2d23fad
correct head masks
patrickvonplaten Apr 19, 2020
0c35bbf
refactor
patrickvonplaten Apr 19, 2020
232463e
made reformer more flexible
patrickvonplaten Apr 19, 2020
2648a94
make style
patrickvonplaten Apr 19, 2020
902408b
remove set max length
patrickvonplaten Apr 21, 2020
8ed63ab
add attention masks
patrickvonplaten Apr 22, 2020
513bb43
fix up tests
patrickvonplaten Apr 22, 2020
db60c23
fix conflict
patrickvonplaten Apr 30, 2020
9f359af
fix lsh attention mask
patrickvonplaten Apr 23, 2020
48097a0
make random seed optional for the moment
patrickvonplaten Apr 23, 2020
650e00c
improve memory in reformer
patrickvonplaten Apr 23, 2020
ccba9ac
add tests
patrickvonplaten Apr 23, 2020
f83721e
make style
patrickvonplaten Apr 23, 2020
125c86d
make sure masks work correctly
patrickvonplaten Apr 24, 2020
2beda9c
detach gradients
patrickvonplaten Apr 24, 2020
12e35e1
save intermediate
patrickvonplaten Apr 24, 2020
8b058e2
correct backprob through gather
patrickvonplaten Apr 24, 2020
69258b8
make style
patrickvonplaten Apr 24, 2020
44c3a7c
change back num hashes
patrickvonplaten Apr 25, 2020
48fff07
rename to labels
patrickvonplaten Apr 25, 2020
55842be
fix rotation shape
patrickvonplaten Apr 25, 2020
71426c0
fix detach
patrickvonplaten Apr 25, 2020
dfbcf8f
update
patrickvonplaten Apr 25, 2020
0ea564c
fix trainer
patrickvonplaten Apr 25, 2020
af3456c
fix backward dropout
patrickvonplaten Apr 26, 2020
002f19c
make reformer more flexible
patrickvonplaten Apr 26, 2020
7de3f4f
fix
patrickvonplaten May 7, 2020
6111bd5
fix
patrickvonplaten May 7, 2020
0c75149
add tests for fixed seed in reformer layer
patrickvonplaten Apr 26, 2020
7a03bc7
fix trainer typo
patrickvonplaten Apr 26, 2020
37943f3
fix typo in activations
patrickvonplaten Apr 26, 2020
0f751f5
add fp16 tests
patrickvonplaten Apr 28, 2020
8df5dcd
add fp16 training
patrickvonplaten Apr 28, 2020
51426b5
support fp16
patrickvonplaten Apr 28, 2020
b37fd3b
correct gradient bug in reformer
patrickvonplaten Apr 29, 2020
e3e05ef
add fast gelu
patrickvonplaten Apr 29, 2020
c3e32b4
re-add dropout for embedding dropout
patrickvonplaten Apr 29, 2020
52ee5ed
better naming
patrickvonplaten Apr 29, 2020
ece19ee
better naming
patrickvonplaten Apr 29, 2020
e661832
renaming
patrickvonplaten Apr 29, 2020
f1a6355
finalize test branch
patrickvonplaten Apr 29, 2020
ea1126e
finalize tests
patrickvonplaten Apr 30, 2020
d4bc3c6
add more tests
patrickvonplaten Apr 30, 2020
94086ac
finish tests
patrickvonplaten Apr 30, 2020
01f4074
fix
patrickvonplaten May 7, 2020
9dafbc2
fix type trainer
patrickvonplaten Apr 30, 2020
de08a57
fix fp16 tests
patrickvonplaten Apr 30, 2020
aa570dc
fix tests
patrickvonplaten Apr 30, 2020
a681d19
fix tests
patrickvonplaten Apr 30, 2020
320c045
fix tests
patrickvonplaten Apr 30, 2020
482c6cd
fix issue with dropout
patrickvonplaten Apr 30, 2020
d7905dd
fix dropout seeds
patrickvonplaten Apr 30, 2020
764e06e
correct random seed on gpu
patrickvonplaten Apr 30, 2020
a3e0f59
finalize random seed for dropout
patrickvonplaten Apr 30, 2020
c48f88a
finalize random seed for dropout
patrickvonplaten Apr 30, 2020
ce87cb6
remove duplicate line
patrickvonplaten Apr 30, 2020
d418dd0
correct half precision bug
patrickvonplaten May 1, 2020
3248e67
make style
patrickvonplaten May 1, 2020
6fe0648
refactor
patrickvonplaten May 1, 2020
c3031b8
refactor
patrickvonplaten May 1, 2020
6c2be30
docstring
patrickvonplaten May 1, 2020
3d266fb
remove sinusoidal position encodings for reformer
patrickvonplaten May 1, 2020
1be343f
move chunking to modeling_utils
patrickvonplaten May 1, 2020
a10eb2e
make style
patrickvonplaten May 1, 2020
f31b570
clean config
patrickvonplaten May 1, 2020
b2a660f
make style
patrickvonplaten May 1, 2020
dfc1f64
fix tests
patrickvonplaten May 1, 2020
2e95c17
fix auto tests
patrickvonplaten May 1, 2020
b95f6ae
pretrained models
patrickvonplaten May 1, 2020
a6f69cb
fix docstring
patrickvonplaten May 1, 2020
59868f3
update conversion file
patrickvonplaten May 1, 2020
a81c3e0
Update pretrained_models.rst
patrickvonplaten May 1, 2020
c0ddf94
fix rst
patrickvonplaten May 1, 2020
62a8eb0
fix rst
patrickvonplaten May 1, 2020
47e5fc8
update copyright
patrickvonplaten May 1, 2020
b6576c8
fix test path
patrickvonplaten May 1, 2020
a111720
fix test path
patrickvonplaten May 1, 2020
ff5e783
fix small issue in test
patrickvonplaten May 1, 2020
f7f949b
include reformer in generation tests
patrickvonplaten May 2, 2020
91472b8
add docs for axial position encoding
patrickvonplaten May 2, 2020
6ed2fa8
finish docs
patrickvonplaten May 2, 2020
963bb5e
Update convert_reformer_trax_checkpoint_to_pytorch.py
patrickvonplaten May 2, 2020
425b185
remove isort
patrickvonplaten May 2, 2020
3336d8f
include sams comments
patrickvonplaten May 3, 2020
54eb629
remove wrong comment in utils
patrickvonplaten May 3, 2020
e4e1e59
correct typos
patrickvonplaten May 3, 2020
5f5c89b
fix typo
patrickvonplaten May 3, 2020
7fdf16b
Update reformer.rst
patrickvonplaten May 4, 2020
7ccec6a
applied morgans optimization
patrickvonplaten May 4, 2020
3978afa
make style
patrickvonplaten May 4, 2020
01b1006
make gpu compatible
patrickvonplaten May 4, 2020
e983a69
remove bogus file
patrickvonplaten May 4, 2020
9ce32f0
big test refactor
patrickvonplaten May 4, 2020
67f02c0
add example for chunking
patrickvonplaten May 4, 2020
4e7252a
fix typo
patrickvonplaten May 4, 2020
ca4dab3
add to README
patrickvonplaten May 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,9 @@ At some point in the future, you'll be able to seamlessly move from pre-training
16. **[BART](https://huggingface.co/transformers/model_doc/bart.html)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/pdf/1910.13461.pdf) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer.
17. **[ELECTRA](https://huggingface.co/transformers/model_doc/electra.html)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning.
18. **[DialoGPT](https://huggingface.co/transformers/model_doc/dialogpt.html)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan.
18. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users).
19. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
19. **[Reformer](https://huggingface.co/transformers/model_doc/reformer.html)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
20. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users).
21. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.

These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html).

Expand Down
11 changes: 11 additions & 0 deletions docs/source/glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,14 @@ positional embeddings.

Absolute positional embeddings are selected in the range ``[0, config.max_position_embeddings - 1]``. Some models
use other types of positional embeddings, such as sinusoidal position embeddings or relative position embeddings.


Feed Forward Chunking
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
--------------------------

In transformers two feed forward layers usually follows the self attention layer in each residual attention block. The intermediate embedding size of the feed forward layers is often bigger than the hidden size of the model (*e.g.* for ``bert-base-uncased``).

For an input of size ``[batch_size, sequence_length]``, the memory required to store the intermediate feed forward embeddings ``[batch_size, sequence_length, config.intermediate_size]`` can account for a large fraction of the memory use. The authors of `Reformer: The Efficient Transformer <https://arxiv.org/abs/2001.04451>`_ noticed that since the computation is independent of the ``sequence_length`` dimension, it is mathematically equivalent to compute the output embeddings of both feed forward layers ``[batch_size, config.hidden_size]_0, ..., [batch_size, config.hidden_size]_n`` individually and concat them afterward to ``[batch_size, sequence_length, config.hidden_size]`` with ``n = sequence_length``, which trades increased computation time against reduced memory use, but yields a mathematically **equivalent** result.

For models employing the function :func:`~.transformers.apply_chunking_to_forward`, the ``chunk_size`` defines the number of output embeddings that are computed in parallel and thus defines the trade-off between memory and time complexity.
If ``chunk_size`` is set to 0, no feed forward chunking is done.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,4 @@ The library currently contains PyTorch and Tensorflow implementations, pre-train
model_doc/t5
model_doc/electra
model_doc/dialogpt
model_doc/reformer
6 changes: 6 additions & 0 deletions docs/source/main_classes/model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ The base class ``PreTrainedModel`` implements the common methods for loading/sav
.. autoclass:: transformers.PreTrainedModel
:members:

``Helper Functions``
~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: transformers.apply_chunking_to_forward


``TFPreTrainedModel``
~~~~~~~~~~~~~~~~~~~~~

Expand Down
114 changes: 114 additions & 0 deletions docs/source/model_doc/reformer.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
Reformer
----------------------------------------------------
**DISCLAIMER:** This model is still a work in progress, if you see something strange,
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`_

Overview
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
~~~~~
The Reformer model was presented in `Reformer: The Efficient Transformer <https://https://arxiv.org/abs/2001.04451.pdf>`_ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
Here the abstract:

*Large Transformer models routinely achieve state-of-the-art results on a number of tasks but training these models can be prohibitively costly, especially on long sequences. We introduce two techniques to improve the efficiency of Transformers. For one, we replace dot-product attention by one that uses locality-sensitive hashing, changing its complexity from O(L^2) to O(Llog(L)), where L is the length of the sequence. Furthermore, we use reversible residual layers instead of the standard residuals, which allows storing activations only once in the training process instead of N times, where N is the number of layers. The resulting model, the Reformer, performs on par with Transformer models while being much more memory-efficient and much faster on long sequences.*

The Authors' code can be found `here <https://github.com/google/trax/tree/master/trax/models/reformer>`_ .

Axial Positional Encodings
~~~~~~~~~~~~~~~~~~~~
Axial Positional Encodings were first implemented in Google's `trax library <https://github.com/google/trax/blob/4d99ad4965bab1deba227539758d59f0df0fef48/trax/layers/research/position_encodings.py#L29>`_ and developed by the authors of this model's paper. In models that are treating very long input sequences, the conventional position id encodings store an embedings vector of size :math:`d` being the ``config.hidden_size`` for every position :math:`i, \ldots, n_s`, with :math:`n_s` being ``config.max_embedding_size``. *E.g.*, having a sequence length of :math:`n_s = 2^{19} \approx 0.5M` and a ``config.hidden_size`` of :math:`d = 2^{10} \approx 1000` would result in a position encoding matrix:

.. math::
X_{i,j}, \text{ with } i \in \left[1,\ldots, d\right] \text{ and } j \in \left[1,\ldots, n_s\right]

which alone has over 500M parameters to store. Axial positional encodings factorize :math:`X_{i,j}` into two matrices:

.. math::
X^{1}_{i,j}, \text{ with } i \in \left[1,\ldots, d^1\right] \text{ and } j \in \left[1,\ldots, n_s^1\right]

and

.. math::
X^{2}_{i,j}, \text{ with } i \in \left[1,\ldots, d^2\right] \text{ and } j \in \left[1,\ldots, n_s^2\right]

with:

.. math::
d = d^1 + d^2 \text{ and } n_s = n_s^1 \times n_s^2 .

Therefore the following holds:

.. math::
X_{i,j} = \begin{cases}
X^{1}_{i, k}, & \text{if }\ i < d^1 \text{ with } k = j \mod n_s^1 \\
X^{2}_{i - d^1, l}, & \text{if } i \ge d^1 \text{ with } l = \lfloor\frac{j}{n_s^1}\rfloor
\end{cases}

Intuitively, this means that a position embedding vector :math:`x_j \in \mathbb{R}^{d}` is now the composition of two factorized embedding vectors: :math:`x^1_{k, l} + x^2_{l, k}`, where as the ``config.max_embedding_size`` dimension :math:`j` is factorized into :math:`k \text{ and } l`.
This design ensures that each position embedding vector :math:`x_j` is unique.

Using the above example again, axial position encoding with :math:`d^1 = 2^5, d^2 = 2^5, n_s^1 = 2^9, n_s^2 = 2^{10}` can drastically reduced the number of parameters to :math:`2^{14} + 2^{15} \approx 49000` parameters.

In practice, the parameter ``config.axial_pos_embds_dim`` is set to ``list``:math:`(d^1, d^2)` which sum has to be equal to ``config.hidden_size`` and ``config.axial_pos_shape`` is set to ``list``:math:`(n_s^1, n_s^2)` and which product has to be equal to ``config.max_embedding_size`` which during training has to be equal to the ``sequence length`` of the ``input_ids``.



LSH Self Attention
~~~~~~~~~~~~~~~~~~~~
In Locality sensitive hashing (LSH) self attention the key and query projection weights are tied. Therefore, the key query embedding vectors are also tied.
LSH self attention uses the locality sensitive
hashing mechanism proposed in `Practical and Optimal LSH for Angular Distance <https://arxiv.org/abs/1509.02897>`_ to assign each of the tied key query embedding vectors to one of ``config.num_buckets`` possible buckets. The premise is that the more "similar" key query embedding vectors (in terms of *cosine similarity*) are to each other, the more likely they are assigned to the same bucket.
The accuracy of the LSH mechanism can be improved by increasing ``config.num_hashes`` or directly the argument ``num_hashes`` of the forward function so that the output of the LSH self attention better approximates the output of the "normal" full self attention.
The buckets are then sorted and chunked into query key embedding vector chunks each of length ``config.lsh_chunk_length``. For each chunk, the query embedding vectors attend to its key vectors (which are tied to themselves) and to the key embedding vectors of ``config.lsh_num_chunks_before`` previous neighboring chunks and ``config.lsh_num_chunks_after`` following neighboring chunks.
For more information, see the `original Paper <https://arxiv.org/abs/2001.04451>`_ or this great `blog post <https://www.pragmatic.ml/reformer-deep-dive/>`_.

Note that ``config.num_buckets`` can also be factorized into a ``list``:math:`(n_{\text{buckets}}^1, n_{\text{buckets}}^2)`. This way instead of assigning the query key embedding vectors to one of :math:`(1,\ldots, n_{\text{buckets}})` they are assigned to one of :math:`(1-1,\ldots, n_{\text{buckets}}^1-1, \ldots, 1-n_{\text{buckets}}^2, \ldots, n_{\text{buckets}}^1-n_{\text{buckets}}^2)`. This is crucial for very long sequences to save memory.

It is recommended to leave ``config.num_buckets=None``, so that depending on the sequence length, a good value for ``num_buckets`` are calculated on the fly.

Using LSH self attention, the memory and time complexity of the query-key matmul operation can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times \log(n_s))`, which usually represents the memory and time bottleneck in a transformer model, with :math:`n_s` being the sequence length.


Local Self Attention
~~~~~~~~~~~~~~~~~~~~
Local self attention is essentially a "normal" self attention layer with
key, query and value projections, but is chunked so that in each chunk of length ``config.local_chunk_length`` the query embedding vectors only attends to the key embedding vectors in its chunk and to the key embedding vectors of ``config.local_num_chunks_before`` previous neighboring chunks and ``config.local_num_chunks_after`` following neighboring chunks.

Using Local self attention, the memory and time complexity of the query-key matmul operation can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times \log(n_s))`, which usually represents the memory and time bottleneck in a transformer model, with :math:`n_s` being the sequence length.


Training
~~~~~~~~~~~~~~~~~~~~
During training, we must ensure that the sequence length is set to a value that can be divided by the least common multiple of ``config.lsh_chunk_length`` and ``config.local_chunk_length`` and that the parameters of the Axial Positional Encodings are correctly set as described above. Reformer is very memory efficient so that the model can easily be trained on sequences as long as 64000 tokens.
For training, the ``ReformerModelWithLMHead`` should be used as follows:

::

input_ids = tokenizer.encode('This is a sentence from the training data', return_tensors='pt')
loss = model(input_ids, labels=input_ids)[0]


ReformerConfig
~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.ReformerConfig
:members:


ReformerTokenizer
~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.ReformerTokenizer
:members:


ReformerModel
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.ReformerModel
:members:
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved


ReformerModelWithLMHead
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.ReformerModelWithLMHead
:members:
3 changes: 3 additions & 0 deletions docs/source/pretrained_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,6 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
| | ``DialoGPT-large`` | | 36-layer, 1280-hidden, 20-heads, 774M parameters |
| | | | Trained on English text: 147M conversation-like exchanges extracted from Reddit. |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| Reformer | ``reformer-crime-and-punishment`` | | 6-layer, 256-hidden, 2-heads, 3M parameters |
| | | | Trained on English text: Crime and Punishment novel by Fyodor Dostoyevsky |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
13 changes: 12 additions & 1 deletion src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from .configuration_marian import MarianConfig
from .configuration_mmbt import MMBTConfig
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
Expand Down Expand Up @@ -138,6 +139,7 @@
from .tokenization_flaubert import FlaubertTokenizer
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_reformer import ReformerTokenizer
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from .tokenization_t5 import T5Tokenizer
from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer, TransfoXLTokenizerFast
Expand All @@ -159,7 +161,7 @@

# Modeling
if is_torch_available():
from .modeling_utils import PreTrainedModel, prune_layer, Conv1D, top_k_top_p_filtering
from .modeling_utils import PreTrainedModel, prune_layer, Conv1D, top_k_top_p_filtering, apply_chunking_to_forward
from .modeling_auto import (
AutoModel,
AutoModelForPreTraining,
Expand Down Expand Up @@ -190,6 +192,7 @@
BertForQuestionAnswering,
load_tf_weights_in_bert,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BertLayer,
)
from .modeling_openai import (
OpenAIGPTPreTrainedModel,
Expand Down Expand Up @@ -320,6 +323,14 @@
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
)

from .modeling_reformer import (
ReformerAttention,
ReformerLayer,
ReformerModel,
ReformerModelWithLMHead,
REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
)

# Optimization
from .optimization import (
AdamW,
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,18 @@ def gelu_new(x):
else:
gelu = F.gelu


def gelu_fast(x):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this different than the builtins? Maybe add docstring?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also if it isn't faster, does it need a new name?

return 0.5 * x * (1 + torch.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x)))


ACT2FN = {
"relu": F.relu,
"swish": swish,
"gelu": gelu,
"tanh": torch.tanh,
"gelu_new": gelu_new,
"gelu_fast": gelu_fast,
}


Expand Down
3 changes: 3 additions & 0 deletions src/transformers/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_reformer import ReformerConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
Expand Down Expand Up @@ -73,6 +74,7 @@
("camembert", CamembertConfig,),
("xlm-roberta", XLMRobertaConfig,),
("bart", BartConfig,),
("reformer", ReformerConfig,),
("roberta", RobertaConfig,),
("flaubert", FlaubertConfig,),
("bert", BertConfig,),
Expand Down Expand Up @@ -130,6 +132,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
- contains `camembert`: :class:`~transformers.CamembertConfig` (CamemBERT model)
- contains `xlm-roberta`: :class:`~transformers.XLMRobertaConfig` (XLM-RoBERTa model)
- contains `roberta`: :class:`~transformers.RobertaConfig` (RoBERTa model)
- contains `reformer`: :class:`~transformers.ReformerConfig` (Reformer model)
- contains `bert`: :class:`~transformers.BertConfig` (Bert model)
- contains `openai-gpt`: :class:`~transformers.OpenAIGPTConfig` (OpenAI GPT model)
- contains `gpt2`: :class:`~transformers.GPT2Config` (OpenAI GPT-2 model)
Expand Down
Loading