Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPT-J #12243

Closed
wants to merge 67 commits into from
Closed

GPT-J #12243

wants to merge 67 commits into from

Conversation

StellaAthena
Copy link
Contributor

@StellaAthena StellaAthena commented Jun 18, 2021

This is a work-in-progress focused on reconciling styles and may break without warning. If you want to use GPT-J with the HF interface, you can do that by installing transformers from here. The purpose of this PR is to make progress on converting that repo to the style HF prefers.

What does this PR do?

This is my attempt to reconcile #12106 with the HF style guidelines as described by @sgugger. The original PR was created by @finetuneanon and @kurumuz.

This implementation has not been thoroughly tested yet, but I wanted to get something out as a starting point for continuing the conversation before too much momentum is lost. I need to reread HF documentation a bit more to figure out the things that are wrong, or hopefully one of you lovely people can help me out.

For comparison, a frozen version of the code in the original PR can be found here.

Before submitting

Who can review?

@patrickvonplaten @patil-suraj @sgugger

@StellaAthena
Copy link
Contributor Author

The main thing I'm uncertain about is how to handle unimplemented functionality. GPT-J uses the same tokenizer as GPT-2, so I removed the tokenizer definition. Is that correct, or no? Relatredly, there were many types of modeling that GPT-J was not designed for, and @finetuneanon's PR just deleted the boilerplate for them. Is this correct?

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing, thanks a lot for the PR Stella!

I left a few comments in the modeling file.

Regarding you questions

  • I'm uncertain about is how to handle unimplemented functionality
    The modeling template adds all types of head models (ForMLM, ForMultipleChoice) any such functionality that is not needed for GPT-J can be removed.
  • GPT-J uses the same tokenizer as GPT-2, so I removed the tokenizer definition. Is that correct, or no?
    Yes, we don't need to add a new tokenizer in this case. We can define the tokenizer association in the tokenization_auto.py file, as is done for GPTNeo
    (GPTNeoConfig, (GPT2Tokenizer, GPT2TokenizerFast)),

Another important thing is to add tests for the model. We could reuse the GPT2's tests from test_modeling_gpt2.py

Also, sorry to ask this again, but could we not modify generation in this PR, since it seems it's not related to GPT-J.

But great that you took over the PR, let us know if there's anything else we can help with :)

Comment on lines +60 to +66
def load_tf_weights_in_gptj(model, config, tf_checkpoint_path):
"""Load tf checkpoints in a pytorch model."""
try:
import re

import numpy as np
import tensorflow as tf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed, there is no TF model.

Comment on lines +172 to +185
def fixed_pos_embedding(dim=None, seq_len=None):
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2) / dim))
sinusoid_inp = torch.einsum('i , j -> i j', torch.arange(seq_len), inv_freq).float()
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)

Tdef rotate_every_two(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), axis=-1)
return rearrange(x, '... d j -> ... (d j)')

def apply_rotary_pos_emb(x, sincos, offset=0):
sin, cos = map(lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2), sincos)
return (x * cos) + (rotate_every_two(x) * sin)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to not add eionops dependency. Also we could add this as a static method in the attention class so that it can be tested easily. We could probably reuse this implementation

def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None):

Comment on lines +190 to +196
class GPTJAttentionMixin:
"""
A few attention related utilities for attention modules in GPT Neo, to be used as a mixin.
"""
def _split_heads(self, tensor, num_heads, attn_head_size, rotary):
"""
Splits hidden_size dim into attn_head_size and num_heads
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there is no local attention all of this can go in GPTJSelfAttention class

Comment on lines +202 to +203
if len(tensor.shape) == 5:
return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can also be removed, no local attention here :/

Comment on lines +213 to +214
if len(tensor.shape) == 5:
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above.

Comment on lines +257 to +261
if attention_type == "local":
self.register_buffer(
"bias",
bias ^ torch.tril(bias, -config.window_size),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

Comment on lines +367 to +368
if self.attention_type in ["global", "local"]:
self.attention = GPTJSelfAttention(self.attention_type, config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could remove all this GPTNeo related code from here.

@StellaAthena
Copy link
Contributor Author

Also, sorry to ask this again, but could we not modify generation in this PR, since it seems it's not related to GPT-J.

Damn. It looks like I messed something up.... this was supposed to not include @finetuneanon's commits. I might close this and create a replacement PR with the correct commit history.

@sualehasif
Copy link

Mmm, I was wondering how this has been going. I would love to try a stable version of this!

@patil-suraj
Copy link
Contributor

Hey @sualehasif

A stable version will be available in a week, stay tuned!

@mittalpatel
Copy link

Damn. It looks like I messed something up.... this was supposed to not include @finetuneanon's commits. I might close this and create a replacement PR with the correct commit history.

@StellaAthena any idea when would you be adding a new PR? We are also running some experiments so maybe we could help.

@patil-suraj
Copy link
Contributor

patil-suraj commented Jun 24, 2021

@mittalpatel

I'm taking over the PR. But feel free to post your findings :)

@StellaAthena
Copy link
Contributor Author

In #12106 @finetuneanon reports the results of some evaluations of the ported model on EleutherAI’s evaluation harness. The numbers were a little lower than what we had found using the original implementation, but both he and I felt this was likely due to FP16. I can now confirm that the ported model achieves the same performance as the original model when evaluated in FP32. The absolute difference in performance on lambada, HellaSwag, PiQA, and Winogrande are all less than 0.5% when done in FP32

@finetunej
Copy link

Cool, that's good to know.

@StellaAthena
Copy link
Contributor Author

@patil-suraj can you mark this as a draft, as it is not ready to merge in its current state?

@patil-suraj patil-suraj marked this pull request as draft July 3, 2021 14:10
@socialbrim
Copy link

socialbrim commented Jul 9, 2021

Hey @sualehasif

A stable version will be available in a week, stay tuned!

Hi, @patil-suraj thanks so much for working on this. Is there any progress on integration to huggingface transformers?

@willfrey
Copy link
Contributor

Just chiming in here: All of the .py files with dashes will not be importable :) So I'd suggest changing gpt-j to gptj or gpt_j in the .py file path names.

@calclavia
Copy link

Any updates on this and any help required?

@OhadRubin
Copy link

@patil-suraj What is the status of this?
I would really like to use this model, and I don't feel like messing around with forks to get this to work.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@wolfgangmeyers
Copy link

I would still love to see this happen.

@StellaAthena
Copy link
Contributor Author

I would still love to see this happen.

This is going to happen any day now, see #13022

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants