Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reformer #3351

Merged
merged 162 commits into from
May 7, 2020
Merged

Reformer #3351

merged 162 commits into from
May 7, 2020

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Mar 19, 2020

Add the Reformer

Paper: (https://arxiv.org/pdf/2001.04451.pdf)

First steps to take:

  • Copy Bert PT code to Reformer PT file.
  • Replace self-attention with LSH attention
  • Make forward pass work for Bert Layer

Forward-Pass: Get 1-to-1 same outputs as original Flax code for forward pass

  • for LSH attention layer
  • for Bert Layer with RevNet
  • for different attention masks
  • for feed forward chunking layer
  • for whole Reformer model
  • for sinusoidal position encodings
  • for axial position encodings
  • for local blocking attention (chunked attention)
  • for pretrained weights from official reformer model: ReformerLM model was trained using https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/text_generation.ipynb and weights were loaded into https://huggingface.co/patrickvonplaten/reformer-crime-and-punish and checked that a single forward pass is identical. predict_mem_len had to be adapted to make functions equal.
  • Add optional attention mask
  • Add support for fp16
  • Speed up incremental generation. This is needed for generation and will not be trivial since the buckets have to ordered correctly and there is a chunk length parameter.

Backpropagation:

  • Make backpropagation work
  • Check that backpropagation works with chunked feed forward layers
  • Implement RevResLayers for backprop
  • Check that RevReslayers backprop works on CPU
  • Check that RevReslayers backprop works on GPU
  • Get same gradients as original trax code
  • Train model on crime-and-punishment text and check that model performs reasonable afterwards

Tokenizer

Optimize time and memory efficiency

Pretrained Models

  • Check if pretrained model on C4 is added soon: google/trax@b1f0c17

  • Add Reformer / Bert in trax

Useful code resources:

Useful blog/paper resources:

Previous Discussions:

Update

The code is clean and ready for review now.
Small ToDos before merging:

  • Fill in TODOs in docs
  • Check whether more pre-trained weights can be used
  • Train on fp16 once
  • Update notebook showing how to use Reformer

Review

I added quite some docstrings to explain the new methods introduced by the Reformer (Axial Position Encoding, LSH Attention, Local Attention, Feed Forward chunking), so it might be better to first go through the doctsrings. Docstrings are easier to read when switching to this branch and creating the docs locally.

@patrickvonplaten patrickvonplaten force-pushed the reformer_add_model branch 2 times, most recently from 60e5d9c to 2f3afad Compare April 6, 2020 11:33
@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Apr 7, 2020

Memory complexity ReformerLayer
vs BertLayer:
Figure_Memory_GPU_bs_1

@patrickvonplaten
Copy link
Contributor Author

Time complexity ReformerLayer vs. BertLayer:

Figure_Memory_GPU_bs_1_time

@patrickvonplaten patrickvonplaten force-pushed the reformer_add_model branch 2 times, most recently from a4b0cce to 19d6f70 Compare April 14, 2020 11:59
@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Apr 28, 2020

Experiment

I tested training the Reformer model on 0.5M tokens per sample on the novel "Crime and Punishment" using conventional LM training. I essentially translated the official trax notebook: https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/text_generation.ipynb into hugging face code: https://colab.research.google.com/drive/1jR6hA2CQXDbucJXdiDXhmxmyoQmM2Pws

The only differences to the official notebook are:

  • The gradient is accumulated over 8 samples and then updated whereas in the official notebook 8 TPUs are used and the gradient is calculated in parallel and then averaged together.

  • The learning rate is 0.005 instead of 0.01 (because already at 0.005, the gradient seems to become too big).

Results

My training starts similarly around 6.2 and goes down smoothly in the beginning.
At some point though the gradient seem to explode and the loss goes up again and that even at a learning rate of "only" 0.05.

The attached plots are here:

Loss

eval_loss

Accuracy

eval_accuracy

Learning rate (cosine scheduler)

learning_rate

When lowering the learning rate more, to 0.0005 e.g. the loss keeps going down but only reaches something around 2.3 in the end.

Comparison

The training in the official trax notebook is very smooth.
Loss starts at 6.2 something and goes down smoothly to 0.8 while the accuracy reaches >80% in the end for a learning rate of 0.01.

Analysis

  • It is confirmed that the forward pass is identical with the trax implementation thanks to integration tests. Things that are not fully tested for the backward pass are:

    • Dropout: the dropout used in the official trax library does not seem to correspond to the "usual" nn.Dropout used in PyTorch but sometimes drop specific dimensions only or whole matrices. It is tested though that the dropout used here is deterministic for both the "normal" forward pass and the forward pass used in the backward pass to recalculate the activations, by means of setting the random seed used for the first forward pass. Nevertheless, there could still be small bugs.

    • Reversible Layers: Because Reformer uses reversible layers, I had to fiddle with a customized backward function here. This is IMO quite prone to errors. I checked multiple times that from a logical point of view everything is correct and compared my code with: https://github.com/RobinBruegger/RevTorch and https://github.com/lucidrains/reformer-pytorch which do similar / the same architecture. IMO, it quite hard to test this for correctness. One could also write the whole code without having reversible layers and then see whether the gradient is the same (Seems actually not like a bad idea to me).

    • Attention mask: The official trax code does not seem to use a user-specific attention mask for the LSH Attn Layer, but only for the Local Attn Layer. I tested that the attn mask is correct for the local attention task by integration tests and checked that the attn mask for the LSH layer works correctly (input with mask gives the same result as input without mask), but maybe the LSH Attn mask has to be removed. But don't really see a reason why ?!

    • Initialization: The initialization scheme used in the trax library is different from what is normally done in transformers, so there are small changes in my code. But I doubt that this is an issue, especially since the training looks very similar in the beginning.

    • Training parameters: It might also be simply due to different training / optimization parameters. Maybe there are some under-the-hood training parameters that I didn't notice (special gradient clipping, ...)

@flozi00
Copy link
Contributor

flozi00 commented Apr 28, 2020

https://colab.research.google.com/drive/1jR6hA2CQXDbucJXdiDXhmxmyoQmM2Pws

Tried to train model over longer time, but getting error

Forward got unexcepted keyword "lm_labels" after calling trainer.train()

P: Fixed the typo. I will change the model into half-precision soon so that the memory will be sufficient :-)

@flozi00
Copy link
Contributor

flozi00 commented Apr 28, 2020

I get some good results with the following parameters: https://gist.github.com/flozi00/b491b41a9865733e5f8bb4032c313540

the best eval loss is about 1.654, but is increasing now again the same as yours
will have a look in a few hours again

Screenshot_1

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Apr 28, 2020

I get some good results with the following parameters: https://gist.github.com/flozi00/b491b41a9865733e5f8bb4032c313540

the best eval loss is about 1.654, but is increasing now again the same as yours
will have a look in a few hours again

Screenshot_1

Awesome that's already much better than what I got! If you manage to get it under 1 (loss) / >75% (accuracy) that would be great. Also feel free to change the hyper-parameters as you wish! Especially the adam betas and co.

I also added support for fp16 - so the notebook now only needs 8GB of RAM.

(You might have to reset the environment and re-install the github branch though)

@flozi00
Copy link
Contributor

flozi00 commented Apr 28, 2020

Sounds very great.
Trying to decrease sequence length, cause while increasing number of hashes or heads getting memory error.
Training on 24GB GPU

Read that 4 hashes are good and 8 brings the best quality.

Trained on some configurations now and everytime the loss goes to ~1 but then increases to 4 very fast and keeps on there for minimum 1000 steps.
Any idea about it ?

@patrickvonplaten
Copy link
Contributor Author

Sounds very great.
Trying to decrease sequence length, cause while increasing number of hashes or heads getting memory error.
Training on 24GB GPU

Read that 4 hashes are good and 8 brings the best quality.

Trained on some configurations now and everytime the loss goes to ~1 but then increases to 4 very fast and keeps on there for minimum 1000 steps.
Any idea about it ?

My guess is that since it's such a small dataset (0.5M tokens is tiny) the model needs very well-calibrated hyperparameter tuning. When the learning rate is low enough, this actually does not happen anymore but also the loss only gets to about ~2. But I ran very few experiments and didn't do any hyperparameter search.
Also, I use slightly different dropouts, then were used in the official code so maybe using weight decay instead of dropout could work better.

Will check that the gradients are correct in the next days and then hopefully be ready soon.

@nkitaev
Copy link

nkitaev commented Apr 28, 2020

@patrickvonplaten I'm excited to see a lot of progress here!

The loss curves above could be due to poor hyperparameter choice, but they're also very similar to what you see when the reverse pass of the network doesn't match the forward pass. For example, failing to cache hash bucket assignments (for exact re-use in the backward pass) leads to a failure mode with loss rebounds very similar to the figures you posted above. I also once had a bug where the wrong random seed was used for dropout in the backward pass, which IIRC manifested itself in the same way.

@patrickvonplaten
Copy link
Contributor Author

@patrickvonplaten I'm excited to see a lot of progress here!

The loss curves above could be due to poor hyperparameter choice, but they're also very similar to what you see when the reverse pass of the network doesn't match the forward pass. For example, failing to cache hash bucket assignments (for exact re-use in the backward pass) leads to a failure mode with loss rebounds very similar to the figures you posted above. I also once had a bug where the wrong random seed was used for dropout in the backward pass, which IIRC manifested itself in the same way.

Thanks for taking a look @nkitaev. I just found a bug in the backward(). I now have 1-to-1 the same gradients as your trax code. Will retrain tonight and should get better results :-)

@patrickvonplaten patrickvonplaten changed the title [WIP] check and possibly add the Reformer [WIP] Reformer Apr 30, 2020
@sshleifer sshleifer self-requested a review May 1, 2020 16:40
@codecov-io
Copy link

codecov-io commented May 1, 2020

Codecov Report

Merging #3351 into master will not change coverage.
The diff coverage is n/a.

Impacted file tree graph

@@           Coverage Diff           @@
##           master    #3351   +/-   ##
=======================================
  Coverage   79.13%   79.13%           
=======================================
  Files         117      117           
  Lines       19517    19517           
=======================================
  Hits        15444    15444           
  Misses       4073     4073           

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 7b5bec3...7b5bec3. Read the comment docs.

@patrickvonplaten patrickvonplaten changed the title [WIP] Reformer Reformer May 1, 2020
@patrickvonplaten
Copy link
Contributor Author

Training looks good now on Crime and Punishment. To verify that training works, the model was trained on over 1200 steps and with little regularization.

Eval loss

eval_loss-1

Eval accuracy

eval_accuracy-1

@patrickvonplaten patrickvonplaten merged commit dca3469 into master May 7, 2020
@patrickvonplaten patrickvonplaten deleted the reformer_add_model branch May 7, 2020 08:17
@theblackcat102
Copy link
Contributor

@patrickvonplaten Based on your merge, it seems like the input size for each batch is fixed in order to match the product of axial position embedding size? I am correct?

@patrickvonplaten
Copy link
Contributor Author

@patrickvonplaten Based on your merge, it seems like the input size for each batch is fixed in order to match the product of axial position embedding size? I am correct?

For training, yes that's correct. For inference the input_size can also be smaller. Also check out: https://huggingface.co/transformers/model_doc/reformer.html

mfuntowicz pushed a commit that referenced this pull request May 8, 2020
* first copy & past commit from Bert and morgans LSH code

* add easy way to compare to trax original code

* translate most of function

* make trax lsh self attention deterministic with numpy seed + copy paste code

* add same config

* add same config

* make layer init work

* implemented hash_vectors function for lsh attention

* continue reformer translation

* hf LSHSelfAttentionLayer gives same output as trax layer

* refactor code

* refactor code

