Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPT-J-6B #13022

Merged
merged 133 commits into from
Aug 31, 2021
Merged

GPT-J-6B #13022

Show file tree
Hide file tree
Changes from 123 commits
Commits
Show all changes
133 commits
Select commit Hold shift + click to select a range
f01d47f
Test GPTJ implementation
StellaAthena Aug 4, 2021
44ccee7
fix conflicts
StellaAthena Aug 4, 2021
fe991bf
Fixed conflicts
StellaAthena Aug 4, 2021
54f2b33
Update __init__.py
StellaAthena Aug 4, 2021
e2ce2a3
Update __init__.py
StellaAthena Aug 4, 2021
e59e579
change GPT_J to GPTJ
kurumuz Aug 4, 2021
e2329b4
fix missing imports and typos
kurumuz Aug 4, 2021
1bee4ee
use einops for now
kurumuz Aug 4, 2021
03b7278
Use torch ops instead of einsum
kurumuz Aug 4, 2021
8034f2c
remove einops deps
kurumuz Aug 4, 2021
f86b47b
Merge pull request #1 from kurumuz/gptj_fixes
StellaAthena Aug 4, 2021
0a344ba
Merge branch 'huggingface:master' into master
StellaAthena Aug 5, 2021
194d024
Update configuration_auto.py
StellaAthena Aug 5, 2021
06f07da
Added GPT J
StellaAthena Aug 5, 2021
979bff8
Update gptj.rst
StellaAthena Aug 5, 2021
30635c1
Update __init__.py
StellaAthena Aug 5, 2021
bae5e27
Update test_modeling_gptj.py
StellaAthena Aug 5, 2021
1bcf933
Added GPT J
StellaAthena Aug 5, 2021
12a12a7
Changed configs to match GPT2 instead of GPT Neo
StellaAthena Aug 5, 2021
4efbbec
Removed non-existent sequence model
StellaAthena Aug 5, 2021
6877889
Update configuration_auto.py
StellaAthena Aug 5, 2021
cfaaae4
Update configuration_auto.py
StellaAthena Aug 5, 2021
e9860e9
Update configuration_auto.py
StellaAthena Aug 5, 2021
e8a2333
Update modeling_gptj.py
StellaAthena Aug 5, 2021
3bd2879
Update modeling_gptj.py
StellaAthena Aug 5, 2021
8c524f7
Progress on updating configs to agree with GPT2
StellaAthena Aug 6, 2021
f0c0a31
Update modeling_gptj.py
StellaAthena Aug 6, 2021
1ad512b
num_layers -> n_layer
StellaAthena Aug 6, 2021
89b8724
layer_norm_eps -> layer_norm_epsilon
StellaAthena Aug 6, 2021
76fc4e1
attention_layers -> num_hidden_layers
StellaAthena Aug 6, 2021
6284c7e
Update modeling_gptj.py
StellaAthena Aug 6, 2021
2d5cc30
attention_pdrop -> attn_pdrop
StellaAthena Aug 6, 2021
1ddbb63
hidden_act -> activation_function
StellaAthena Aug 6, 2021
b46551d
Update configuration_gptj.py
StellaAthena Aug 6, 2021
60daf97
Update configuration_gptj.py
StellaAthena Aug 6, 2021
1c9ba25
Update configuration_gptj.py
StellaAthena Aug 6, 2021
7f52c42
Update configuration_gptj.py
StellaAthena Aug 6, 2021
33380ca
Update configuration_gptj.py
StellaAthena Aug 6, 2021
05b2b3b
Update modeling_gptj.py
StellaAthena Aug 6, 2021
d6b86f8
Update modeling_gptj.py
StellaAthena Aug 6, 2021
2b633be
Update modeling_gptj.py
StellaAthena Aug 6, 2021
dde732d
Update modeling_gptj.py
StellaAthena Aug 6, 2021
89818be
Update modeling_gptj.py
StellaAthena Aug 6, 2021
0bc0c25
Update modeling_gptj.py
StellaAthena Aug 6, 2021
e9bc670
fix layernorm and lm_head size
kurumuz Aug 6, 2021
0ac5c64
Merge branch 'huggingface:master' into master
StellaAthena Aug 6, 2021
3553d8e
Update docs/source/model_doc/gptj.rst
StellaAthena Aug 6, 2021
6b6e41d
removed claim that GPT J uses local attention
StellaAthena Aug 6, 2021
5fc60db
Removed GPTJForSequenceClassification
StellaAthena Aug 6, 2021
d70f1f8
Update src/transformers/models/gptj/configuration_gptj.py
StellaAthena Aug 6, 2021
878518a
Removed unsupported boilerplate
StellaAthena Aug 6, 2021
a48ee07
Update tests/test_modeling_gptj.py
StellaAthena Aug 6, 2021
f189b6d
Update src/transformers/models/gptj/modeling_gptj.py
StellaAthena Aug 6, 2021
6059ef6
Update src/transformers/models/gptj/modeling_gptj.py
StellaAthena Aug 6, 2021
8793237
Update src/transformers/models/gptj/modeling_gptj.py
StellaAthena Aug 6, 2021
d5d758e
Update tests/test_modeling_gptj.py
StellaAthena Aug 6, 2021
58b77c1
Update tests/test_modeling_gptj.py
StellaAthena Aug 6, 2021
513fa3e
Update tests/test_modeling_gptj.py
StellaAthena Aug 6, 2021
9876057
Update src/transformers/models/gptj/modeling_gptj.py
StellaAthena Aug 6, 2021
de6c5af
Update __init__.py
StellaAthena Aug 7, 2021
3f21e98
Update configuration_gptj.py
StellaAthena Aug 7, 2021
316cc95
Update modeling_gptj.py
StellaAthena Aug 7, 2021
4be1484
Corrected indentation
StellaAthena Aug 7, 2021
80f4658
Remove stray backslash
EricHallahan Aug 8, 2021
f4d70d2
Delete .DS_Store
leogao2 Aug 8, 2021
28caefb
Delete .DS_Store
leogao2 Aug 8, 2021
6633eba
Delete .DS_Store
leogao2 Aug 8, 2021
2f92631
Delete .DS_Store
leogao2 Aug 8, 2021
23356a0
Delete .DS_Store
leogao2 Aug 8, 2021
dda2643
Update docs to match
leogao2 Aug 8, 2021
a31f11a
Remove tf loading
leogao2 Aug 8, 2021
cbf8dd1
Remove config.jax
leogao2 Aug 8, 2021
fed0955
Remove stray `else:` statement
EricHallahan Aug 8, 2021
0ae4be5
Remove references to `load_tf_weights_in_gptj`
EricHallahan Aug 8, 2021
3c6161d
Adapt tests to match output from GPT-J 6B
EricHallahan Aug 8, 2021
dd4f02d
Apply suggestions from code review
StellaAthena Aug 9, 2021
752595f
Default `activation_function` to `gelu_new`
EricHallahan Aug 9, 2021
7a032e5
Fix part of the config documentation
EricHallahan Aug 9, 2021
455c311
Revert "Update configuration_auto.py"
EricHallahan Aug 10, 2021
49ba5cc
Revert "Update configuration_auto.py"
EricHallahan Aug 10, 2021
3ebf87b
Revert "Update configuration_auto.py"
EricHallahan Aug 10, 2021
c844dcb
Revert "Update configuration_auto.py"
EricHallahan Aug 10, 2021
9af5cff
Hyphenate GPT-J
EricHallahan Aug 10, 2021
74a9777
Undid sorting of the models alphabetically
StellaAthena Aug 10, 2021
4a40d00
Reverting previous commit
StellaAthena Aug 10, 2021
176ec56
Merge branch 'master' into master
patil-suraj Aug 13, 2021
24ac25a
fix style and quality issues
patil-suraj Aug 13, 2021
d7ac30f
Update docs/source/model_doc/gptj.rst
StellaAthena Aug 14, 2021
6857e93
Update src/transformers/__init__.py
StellaAthena Aug 14, 2021
94db694
Update tests/test_modeling_gptj.py
StellaAthena Aug 14, 2021
5fa31e0
Update src/transformers/models/gptj/modeling_gptj.py
StellaAthena Aug 14, 2021
f38e019
Update src/transformers/__init__.py
StellaAthena Aug 14, 2021
e4a5f5a
Update src/transformers/models/gptj/modeling_gptj.py
StellaAthena Aug 14, 2021
2d0a2a0
Update src/transformers/models/gptj/modeling_gptj.py
StellaAthena Aug 14, 2021
7443fcb
Update src/transformers/models/gptj/configuration_gptj.py
StellaAthena Aug 14, 2021
b3c1a20
Update src/transformers/models/gptj/configuration_gptj.py
StellaAthena Aug 14, 2021
0bedd33
Update src/transformers/models/gptj/configuration_gptj.py
StellaAthena Aug 14, 2021
2507592
Update src/transformers/models/gptj/modeling_gptj.py
StellaAthena Aug 14, 2021
cd4713f
Update src/transformers/models/gptj/modeling_gptj.py
StellaAthena Aug 14, 2021
f0a3c0a
Update src/transformers/models/gptj/modeling_gptj.py
StellaAthena Aug 14, 2021
f046728
Update src/transformers/models/gptj/modeling_gptj.py
StellaAthena Aug 14, 2021
3ae0298
Update src/transformers/models/gptj/modeling_gptj.py
StellaAthena Aug 17, 2021
27aeb5f
Replaced GPTJ-specific code with generic code
StellaAthena Aug 17, 2021
8c8ee6b
Update src/transformers/models/gptj/modeling_gptj.py
StellaAthena Aug 17, 2021
53b801a
Made the code always use rotary positional encodings
StellaAthena Aug 17, 2021
ae18ff5
Update index.rst
StellaAthena Aug 17, 2021
c4d11ca
Fix documentation
StellaAthena Aug 17, 2021
4a37a78
Combine attention classes
EricHallahan Aug 17, 2021
d5e5f84
Removed `config.rotary_dim` from tests
StellaAthena Aug 17, 2021
3522e07
Merge branch 'huggingface:master' into master
StellaAthena Aug 17, 2021
c27e587
Update test_modeling_gptj.py
StellaAthena Aug 17, 2021
9eebb6f
Update test_modeling_gptj.py
StellaAthena Aug 17, 2021
ff301d3
Fix formatting
EricHallahan Aug 17, 2021
ff1eb1d
Removed depreciated argument `layer_id` to `GPTJAttention`
StellaAthena Aug 18, 2021
4c86bbc
Update modeling_gptj.py
StellaAthena Aug 18, 2021
1f99941
Update modeling_gptj.py
StellaAthena Aug 18, 2021
b6021cf
Fix code quality
EricHallahan Aug 18, 2021
d2c85a2
Restore model functionality
EricHallahan Aug 19, 2021
223bda1
Save `lm_head.weight` in checkpoints
EricHallahan Aug 22, 2021
ad567a9
Fix crashes when loading with reduced precision
EricHallahan Aug 22, 2021
af0c01d
refactor self._attn(...)` and rename layer weights"
Aug 23, 2021
c90eb3b
make sure logits are in fp32 for sampling
patrickvonplaten Aug 23, 2021
3897823
improve docs
patrickvonplaten Aug 23, 2021
504b339
Add `GPTJForCausalLM` to `TextGenerationPipeline` whitelist
EricHallahan Aug 25, 2021
7dcc7c5
Merge branch 'huggingface:master' into master
StellaAthena Aug 25, 2021
8cbaa1f
Added GPT-J to the README
StellaAthena Aug 25, 2021
bca938d
Fix doc/readme consistency
EricHallahan Aug 25, 2021
71d3300
Merge branch 'master' into master
StellaAthena Aug 27, 2021
22f8131
Add rough parallelization support
EricHallahan Aug 27, 2021
1d69e42
Clean up loose ends
EricHallahan Aug 30, 2021
bce04cb
Merge branch 'master' into master
EricHallahan Aug 30, 2021
4784eae
Fix index.rst
EricHallahan Aug 30, 2021
3466cd0
fix merge conflicts
patrickvonplaten Aug 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,8 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| GPT Neo | ❌ | ❌ | ✅ | ❌ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| GPT-J | ❌ | ❌ | ✅ | ❌ | ❌ |
StellaAthena marked this conversation as resolved.
Show resolved Hide resolved
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Hubert | ❌ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| I-BERT | ❌ | ❌ | ✅ | ❌ | ❌ |
Expand Down Expand Up @@ -564,6 +566,7 @@ Flax), PyTorch, and/or TensorFlow.
model_doc/mt5
model_doc/gpt
model_doc/gpt2
model_doc/gptj
model_doc/gpt_neo
model_doc/hubert
model_doc/pegasus
Expand Down
102 changes: 102 additions & 0 deletions docs/source/model_doc/gptj.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
..
Copyright 2021 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

GPT-J
-----------------------------------------------------------------------------------------------------------------------

Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The GPT-J model was released in the `kingoflolz/mesh-transformer-jax
<https://github.com/kingoflolz/mesh-transformer-jax>`__ repository by Ben Wang and Aran Komatsuzaki. It is a GPT-2-like
causal language model trained on `the Pile <https://pile.eleuther.ai/>`__ dataset.

This model was contributed by `Stella Biderman <https://huggingface.co/stellaathena>`__.
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

Tips:

- Running [GPT-J](https://huggingface.co/EleutherAI/gpt-j-6B) in float32 precision on GPU requires at least 24 GB of
RAM. On GPUs with less than 24 GB RAM, one should therefore load the model in half-precision:

.. code-block::

>>> from transformers import GPTJForCausalLM
>>> import torch

>>> model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.float16)

Generation
_______________________________________________________________________________________________________________________

The :meth:`~transformers.generation_utils.GenerationMixin.generate` method can be used to generate text using GPT-J
model.

.. code-block::

>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")

>>> prompt = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \
... "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \
... "researchers was the fact that the unicorns spoke perfect English."

>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids

>>> gen_tokens = model.generate(input_ids, do_sample=True, temperature=0.9, max_length=100,)
>>> gen_text = tokenizer.batch_decode(gen_tokens)[0]

...or in float16 precision:

.. code-block::

>>> from transformers import GPTJForCausalLM, AutoTokenizer
>>> import torch

>>> model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.float16)
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")

>>> prompt = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \
... "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \
... "researchers was the fact that the unicorns spoke perfect English."

>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids

>>> gen_tokens = model.generate(input_ids, do_sample=True, temperature=0.9, max_length=100,)
>>> gen_text = tokenizer.batch_decode(gen_tokens)[0]


GPTJConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.GPTJConfig
:members:

GPTJModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.GPTJModel
:members: forward


GPTJForCausalLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.GPTJForCausalLM
:members: forward


GPTJForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.GPTJForSequenceClassification
:members: forward
18 changes: 18 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@
"models.funnel": ["FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP", "FunnelConfig", "FunnelTokenizer"],
"models.gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2Tokenizer"],
"models.gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"],
"models.gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig"],
"models.herbert": ["HerbertTokenizer"],
"models.hubert": ["HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "HubertConfig"],
"models.ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"],
Expand Down Expand Up @@ -814,6 +815,15 @@
"load_tf_weights_in_gpt_neo",
]
)
_import_structure["models.gptj"].extend(
[
"GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTJForCausalLM",
"GPTJForSequenceClassification",
"GPTJModel",
"GPTJPreTrainedModel",
]
)
_import_structure["models.hubert"].extend(
[
"HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -1898,6 +1908,7 @@
from .models.funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig, FunnelTokenizer
from .models.gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2Tokenizer
from .models.gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig
from .models.gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig
from .models.herbert import HerbertTokenizer
from .models.hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig
from .models.ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
Expand Down Expand Up @@ -2420,6 +2431,13 @@
GPTNeoPreTrainedModel,
load_tf_weights_in_gpt_neo,
)
from .models.gptj import (
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTJForCausalLM,
GPTJForSequenceClassification,
GPTJModel,
GPTJPreTrainedModel,
)
from .models.hubert import (
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
HubertForCTC,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
funnel,
gpt2,
gpt_neo,
gptj,
herbert,
hubert,
ibert,
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
CONFIG_MAPPING_NAMES = OrderedDict(
[
# Add configs here
("gptj", "GPTJConfig"),
("beit", "BeitConfig"),
("rembert", "RemBertConfig"),
("visual_bert", "VisualBertConfig"),
Expand Down Expand Up @@ -94,6 +95,7 @@
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
[
# Add archive maps here
("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Expand Down Expand Up @@ -155,6 +157,7 @@
MODEL_NAMES_MAPPING = OrderedDict(
[
# Add full (and cased) model names here
("gptj", "GPT-J"),
("beit", "BeiT"),
("rembert", "RemBERT"),
("visual_bert", "VisualBert"),
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("gptj", "GPTJModel"),
("beit", "BeitModel"),
("rembert", "RemBertModel"),
("visual_bert", "VisualBertModel"),
Expand Down Expand Up @@ -134,6 +135,7 @@
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
# Model with LM heads mapping
("gptj", "GPTJForCausalLM"),
("rembert", "RemBertForMaskedLM"),
("roformer", "RoFormerForMaskedLM"),
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
Expand Down Expand Up @@ -182,6 +184,7 @@
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
("gptj", "GPTJForCausalLM"),
("rembert", "RemBertForCausalLM"),
("roformer", "RoFormerForCausalLM"),
("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
Expand Down Expand Up @@ -285,6 +288,7 @@
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
("gptj", "GPTJForSequenceClassification"),
("rembert", "RemBertForSequenceClassification"),
("canine", "CanineForSequenceClassification"),
("roformer", "RoFormerForSequenceClassification"),
Expand Down
52 changes: 52 additions & 0 deletions src/transformers/models/gptj/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...file_utils import _LazyModule, is_torch_available


_import_structure = {
"configuration_gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig"],
}

if is_torch_available():
_import_structure["modeling_gptj"] = [
"GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTJForCausalLM",
"GPTJForSequenceClassification",
"GPTJModel",
"GPTJPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig

if is_torch_available():
from .modeling_gptj import (
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTJForCausalLM,
GPTJForSequenceClassification,
GPTJModel,
GPTJPreTrainedModel,
)

else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
Loading