Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can I training a bart model from scratch by transformers? #5096

Closed
ScottishFold007 opened this issue Jun 18, 2020 · 21 comments
Closed

Can I training a bart model from scratch by transformers? #5096

ScottishFold007 opened this issue Jun 18, 2020 · 21 comments

Comments

@ScottishFold007
Copy link
Contributor

Can I training a bart model from scratch by transformers?

@patrickvonplaten
Copy link
Contributor

Yes

@ScottishFold007
Copy link
Contributor Author

Yes

That' s awesome!Can you give a code to show? I'm grateful!

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Jun 18, 2020

So from the paper: https://arxiv.org/pdf/1910.13461.pdf, you can see that Bart is trained on denoising input sequences in almost any possible way.

One way could be for BartForConditionalGeneration:

from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig

tok = BartTokenizer.from_pretrained("facebook/bart-large")
model = BartForConditionalGeneration(BartConfig())

input_string = "My dog is <mask> </s>"
decoder_input_string = "<s> My dog is cute"
labels_string = "My dog is cute </s>"

input_ids = tok(input_string, add_special_tokens=False, return_tensors="pt").input_ids
decoder_input_ids =tok(decoder_input_string, add_special_tokens=False, return_tensors="pt").input_ids
labels = tok(labels_string, add_special_tokens=False, return_tensors="pt").input_ids
 
loss = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, labels=labels)[0]

@patrickvonplaten
Copy link
Contributor

Pinging @sshleifer to make sure I did not forget anything

@ScottishFold007
Copy link
Contributor Author

ScottishFold007 commented Jun 18, 2020

Pinging @sshleifer to make sure I did not forget anything

Actually, I was going to ask. how train a model from zero to one. For example, I want to train a Chinese bart model.

@tomhosking
Copy link
Contributor

tomhosking commented Sep 2, 2020

Here's a working example for this, including batching:

from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig

tok = BartTokenizer.from_pretrained("facebook/bart-large")
model = BartForConditionalGeneration(BartConfig())

input_batch = ["My dog is <mask></s>", "It loves to play in the <mask></s>"]
decoder_input_batch = ["<s>My dog is cute", "<s>It loves to play in the park"]
labels_batch = ["My dog is cute</s>", "It loves to play in the park</s>"]

input_ids = tok.batch_encode_plus(input_batch, add_special_tokens=False, return_tensors="pt", padding=True).input_ids
decoder_input_ids = tok.batch_encode_plus(decoder_input_batch, add_special_tokens=False, return_tensors="pt", padding=True).input_ids
labels = tok.batch_encode_plus(labels_batch, add_special_tokens=False, return_tensors="pt", padding=True).input_ids

loss = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, labels=labels)[0]

>>> tensor(10.9981, device='cuda:0', grad_fn=<NllLossBackward>)

@ScottishFold007
Copy link
Contributor Author

Here's a working example for this, including batching:

from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig

tok = BartTokenizer.from_pretrained("facebook/bart-large")
model = BartForConditionalGeneration(BartConfig())

input_batch = ["My dog is <mask></s>", "It loves to play in the <mask></s>"]
decoder_input_batch = ["<s>My dog is cute", "<s>It loves to play in the park"]
labels_batch = ["My dog is cute</s>", "It loves to play in the park</s>"]

input_ids = tok.batch_encode_plus(input_batch, add_special_tokens=False, return_tensors="pt", padding=True).input_ids
decoder_input_ids = tok.batch_encode_plus(decoder_input_batch, add_special_tokens=False, return_tensors="pt", padding=True).input_ids
labels = tok.batch_encode_plus(labels_batch, add_special_tokens=False, return_tensors="pt", padding=True).input_ids

loss = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, labels=labels)[0]

>>> tensor(10.9981, device='cuda:0', grad_fn=<NllLossBackward>)

