Skip to content

Commit

Permalink
[Add]Add Electra to AutoConverter (#5658)
Browse files Browse the repository at this point in the history
  • Loading branch information
megemini authored Apr 14, 2023
1 parent 638c795 commit 848f5c1
Show file tree
Hide file tree
Showing 2 changed files with 278 additions and 3 deletions.
132 changes: 131 additions & 1 deletion paddlenlp/transformers/electra/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
# limitations under the License.

from dataclasses import dataclass
from typing import Optional
from typing import List, Optional

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import Tensor
from paddle.nn import TransformerEncoder, TransformerEncoderLayer

from ...utils.converter import StateDictNameMapping
from .. import PretrainedModel, register_base_model
from ..activations import get_activation
from ..model_outputs import (
Expand Down Expand Up @@ -160,6 +161,135 @@ class ElectraPretrainedModel(PretrainedModel):
pretrained_resource_files_map = ELECTRA_PRETRAINED_RESOURCE_FILES_MAP
config_class = ElectraConfig

@classmethod
def _get_name_mappings(cls, config: ElectraConfig) -> List[StateDictNameMapping]:
model_mappings = [
["embeddings.word_embeddings.weight", "embeddings.word_embeddings.weight"],
["embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"],
["embeddings.token_type_embeddings.weight", "embeddings.token_type_embeddings.weight"],
["embeddings.LayerNorm.weight", "embeddings.layer_norm.weight"],
["embeddings.LayerNorm.bias", "embeddings.layer_norm.bias"],
["embeddings_project.weight", "embeddings_project.weight", "transpose"],
["embeddings_project.bias", "embeddings_project.bias"],
]

for layer_index in range(config.num_hidden_layers):
layer_mappings = [
[
f"encoder.layer.{layer_index}.attention.self.query.weight",
f"encoder.layers.{layer_index}.self_attn.q_proj.weight",
"transpose",
],
[
f"encoder.layer.{layer_index}.attention.self.query.bias",
f"encoder.layers.{layer_index}.self_attn.q_proj.bias",
],
[
f"encoder.layer.{layer_index}.attention.self.key.weight",
f"encoder.layers.{layer_index}.self_attn.k_proj.weight",
"transpose",
],
[
f"encoder.layer.{layer_index}.attention.self.key.bias",
f"encoder.layers.{layer_index}.self_attn.k_proj.bias",
],
[
f"encoder.layer.{layer_index}.attention.self.value.weight",
f"encoder.layers.{layer_index}.self_attn.v_proj.weight",
"transpose",
],
[
f"encoder.layer.{layer_index}.attention.self.value.bias",
f"encoder.layers.{layer_index}.self_attn.v_proj.bias",
],
[
f"encoder.layer.{layer_index}.attention.output.dense.weight",
f"encoder.layers.{layer_index}.self_attn.out_proj.weight",
"transpose",
],
[
f"encoder.layer.{layer_index}.attention.output.dense.bias",
f"encoder.layers.{layer_index}.self_attn.out_proj.bias",
],
[
f"encoder.layer.{layer_index}.attention.output.LayerNorm.weight",
f"encoder.layers.{layer_index}.norm1.weight",
],
[
f"encoder.layer.{layer_index}.attention.output.LayerNorm.bias",
f"encoder.layers.{layer_index}.norm1.bias",
],
[
f"encoder.layer.{layer_index}.intermediate.dense.weight",
f"encoder.layers.{layer_index}.linear1.weight",
"transpose",
],
[f"encoder.layer.{layer_index}.intermediate.dense.bias", f"encoder.layers.{layer_index}.linear1.bias"],
[
f"encoder.layer.{layer_index}.output.dense.weight",
f"encoder.layers.{layer_index}.linear2.weight",
"transpose",
],
[f"encoder.layer.{layer_index}.output.dense.bias", f"encoder.layers.{layer_index}.linear2.bias"],
[f"encoder.layer.{layer_index}.output.LayerNorm.weight", f"encoder.layers.{layer_index}.norm2.weight"],
[f"encoder.layer.{layer_index}.output.LayerNorm.bias", f"encoder.layers.{layer_index}.norm2.bias"],
]
model_mappings.extend(layer_mappings)

# base-model prefix "ElectraModel"
if "ElectraModel" not in config.architectures:
for mapping in model_mappings:
mapping[0] = "electra." + mapping[0]
mapping[1] = "electra." + mapping[1]

# downstream mappings
if "ElectraForQuestionAnswering" in config.architectures:
model_mappings.extend(
[["qa_outputs.weight", "classifier.weight", "transpose"], ["qa_outputs.bias", "classifier.bias"]]
)

if "ElectraForMultipleChoice" in config.architectures:
model_mappings.extend(
[
["sequence_summary.summary.weight", "sequence_summary.dense.weight", "transpose"],
["sequence_summary.summary.bias", "sequence_summary.dense.bias"],
["classifier.weight", "classifier.weight", "transpose"],
["classifier.bias", "classifier.bias"],
]
)

if "ElectraForSequenceClassification" in config.architectures:
model_mappings.extend(
[
["classifier.dense.weight", "classifier.dense.weight", "transpose"],
["classifier.dense.bias", "classifier.dense.bias"],
["classifier.out_proj.weight", "classifier.out_proj.weight", "transpose"],
["classifier.out_proj.bias", "classifier.out_proj.bias"],
]
)

if "ElectraForTokenClassification" in config.architectures:
model_mappings.extend(
[
["classifier.weight", "classifier.weight", "transpose"],
["classifier.bias", "classifier.bias"],
]
)

# TODO: need to tie weights
if "ElectraForMaskedLM" in config.architectures:
model_mappings.extend(
[
["generator_predictions.LayerNorm.weight", "generator_predictions.layer_norm.weight", "transpose"],
["generator_predictions.LayerNorm.bias", "generator_predictions.layer_norm.bias"],
["generator_predictions.dense.weight", "generator_predictions.dense.weight", "transpose"],
["generator_predictions.dense.bias", "generator_predictions.dense.bias"],
["generator_lm_head.bias", "generator_lm_head_bias"],
]
)

return [StateDictNameMapping(*mapping) for mapping in model_mappings]

def init_weights(self):
"""
Initializes and tie weights if needed.
Expand Down
149 changes: 147 additions & 2 deletions tests/transformers/electra/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import tempfile
import unittest

import numpy as np
import paddle
from parameterized import parameterized_class
from parameterized import parameterized, parameterized_class

from paddlenlp.transformers import (
ElectraConfig,
Expand All @@ -31,7 +34,7 @@
ElectraModel,
ElectraPretrainedModel,
)
from tests.testing_utils import slow
from tests.testing_utils import require_package, slow
from tests.transformers.test_modeling_common import (
ModelTesterMixin,
floats_tensor,
Expand Down Expand Up @@ -481,6 +484,148 @@ def test_model_from_pretrained(self):
self.assertIsNotNone(model)


class ElectraModelCompatibilityTest(unittest.TestCase):
model_id = "hf-internal-testing/tiny-random-ElectraModel"

@require_package("transformers", "torch")
def test_electra_converter(self):
with tempfile.TemporaryDirectory() as tempdir:
# 1. create input
input_ids = np.random.randint(100, 200, [1, 20])

# 2. forward the paddle model
from paddlenlp.transformers import ElectraModel

paddle_model = ElectraModel.from_pretrained(self.model_id, from_hf_hub=True, cache_dir=tempdir)
paddle_model.eval()
paddle_logit = paddle_model(paddle.to_tensor(input_ids))[0]

# 3. forward the torch model
import torch
from transformers import ElectraModel

torch_model = ElectraModel.from_pretrained(self.model_id, cache_dir=tempdir)
torch_model.eval()
torch_logit = torch_model(torch.tensor(input_ids), return_dict=False)[0]

# 4. compare results
self.assertTrue(
np.allclose(
paddle_logit.detach().cpu().reshape([-1])[:9].numpy(),
torch_logit.detach().cpu().reshape([-1])[:9].numpy(),
rtol=1e-4,
)
)

@require_package("transformers", "torch")
def test_electra_converter_from_local_dir_with_enable_torch(self):
with tempfile.TemporaryDirectory() as tempdir:
# 1. forward the torch model
from transformers import ElectraModel

torch_model = ElectraModel.from_pretrained(self.model_id)
torch_model.save_pretrained(tempdir)

# 2. forward the paddle model
from paddlenlp.transformers import ElectraModel, model_utils

model_utils.ENABLE_TORCH_CHECKPOINT = False

with self.assertRaises(ValueError) as error:
ElectraModel.from_pretrained(tempdir)
self.assertIn("conversion is been disabled" in str(error.exception))
model_utils.ENABLE_TORCH_CHECKPOINT = True

@require_package("transformers", "torch")
def test_electra_converter_from_local_dir(self):
with tempfile.TemporaryDirectory() as tempdir:

# 1. create commmon input
input_ids = np.random.randint(100, 200, [1, 20])

# 2. forward the torch model
import torch
from transformers import ElectraModel

torch_model = ElectraModel.from_pretrained(self.model_id)
torch_model.eval()
torch_model.save_pretrained(tempdir)
torch_logit = torch_model(torch.tensor(input_ids), return_dict=False)[0]

# 2. forward the paddle model
from paddlenlp.transformers import ElectraModel

paddle_model = ElectraModel.from_pretrained(tempdir)
paddle_model.eval()
paddle_logit = paddle_model(paddle.to_tensor(input_ids))[0]

self.assertTrue(
np.allclose(
paddle_logit.detach().cpu().reshape([-1])[:9].numpy(),
torch_logit.detach().cpu().reshape([-1])[:9].numpy(),
rtol=1e-4,
)
)

@parameterized.expand(
[
("ElectraModel",),
# ("ElectraForMaskedLM",), TODO: need to tie weights
# ("ElectraForPretraining",), TODO: need to tie weights
("ElectraForMultipleChoice",),
("ElectraForQuestionAnswering",),
("ElectraForSequenceClassification",),
("ElectraForTokenClassification",),
]
)
@require_package("transformers", "torch")
def test_electra_classes_from_local_dir(self, class_name, pytorch_class_name=None):
pytorch_class_name = pytorch_class_name or class_name
with tempfile.TemporaryDirectory() as tempdir:

# 1. create commmon input
input_ids = np.random.randint(100, 200, [1, 20])

# 2. forward the torch model
import torch
import transformers

torch_model_class = getattr(transformers, pytorch_class_name)
torch_model = torch_model_class.from_pretrained(self.model_id)
torch_model.eval()

if "MultipleChoice" in class_name:
# construct input for MultipleChoice Model
torch_model.config.num_choices = random.randint(2, 10)
input_ids = (
paddle.to_tensor(input_ids)
.unsqueeze(1)
.expand([-1, torch_model.config.num_choices, -1])
.cpu()
.numpy()
)

torch_model.save_pretrained(tempdir)
torch_logit = torch_model(torch.tensor(input_ids), return_dict=False)[0]

# 3. forward the paddle model
from paddlenlp import transformers

paddle_model_class = getattr(transformers, class_name)
paddle_model = paddle_model_class.from_pretrained(tempdir)
paddle_model.eval()

paddle_logit = paddle_model(paddle.to_tensor(input_ids), return_dict=False)[0]

self.assertTrue(
np.allclose(
paddle_logit.detach().cpu().reshape([-1])[:9].numpy(),
torch_logit.detach().cpu().reshape([-1])[:9].numpy(),
atol=1e-3,
)
)


class ElectraModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_no_head_absolute_embedding(self):
Expand Down

0 comments on commit 848f5c1

Please sign in to comment.