* refactor code

* refactor

* refactor + add reformer config

* delete bogus file

* split reformer attention layer into two layers

* save intermediate step

* save intermediate step

* make test work

* add complete reformer block layer

* finish reformer layer

* implement causal and self mask

* clean reformer test and refactor code

* fix merge conflicts

* fix merge conflicts

* update init

* fix device for GPU

* fix chunk length init for tests

* include morgans optimization

* improve memory a bit

* improve comment

* factorize num_buckets

* better testing parameters

* make whole model work

* make lm model work

* add t5 copy paste tokenizer

* add chunking feed forward

* clean config

* add improved assert statements

* make tokenizer work

* improve test

* correct typo

* extend config

* add complexer test

* add new axial position embeddings

* add local block attention layer

* clean tests

* refactor

* better testing

* save intermediate progress

* clean test file

* make shorter input length work for model

* allow variable input length

* refactor

* make forward pass for pretrained model work

* add generation possibility

* finish dropout and init

* make style

* refactor

* add first version of RevNet Layers

* make forward pass work and add convert file

* make uploaded model forward pass work

* make uploaded model forward pass work

* refactor code

* add namedtuples and cache buckets

* correct head masks

* refactor

* made reformer more flexible

* make style

* remove set max length

* add attention masks

* fix up tests

* fix lsh attention mask

* make random seed optional for the moment

* improve memory in reformer

* add tests

* make style

* make sure masks work correctly

* detach gradients

* save intermediate

* correct backprob through gather

* make style

* change back num hashes

* rename to labels

* fix rotation shape

* fix detach

* update

* fix trainer

* fix backward dropout

* make reformer more flexible

* fix conflict

* fix

* fix

* add tests for fixed seed in reformer layer

* fix trainer typo

* fix typo in activations

* add fp16 tests

* add fp16 training

* support fp16

* correct gradient bug in reformer

* add fast gelu

* re-add dropout for embedding dropout

* better naming

* better naming

* renaming

* finalize test branch

* finalize tests

* add more tests

* finish tests

* fix

* fix type trainer

* fix fp16 tests

* fix tests

* fix tests

* fix tests

* fix issue with dropout

* fix dropout seeds

* correct random seed on gpu

* finalize random seed for dropout

* finalize random seed for dropout

* remove duplicate line

* correct half precision bug

* make style

* refactor

* refactor

* docstring

* remove sinusoidal position encodings for reformer

* move chunking to modeling_utils

* make style

* clean config

* make style

* fix tests

* fix auto tests

* pretrained models

* fix docstring

* update conversion file

* Update pretrained_models.rst

* fix rst

* fix rst

* update copyright

* fix test path

* fix test path

* fix small issue in test

* include reformer in generation tests

* add docs for axial position encoding

* finish docs

* Update convert_reformer_trax_checkpoint_to_pytorch.py

* remove isort

* include sams comments

* remove wrong comment in utils

* correct typos

* fix typo

* Update reformer.rst

* applied morgans optimization

* make style

* make gpu compatible

* remove bogus file

* big test refactor

* add example for chunking

* fix typo

* add to README
tianleiwu pushed a commit to tianleiwu/transformers that referenced this pull request May 8, 2020
* first copy & past commit from Bert and morgans LSH code

* add easy way to compare to trax original code

* translate most of function

* make trax lsh self attention deterministic with numpy seed + copy paste code

* add same config

* add same config

* make layer init work

* implemented hash_vectors function for lsh attention

* continue reformer translation

* hf LSHSelfAttentionLayer gives same output as trax layer

* refactor code

* refactor code

* refactor code

* refactor

* refactor + add reformer config

* delete bogus file

* split reformer attention layer into two layers

* save intermediate step

* save intermediate step

* make test work

* add complete reformer block layer

* finish reformer layer

* implement causal and self mask

* clean reformer test and refactor code

* fix merge conflicts

* fix merge conflicts

* update init

* fix device for GPU

* fix chunk length init for tests