input_batch = ["My dog is ", "It loves to play in the "]
decoder_input_batch = ["My dog is cute", "It loves to play in the park"]
labels_batch = ["My dog is cute
", "It loves to play in the park
"]

If I have a text document, each line of a paragraph, how do I rewrite the data input on it? Thanks!

@swethmandava
Copy link
Contributor

swethmandava commented Dec 17, 2020

@tomhosking the paper indicates that it uses both sentence permutation (loss is propagated from all tokens instead of only masked tokens) and infilling (include only one mask token for multiple consecutive masks). would this be a correct input?

input_batch = ["<s>It is <mask> retriever. My dog is <mask></s>", "<s>There <mask> in SF. It loves to play in the <mask></s>"]
decoder_input_batch = ["</s><s>My dog is cute. It is a golden retriever", "</s><s>It loves to play in the park. There are many parks in SF."]
labels_batch = ["<s>My dog is cute. It is a golden retriever</s>", "<s>It loves to play in the park. There are many parks in SF.</s>"]

(Note: decoder_input_batch starts with </s><s> due to shift_tokens_right #7961)

@jonatasgrosman
Copy link
Contributor

Sorry for the intrusion, but I think your values are almost correct @swethmandava, except for the masking absence

input_batch = ["<s>It <mask> retriever. My <mask> cute </s>", ... ]
decoder_input_batch = ["</s><s>My dog is cute. It is a golden retriever", ...]
labels_batch = ["<s>My dog is cute. It is a golden retriever</s>", ...]

BTW: This </s> token at the beginning of decode's input is kind of weird to me, but it's inherited from the fairseq original code. If you wanna train the model from scratch with random weights I think you can go without this... or maybe this trick is important for convergence, we never know 😁

@HuipengXu
Copy link

Will only 15% mask in the encoder input cause some kind of leakage? The language model in the decoder cannot learn correctly

@prajdabre
Copy link

If anyone wants to train their MBART model then feel free to use this.
https://github.com/prajdabre/yanmtt

Contributions are welcome!

@jbmaxwell
Copy link

Sorry for the intrusion, but I think your values are almost correct @swethmandava, except for the masking absence

input_batch = ["<s>It <mask> retriever. My <mask> cute </s>", ... ]
decoder_input_batch = ["</s><s>My dog is cute. It is a golden retriever", ...]
labels_batch = ["<s>My dog is cute. It is a golden retriever</s>", ...]

BTW: This </s> token at the beginning of decode's input is kind of weird to me, but it's inherited from the fairseq original code. If you wanna train the model from scratch with random weights I think you can go without this... or maybe this trick is important for convergence, we never know 😁

I have a non-natural language dataset where I haven't actually been including <s> and </s> since they don't add any value (and need to be removed later anyway). To work with that, should I insert a pad token at the start of the decoder_input representation (and truncate to max_length)?

@Haiming94
Copy link

So from the paper: https://arxiv.org/pdf/1910.13461.pdf, you can see that Bart is trained on denoising input sequences in almost any possible way.

One way could be for BartForConditionalGeneration:

from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig

tok = BartTokenizer.from_pretrained("facebook/bart-large")
model = BartForConditionalGeneration(BartConfig())

input_string = "My dog is <mask> </s>"
decoder_input_string = "<s> My dog is cute"
labels_string = "My dog is cute </s>"

input_ids = tok(input_string, add_special_tokens=False, return_tensors="pt").input_ids
decoder_input_ids =tok(decoder_input_string, add_special_tokens=False, return_tensors="pt").input_ids
labels = tok(labels_string, add_special_tokens=False, return_tensors="pt").input_ids
 
loss = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, labels=labels)[0]

Hi, do you have a script to build the training dataset of BART pertain, thanks

@BramVanroy
Copy link
Collaborator

@patrickvonplaten @sshleifer Did anyone ever come around to creating a notebook/script for BART pretraining? (In a linked issue you mentioned it was on the to-do list.)

The core difficulty is having a canonical implementation for the data preprocessing (BART is more than just token masking, I believe: e.g.,span masking, shuffling). But a full pretrain pipeline here or in fairseq is also sorely missing.

@patrickvonplaten
Copy link
Contributor

Sadly not :-/ We now have on for Flax in #18297 - could you try to copy-paste the preprocessing logic into a PyTorch one maybe?

@BramVanroy
Copy link
Collaborator

@patrickvonplaten I've been porting the fairseq implementation to a PyTorch dataloader format. I found that the Flax implementation in HF lacks adding noise for 0-length spans and has some slightly diverging implementation so it was more straightforward to start from the fairseq implementation. I am now especially testing the data processing to get it as close as possible to fairseq's implementation (although it is my believe that there's a bug in their code).

I would like to add a full pytorch example for DLM training of BART in the coming days/weeks but I could use some code reviews in doing that to feel more comfortable. Would that be possible?

@patrickvonplaten
Copy link
Contributor

Sure, happy to take a look!

@prajdabre
Copy link

Hi

I remember posting this a year ago but I've written an entire toolkit for this purpose. Feel free to use it. https://github.com/prajdabre/yanmtt

I've also created a simple notebook for the same (scroll to the pretraining part): https://colab.research.google.com/drive/1ovlA_h0ggblawqR-yCgRs3uRjxFJ8K0l?usp=sharing

@BramVanroy
Copy link
Collaborator

Hi Raj, thank you for this. I had come across it but your script seems to have a lot of additional things going on so that it is hard to extract the basics. I also found that you implement word/span masking but not the other things like adding noise or randomly swap a masked token for a random token, so not completely like the original implementation (but correct me if I'm wrong!) .

I think your library can be very useful to be used as a separate library, thanks! In addition I'll try add a PR in transformers for an succinct example to use within transformers with the Trainer, with data processing close the fairseq implementation.

@prajdabre
Copy link

Hi,

My focus was more on mbart and mt5 which looked only at span masking and reordering. I'm not sure if token replacement will have that big of an impact but can be easily implemented in 1 line. To my understanding, span masking is responsible for majority of the gains. The notebook contains a more watered down version of the masking method in my toolkit. You could consider that version and build on top of it easily.

@CountingMstar
Copy link

Hey guys, I would want to know how to pre-training BART model from scratch. Anyone who know about this? BART, pegasus or other text summarization models are okay for me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests