Skip to content

Commit

Permalink
GPT-J-6B (#13022)
Browse files Browse the repository at this point in the history
* Test GPTJ implementation

* Fixed conflicts

* Update __init__.py

* Update __init__.py

* change GPT_J to GPTJ

* fix missing imports and typos

* use einops for now
(need to change to torch ops later)

* Use torch ops instead of einsum

* remove einops deps

* Update configuration_auto.py

* Added GPT J

* Update gptj.rst

* Update __init__.py

* Update test_modeling_gptj.py

* Added GPT J

* Changed configs to match GPT2 instead of GPT Neo

* Removed non-existent sequence model

* Update configuration_auto.py

* Update configuration_auto.py

* Update configuration_auto.py

* Update modeling_gptj.py

* Update modeling_gptj.py

* Progress on updating configs to agree with GPT2

* Update modeling_gptj.py

* num_layers -> n_layer

* layer_norm_eps -> layer_norm_epsilon

* attention_layers -> num_hidden_layers

* Update modeling_gptj.py

* attention_pdrop -> attn_pdrop

* hidden_act -> activation_function

* Update configuration_gptj.py

* Update configuration_gptj.py

* Update configuration_gptj.py

* Update configuration_gptj.py

* Update configuration_gptj.py

* Update modeling_gptj.py

* Update modeling_gptj.py

* Update modeling_gptj.py

* Update modeling_gptj.py

* Update modeling_gptj.py

* Update modeling_gptj.py

* fix layernorm and lm_head size
delete attn_type

* Update docs/source/model_doc/gptj.rst

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* removed claim that GPT J uses local attention

* Removed GPTJForSequenceClassification

* Update src/transformers/models/gptj/configuration_gptj.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Removed unsupported boilerplate

* Update tests/test_modeling_gptj.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Update src/transformers/models/gptj/modeling_gptj.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Update src/transformers/models/gptj/modeling_gptj.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Update src/transformers/models/gptj/modeling_gptj.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update tests/test_modeling_gptj.py

Co-authored-by: Eric Hallahan <eric@hallahans.name>

* Update tests/test_modeling_gptj.py

Co-authored-by: Eric Hallahan <eric@hallahans.name>

* Update tests/test_modeling_gptj.py

Co-authored-by: Eric Hallahan <eric@hallahans.name>

* Update src/transformers/models/gptj/modeling_gptj.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Update __init__.py

* Update configuration_gptj.py

* Update modeling_gptj.py

* Corrected indentation

* Remove stray backslash

* Delete .DS_Store

* Delete .DS_Store

* Delete .DS_Store

* Delete .DS_Store

* Delete .DS_Store

* Update docs to match

* Remove tf loading

* Remove config.jax

* Remove stray `else:` statement

* Remove references to `load_tf_weights_in_gptj`

* Adapt tests to match output from GPT-J 6B

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Default `activation_function` to `gelu_new`

- Specify the approximate formulation of GELU to ensure parity with the default setting of `jax.nn.gelu()`

* Fix part of the config documentation

* Revert "Update configuration_auto.py"

This reverts commit e9860e9.

* Revert "Update configuration_auto.py"

This reverts commit cfaaae4.

* Revert "Update configuration_auto.py"

This reverts commit 6877889.

* Revert "Update configuration_auto.py"

This reverts commit 194d024.

* Hyphenate GPT-J

* Undid sorting of the models alphabetically

* Reverting previous commit

* fix style and quality issues

* Update docs/source/model_doc/gptj.rst

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/__init__.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update tests/test_modeling_gptj.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gptj/modeling_gptj.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/__init__.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gptj/modeling_gptj.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gptj/modeling_gptj.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gptj/configuration_gptj.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gptj/configuration_gptj.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gptj/configuration_gptj.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gptj/modeling_gptj.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gptj/modeling_gptj.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gptj/modeling_gptj.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gptj/modeling_gptj.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gptj/modeling_gptj.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Replaced GPTJ-specific code with generic code

* Update src/transformers/models/gptj/modeling_gptj.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Made the code always use rotary positional encodings

* Update index.rst

* Fix documentation

* Combine attention classes

- Condense all attention operations into `GPTJAttention`
- Replicate GPT-2 and improve code clarity by renaming `GPTJAttention.attn_pdrop` and `GPTJAttention.resid_pdrop` to `GPTJAttention.attn_dropout` and `GPTJAttention.resid_dropout`

* Removed `config.rotary_dim` from tests

* Update test_modeling_gptj.py

* Update test_modeling_gptj.py

* Fix formatting

* Removed depreciated argument `layer_id` to `GPTJAttention`

* Update modeling_gptj.py

* Update modeling_gptj.py

* Fix code quality

* Restore model functionality

* Save `lm_head.weight` in checkpoints

* Fix crashes when loading with reduced precision

* refactor self._attn(...)` and rename layer weights"

* make sure logits are in fp32 for sampling

* improve docs

* Add `GPTJForCausalLM` to `TextGenerationPipeline` whitelist

* Added GPT-J to the README

* Fix doc/readme consistency

* Add rough parallelization support

- Remove unused imports and variables
- Clean up docstrings
- Port experimental parallelization code from GPT-2 into GPT-J

* Clean up loose ends

* Fix index.rst

Co-authored-by: kurumuz <kurumuz1@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Eric Hallahan <eric@hallahans.name>
Co-authored-by: Leo Gao <54557097+leogao2@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: your_github_username <your_github_email>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
8 people authored Aug 31, 2021
1 parent e53af03 commit c02cd95
Show file tree
Hide file tree
Showing 13 changed files with 1,938 additions and 38 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
1. **[Funnel Transformer](https://huggingface.co/transformers/model_doc/funnel.html)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le.
1. **[GPT](https://huggingface.co/transformers/model_doc/gpt.html)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever.
1. **[GPT-2](https://huggingface.co/transformers/model_doc/gpt2.html)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.
1. **[GPT-J](https://huggingface.co/transformers/model_doc/gptj.html)** (from EleutherAI) released in the repository [kingoflolz/mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax/) by Ben Wang and Aran Komatsuzaki.
1. **[GPT Neo](https://huggingface.co/transformers/model_doc/gpt_neo.html)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy.
1. **[Hubert](https://huggingface.co/transformers/model_doc/hubert.html)** (from Facebook) released with the paper [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed.
1. **[I-BERT](https://huggingface.co/transformers/model_doc/ibert.html)** (from Berkeley) released with the paper [I-BERT: Integer-only BERT Quantization](https://arxiv.org/abs/2101.01321) by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer
Expand Down
81 changes: 43 additions & 38 deletions docs/source/index.rst

Large diffs are not rendered by default.

102 changes: 102 additions & 0 deletions docs/source/model_doc/gptj.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
..
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

GPT-J
-----------------------------------------------------------------------------------------------------------------------

Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The GPT-J model was released in the `kingoflolz/mesh-transformer-jax
<https://github.com/kingoflolz/mesh-transformer-jax>`__ repository by Ben Wang and Aran Komatsuzaki. It is a GPT-2-like
causal language model trained on `the Pile <https://pile.eleuther.ai/>`__ dataset.

This model was contributed by `Stella Biderman <https://huggingface.co/stellaathena>`__.

Tips:

- Running [GPT-J](https://huggingface.co/EleutherAI/gpt-j-6B) in float32 precision on GPU requires at least 24 GB of
RAM. On GPUs with less than 24 GB RAM, one should therefore load the model in half-precision:

.. code-block::
>>> from transformers import GPTJForCausalLM
>>> import torch
>>> model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.float16)
Generation
_______________________________________________________________________________________________________________________

The :meth:`~transformers.generation_utils.GenerationMixin.generate` method can be used to generate text using GPT-J
model.

.. code-block::
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
>>> prompt = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \
... "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \
... "researchers was the fact that the unicorns spoke perfect English."
>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
>>> gen_tokens = model.generate(input_ids, do_sample=True, temperature=0.9, max_length=100,)
>>> gen_text = tokenizer.batch_decode(gen_tokens)[0]
...or in float16 precision:

.. code-block::
>>> from transformers import GPTJForCausalLM, AutoTokenizer
>>> import torch
>>> model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.float16)
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
>>> prompt = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \
... "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \
... "researchers was the fact that the unicorns spoke perfect English."
>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
>>> gen_tokens = model.generate(input_ids, do_sample=True, temperature=0.9, max_length=100,)
>>> gen_text = tokenizer.batch_decode(gen_tokens)[0]
GPTJConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.GPTJConfig
:members:

GPTJModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.GPTJModel
:members: forward


GPTJForCausalLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.GPTJForCausalLM
:members: forward


GPTJForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.GPTJForSequenceClassification
:members: forward
18 changes: 18 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@
"models.funnel": ["FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP", "FunnelConfig", "FunnelTokenizer"],
"models.gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2Tokenizer"],
"models.gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"],
"models.gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig"],
"models.herbert": ["HerbertTokenizer"],
"models.hubert": ["HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "HubertConfig"],
"models.ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"],
Expand Down Expand Up @@ -824,6 +825,15 @@
"load_tf_weights_in_gpt_neo",
]
)
_import_structure["models.gptj"].extend(
[
"GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTJForCausalLM",
"GPTJForSequenceClassification",
"GPTJModel",
"GPTJPreTrainedModel",
]
)
_import_structure["models.hubert"].extend(
[
"HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -1966,6 +1976,7 @@
from .models.funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig, FunnelTokenizer
from .models.gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2Tokenizer
from .models.gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig
from .models.gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig
from .models.herbert import HerbertTokenizer
from .models.hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig
from .models.ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
Expand Down Expand Up @@ -2486,6 +2497,13 @@
GPTNeoPreTrainedModel,
load_tf_weights_in_gpt_neo,
)
from .models.gptj import (
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTJForCausalLM,
GPTJForSequenceClassification,
GPTJModel,
GPTJPreTrainedModel,
)
from .models.hubert import (
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
HubertForCTC,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
funnel,
gpt2,
gpt_neo,
gptj,
herbert,
hubert,
ibert,
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
CONFIG_MAPPING_NAMES = OrderedDict(
[
# Add configs here
("gptj", "GPTJConfig"),
("layoutlmv2", "LayoutLMv2Config"),
("beit", "BeitConfig"),
("rembert", "RemBertConfig"),
Expand Down Expand Up @@ -96,6 +97,7 @@
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
[
# Add archive maps here
("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Expand Down Expand Up @@ -158,6 +160,7 @@
MODEL_NAMES_MAPPING = OrderedDict(
[
# Add full (and cased) model names here
("gptj", "GPT-J"),
("beit", "BeiT"),
("rembert", "RemBERT"),
("layoutlmv2", "LayoutLMv2"),
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("gptj", "GPTJModel"),
("layoutlmv2", "LayoutLMv2Model"),
("beit", "BeitModel"),
("rembert", "RemBertModel"),
Expand Down Expand Up @@ -135,6 +136,7 @@
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
# Model with LM heads mapping
("gptj", "GPTJForCausalLM"),
("rembert", "RemBertForMaskedLM"),
("roformer", "RoFormerForMaskedLM"),
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
Expand Down Expand Up @@ -183,6 +185,7 @@
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
("gptj", "GPTJForCausalLM"),
("rembert", "RemBertForCausalLM"),
("roformer", "RoFormerForCausalLM"),
("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
Expand Down Expand Up @@ -286,6 +289,7 @@
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
("gptj", "GPTJForSequenceClassification"),
("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
("rembert", "RemBertForSequenceClassification"),
("canine", "CanineForSequenceClassification"),
Expand Down
52 changes: 52 additions & 0 deletions src/transformers/models/gptj/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...file_utils import _LazyModule, is_torch_available


_import_structure = {
"configuration_gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig"],
}

if is_torch_available():
_import_structure["modeling_gptj"] = [
"GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTJForCausalLM",
"GPTJForSequenceClassification",
"GPTJModel",
"GPTJPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig

if is_torch_available():
from .modeling_gptj import (
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTJForCausalLM,
GPTJForSequenceClassification,
GPTJModel,
GPTJPreTrainedModel,
)

else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
Loading

0 comments on commit c02cd95

Please sign in to comment.