Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 4th No.102】给AutoConverter增加新的模型组网的支持 Electra #5658

Merged
merged 2 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 131 additions & 1 deletion paddlenlp/transformers/electra/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
# limitations under the License.

from dataclasses import dataclass
from typing import Optional
from typing import List, Optional

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import Tensor
from paddle.nn import TransformerEncoder, TransformerEncoderLayer

from ...utils.converter import StateDictNameMapping
from .. import PretrainedModel, register_base_model
from ..activations import get_activation
from ..model_outputs import (
Expand Down Expand Up @@ -160,6 +161,135 @@ class ElectraPretrainedModel(PretrainedModel):
pretrained_resource_files_map = ELECTRA_PRETRAINED_RESOURCE_FILES_MAP
config_class = ElectraConfig

@classmethod
def _get_name_mappings(cls, config: ElectraConfig) -> List[StateDictNameMapping]:
model_mappings = [
["embeddings.word_embeddings.weight", "embeddings.word_embeddings.weight"],
["embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"],
["embeddings.token_type_embeddings.weight", "embeddings.token_type_embeddings.weight"],
["embeddings.LayerNorm.weight", "embeddings.layer_norm.weight"],
["embeddings.LayerNorm.bias", "embeddings.layer_norm.bias"],
["embeddings_project.weight", "embeddings_project.weight", "transpose"],
["embeddings_project.bias", "embeddings_project.bias"],
]

for layer_index in range(config.num_hidden_layers):
layer_mappings = [
[
f"encoder.layer.{layer_index}.attention.self.query.weight",
f"encoder.layers.{layer_index}.self_attn.q_proj.weight",
"transpose",
],
[
f"encoder.layer.{layer_index}.attention.self.query.bias",
f"encoder.layers.{layer_index}.self_attn.q_proj.bias",
],
[
f"encoder.layer.{layer_index}.attention.self.key.weight",
f"encoder.layers.{layer_index}.self_attn.k_proj.weight",
"transpose",
],
[
f"encoder.layer.{layer_index}.attention.self.key.bias",
f"encoder.layers.{layer_index}.self_attn.k_proj.bias",
],
[
f"encoder.layer.{layer_index}.attention.self.value.weight",
f"encoder.layers.{layer_index}.self_attn.v_proj.weight",
"transpose",
],
[
f"encoder.layer.{layer_index}.attention.self.value.bias",
f"encoder.layers.{layer_index}.self_attn.v_proj.bias",
],
[
f"encoder.layer.{layer_index}.attention.output.dense.weight",
f"encoder.layers.{layer_index}.self_attn.out_proj.weight",
"transpose",
],
[
f"encoder.layer.{layer_index}.attention.output.dense.bias",
f"encoder.layers.{layer_index}.self_attn.out_proj.bias",
],
[
f"encoder.layer.{layer_index}.attention.output.LayerNorm.weight",
f"encoder.layers.{layer_index}.norm1.weight",
],
[
f"encoder.layer.{layer_index}.attention.output.LayerNorm.bias",
f"encoder.layers.{layer_index}.norm1.bias",
],
[
f"encoder.layer.{layer_index}.intermediate.dense.weight",
f"encoder.layers.{layer_index}.linear1.weight",
"transpose",
],
[f"encoder.layer.{layer_index}.intermediate.dense.bias", f"encoder.layers.{layer_index}.linear1.bias"],
[
f"encoder.layer.{layer_index}.output.dense.weight",
f"encoder.layers.{layer_index}.linear2.weight",
"transpose",
],
[f"encoder.layer.{layer_index}.output.dense.bias", f"encoder.layers.{layer_index}.linear2.bias"],
[f"encoder.layer.{layer_index}.output.LayerNorm.weight", f"encoder.layers.{layer_index}.norm2.weight"],
[f"encoder.layer.{layer_index}.output.LayerNorm.bias", f"encoder.layers.{layer_index}.norm2.bias"],
]
model_mappings.extend(layer_mappings)

# base-model prefix "ElectraModel"
if "ElectraModel" not in config.architectures:
for mapping in model_mappings:
mapping[0] = "electra." + mapping[0]
mapping[1] = "electra." + mapping[1]

# downstream mappings
if "ElectraForQuestionAnswering" in config.architectures:
model_mappings.extend(
[["qa_outputs.weight", "classifier.weight", "transpose"], ["qa_outputs.bias", "classifier.bias"]]
)

if "ElectraForMultipleChoice" in config.architectures:
model_mappings.extend(
[
["sequence_summary.summary.weight", "sequence_summary.dense.weight", "transpose"],
["sequence_summary.summary.bias", "sequence_summary.dense.bias"],
["classifier.weight", "classifier.weight", "transpose"],
["classifier.bias", "classifier.bias"],
]
)

if "ElectraForSequenceClassification" in config.architectures:
model_mappings.extend(
[
["classifier.dense.weight", "classifier.dense.weight", "transpose"],
["classifier.dense.bias", "classifier.dense.bias"],
["classifier.out_proj.weight", "classifier.out_proj.weight", "transpose"],
["classifier.out_proj.bias", "classifier.out_proj.bias"],
]
)

if "ElectraForTokenClassification" in config.architectures:
model_mappings.extend(
[
["classifier.weight", "classifier.weight", "transpose"],
["classifier.bias", "classifier.bias"],
]
)

# TODO: need to tie weights
if "ElectraForMaskedLM" in config.architectures:
model_mappings.extend(
[
["generator_predictions.LayerNorm.weight", "generator_predictions.layer_norm.weight", "transpose"],
["generator_predictions.LayerNorm.bias", "generator_predictions.layer_norm.bias"],
["generator_predictions.dense.weight", "generator_predictions.dense.weight", "transpose"],
["generator_predictions.dense.bias", "generator_predictions.dense.bias"],
["generator_lm_head.bias", "generator_lm_head_bias"],
]
)

return [StateDictNameMapping(*mapping) for mapping in model_mappings]

def init_weights(self):
"""
Initializes and tie weights if needed.
Expand Down
149 changes: 147 additions & 2 deletions tests/transformers/electra/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import tempfile
import unittest

import numpy as np
import paddle
from parameterized import parameterized_class
from parameterized import parameterized, parameterized_class

from paddlenlp.transformers import (
ElectraConfig,
Expand All @@ -31,7 +34,7 @@
ElectraModel,
ElectraPretrainedModel,
)
from tests.testing_utils import slow
from tests.testing_utils import require_package, slow
from tests.transformers.test_modeling_common import (
ModelTesterMixin,
floats_tensor,
Expand Down Expand Up @@ -481,6 +484,148 @@ def test_model_from_pretrained(self):
self.assertIsNotNone(model)


class ElectraModelCompatibilityTest(unittest.TestCase):
model_id = "hf-internal-testing/tiny-random-ElectraModel"

@require_package("transformers", "torch")
def test_electra_converter(self):
with tempfile.TemporaryDirectory() as tempdir:
# 1. create input
input_ids = np.random.randint(100, 200, [1, 20])

# 2. forward the paddle model
from paddlenlp.transformers import ElectraModel

paddle_model = ElectraModel.from_pretrained(self.model_id, from_hf_hub=True, cache_dir=tempdir)
paddle_model.eval()
paddle_logit = paddle_model(paddle.to_tensor(input_ids))[0]

# 3. forward the torch model
import torch
from transformers import ElectraModel

torch_model = ElectraModel.from_pretrained(self.model_id, cache_dir=tempdir)
torch_model.eval()
torch_logit = torch_model(torch.tensor(input_ids), return_dict=False)[0]

# 4. compare results
self.assertTrue(
np.allclose(
paddle_logit.detach().cpu().reshape([-1])[:9].numpy(),
torch_logit.detach().cpu().reshape([-1])[:9].numpy(),
rtol=1e-4,
)
)

@require_package("transformers", "torch")
def test_electra_converter_from_local_dir_with_enable_torch(self):
with tempfile.TemporaryDirectory() as tempdir:
# 1. forward the torch model
from transformers import ElectraModel

torch_model = ElectraModel.from_pretrained(self.model_id)
torch_model.save_pretrained(tempdir)

# 2. forward the paddle model
from paddlenlp.transformers import ElectraModel, model_utils

model_utils.ENABLE_TORCH_CHECKPOINT = False

with self.assertRaises(ValueError) as error:
ElectraModel.from_pretrained(tempdir)
self.assertIn("conversion is been disabled" in str(error.exception))
model_utils.ENABLE_TORCH_CHECKPOINT = True

@require_package("transformers", "torch")
def test_electra_converter_from_local_dir(self):
with tempfile.TemporaryDirectory() as tempdir:

# 1. create commmon input
input_ids = np.random.randint(100, 200, [1, 20])

# 2. forward the torch model
import torch
from transformers import ElectraModel

torch_model = ElectraModel.from_pretrained(self.model_id)
torch_model.eval()
torch_model.save_pretrained(tempdir)
torch_logit = torch_model(torch.tensor(input_ids), return_dict=False)[0]

# 2. forward the paddle model
from paddlenlp.transformers import ElectraModel

paddle_model = ElectraModel.from_pretrained(tempdir)
paddle_model.eval()
paddle_logit = paddle_model(paddle.to_tensor(input_ids))[0]

self.assertTrue(
np.allclose(
paddle_logit.detach().cpu().reshape([-1])[:9].numpy(),
torch_logit.detach().cpu().reshape([-1])[:9].numpy(),
rtol=1e-4,
)
)

@parameterized.expand(
[
("ElectraModel",),
# ("ElectraForMaskedLM",), TODO: need to tie weights
# ("ElectraForPretraining",), TODO: need to tie weights
("ElectraForMultipleChoice",),
("ElectraForQuestionAnswering",),
("ElectraForSequenceClassification",),
("ElectraForTokenClassification",),
]
)
@require_package("transformers", "torch")
def test_electra_classes_from_local_dir(self, class_name, pytorch_class_name=None):
pytorch_class_name = pytorch_class_name or class_name
with tempfile.TemporaryDirectory() as tempdir:

# 1. create commmon input
input_ids = np.random.randint(100, 200, [1, 20])

# 2. forward the torch model
import torch
import transformers

torch_model_class = getattr(transformers, pytorch_class_name)
torch_model = torch_model_class.from_pretrained(self.model_id)
torch_model.eval()

if "MultipleChoice" in class_name:
# construct input for MultipleChoice Model
torch_model.config.num_choices = random.randint(2, 10)
input_ids = (
paddle.to_tensor(input_ids)
.unsqueeze(1)
.expand([-1, torch_model.config.num_choices, -1])
.cpu()
.numpy()
)

torch_model.save_pretrained(tempdir)
torch_logit = torch_model(torch.tensor(input_ids), return_dict=False)[0]

# 3. forward the paddle model
from paddlenlp import transformers

paddle_model_class = getattr(transformers, class_name)
paddle_model = paddle_model_class.from_pretrained(tempdir)
paddle_model.eval()

paddle_logit = paddle_model(paddle.to_tensor(input_ids), return_dict=False)[0]

self.assertTrue(
np.allclose(
paddle_logit.detach().cpu().reshape([-1])[:9].numpy(),
torch_logit.detach().cpu().reshape([-1])[:9].numpy(),
atol=1e-3,
)
)


class ElectraModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_no_head_absolute_embedding(self):
Expand Down