* include morgans optimization

* improve memory a bit

* improve comment

* factorize num_buckets

* better testing parameters

* make whole model work

* make lm model work

* add t5 copy paste tokenizer

* add chunking feed forward

* clean config

* add improved assert statements

* make tokenizer work

* improve test

* correct typo

* extend config

* add complexer test

* add new axial position embeddings

* add local block attention layer

* clean tests

* refactor

* better testing

* save intermediate progress

* clean test file

* make shorter input length work for model

* allow variable input length

* refactor

* make forward pass for pretrained model work

* add generation possibility

* finish dropout and init

* make style

* refactor

* add first version of RevNet Layers

* make forward pass work and add convert file

* make uploaded model forward pass work

* make uploaded model forward pass work

* refactor code

* add namedtuples and cache buckets

* correct head masks

* refactor

* made reformer more flexible

* make style

* remove set max length

* add attention masks

* fix up tests

* fix lsh attention mask

* make random seed optional for the moment

* improve memory in reformer

* add tests

* make style

* make sure masks work correctly

* detach gradients

* save intermediate

* correct backprob through gather

* make style

* change back num hashes

* rename to labels

* fix rotation shape

* fix detach

* update

* fix trainer

* fix backward dropout

* make reformer more flexible

* fix conflict

* fix

* fix

* add tests for fixed seed in reformer layer

* fix trainer typo

* fix typo in activations

* add fp16 tests

* add fp16 training

* support fp16

* correct gradient bug in reformer

* add fast gelu

* re-add dropout for embedding dropout

* better naming

* better naming

* renaming

* finalize test branch

* finalize tests

* add more tests

* finish tests

* fix

* fix type trainer

* fix fp16 tests

* fix tests

* fix tests

* fix tests

* fix issue with dropout

* fix dropout seeds

* correct random seed on gpu

* finalize random seed for dropout

* finalize random seed for dropout

* remove duplicate line

* correct half precision bug

* make style

* refactor

* refactor

* docstring

* remove sinusoidal position encodings for reformer

* move chunking to modeling_utils

* make style

* clean config

* make style

* fix tests

* fix auto tests

* pretrained models

* fix docstring

* update conversion file

* Update pretrained_models.rst

* fix rst

* fix rst

* update copyright

* fix test path

* fix test path

* fix small issue in test

* include reformer in generation tests

* add docs for axial position encoding

* finish docs

* Update convert_reformer_trax_checkpoint_to_pytorch.py

* remove isort

* include sams comments

* remove wrong comment in utils

* correct typos

* fix typo

* Update reformer.rst

* applied morgans optimization

* make style

* make gpu compatible

* remove bogus file

* big test refactor

* add example for chunking

* fix typo

* add to README
@prajwal-PHAI
Copy link

@patrickvonplaten , I wanted to train a language model for reformers on a custom dataset.
What are the steps, and any sample notebooks available for the same

@LysandreJik
Copy link
Member

Hi @prajwal-PHAI, there are a lot of community notebooks covering T5 finetuning.

@prajwal-PHAI
Copy link

Thanks @LysandreJik
I was running into error loading other datasets, which were not there in the nlp library.

@srulikbd
Copy link

hey. thanks for your amazing work!
I'm running into error while trying the colab example:
https://colab.research.google.com/drive/1jR6hA2CQXDbucJXdiDXhmxmyoQmM2Pws#scrollTo=WskGtnXsnWdu

the problem is that it doesn't recognize the apex package:

ImportError Traceback (most recent call last)
in ()
11
12 # train
---> 13 trainer.train()

/usr/local/lib/python3.6/dist-packages/transformers/trainer.py in train(self, model_path)
384 if self.args.fp16:
385 if not is_apex_available():
--> 386 raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
387 model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level)
388

ImportError: Please install apex from https://www.github.com/nvidia/apex to use fp16 training.

though I installed it...anyone know what to do?

@leo-liuzy
Copy link

Linking a related git issue #16972. cc @patrickvonplaten

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.