Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BART DLM PyTorch pretraining example #18904

Closed
wants to merge 5 commits into from
Closed

Add BART DLM PyTorch pretraining example #18904

wants to merge 5 commits into from

Conversation

BramVanroy
Copy link
Collaborator

@BramVanroy BramVanroy commented Sep 6, 2022

Implements a pretraining example for BART (denoising language model). Big focus on getting the data denoising as close to the original fairseq as possible but instead of on the dataset level on the dataloader level.

Heavily inspired by the fairseq implementation and the FLAX implementation. (See HF (Flax), fairseq, and current implementation.) Looking for some feedback. Please see Questions/Uncertainties.

Some notes

Default values

The defaults are set to the given BART args. This differs from the Flax defaults in one respect, namely poisson_lambda, which is now set to 3.5 instead of 3.0.

HF (Flax), fairseq, and current implementation

There are some differences in implementation between fairseq, the HF FLAX example, and this PyTorch implementation.

  • argwhere in the Flax example
    in this position
    is not the same as what is happening in fairseq. In fairseq
    we check explicitly that the previous token was not a "full stop" (padding token) but in HF we just check whether the
    current token is a full stop. In the current example I also explicitly check that the next token is not a full stop,
    in case of padding. (However, in practice that should be a non-issue since all batches/samples should have the
    same sequence length and there should not be any padding.)
  • I found that the result of sentence permutation was not consistent in terms of where the separating pad token ended
    up (bug report), so I have reimplemented that method so
    that sentences in a sequence are still separated by a padding token, even after permutation.
  • In HF FLAX, the token_mask is restricted to non-special and non-padding tokens.
    In Fairseq, by default, only the first and last tokens are excluded and all others
    are prone to masking. The HF implementation seems sensible so I follow that. get_special_tokens_mask includes the
    padding token, though, so no need to add that separately.
  • The Flax example does not include methods to add more noise. I have ported those as well.
  • However, I did not adapt add_insertion_noise to work well with padded sequences. So the inserted noise may occur
    ANYWHERE. It is unclear whether this is intended behavior.

Alternatively, we could implement all this processing on the dataset level and use Dataset.map. This has some
advantages:

  • more true to fairseq implementation (sample level rather than batch level);
  • cached.

... and disadvantages:

  • potentially slower (not batched), although we can integrate a batched approach. But as discussed above, this will be
    less true to the original fairseq implementation in add_insertion_noise
  • every sample is always processed the same. So in small datasets which are seen multiple times by the model, the
    same sample will always be processed the same. In a dataloader, that will not be the case because the processing
    occurs on every iteration rather than once before training.

Questions/Uncertainties

  • Do the padding tokens still serve a purpose after permutation? (Teaching the model to learn to detect sentence boundaries?) They can get masked and noised.
  • It seems that add_insertion_noise can insert noise anywhere (also in fairseq), which means that it will also overwrite special
    tokens and that sequence don't necessarily end with a EOS token. Is that a problem?
  • I have now added auxiliary scripts for config/tokenizer creation when pre-training. Should I remove those? In the FLAX example, these steps are described inline but without a given script. So we could also just do that.
  • I have explicitly added fingerprints (hashed) because in the past I've come to encounter issues when using spaCy and Dataset.map (every time you load a spaCy model, it has a different hash so the processing will happen every time). I don't see a better way but feel free to share ideas. Maybe someone of the datasets team can chime in, too.

Before submitting

Who can review?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@BramVanroy BramVanroy closed this by deleting the head repository Sep 9, 2022
@sanchit-gandhi
Copy link
Contributor

Hey @BramVanroy! Thanks for making a start on this PR. In general, we aim to mirror the original repo's functionality as closely as possible. In this case, porting from fairseq is the way to go! So great to see your comments regarding consistency with fariseq, and yes to all of them! If indeed these changes are required, we'll need to update the Flax example accordingly.

We can batch samples with datasets.map by passing the num_workers arg. To pre-process samples on a specified number of CPU workers concurrently:

dataset = dataset.map(map_fn, num_workers=data_args.preprocessing_num_workers)

This I think is the way to go for processing the dataset being the closest to fariseq.

Adding auxiliary scripts for config/tokenizer creation is a great idea - all for it! Makes it far easier to reproduce and run the example :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants