Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 4th No.102】给AutoConverter增加新的模型组网的支持 AlbertModel #5626

Merged
merged 8 commits into from
Apr 14, 2023
115 changes: 114 additions & 1 deletion paddlenlp/transformers/albert/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
"""Modeling classes for ALBERT model."""

import math
from typing import Optional, Tuple
from typing import List, Optional, Tuple

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Layer

from ...utils.converter import StateDictNameMapping
from ...utils.env import CONFIG_NAME
from .. import PretrainedModel, register_base_model
from ..activations import ACT2FN
Expand Down Expand Up @@ -357,6 +358,118 @@ class AlbertPretrainedModel(PretrainedModel):
pretrained_init_configuration = ALBERT_PRETRAINED_INIT_CONFIGURATION
pretrained_resource_files_map = ALBERT_PRETRAINED_RESOURCE_FILES_MAP

@classmethod
def _get_name_mappings(cls, config: AlbertConfig) -> List[StateDictNameMapping]:
mappings: list[StateDictNameMapping] = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个变量定义在这里没有用到,故可以删除。

model_mappings = [
["embeddings.word_embeddings.weight", "embeddings.word_embeddings.weight"],
["embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"],
["embeddings.token_type_embeddings.weight", "embeddings.token_type_embeddings.weight"],
["embeddings.LayerNorm.weight", "embeddings.layer_norm.weight"],
["embeddings.LayerNorm.bias", "embeddings.layer_norm.bias"],
["encoder.embedding_hidden_mapping_in.weight", "encoder.embedding_hidden_mapping_in.weight", "transpose"],
["encoder.embedding_hidden_mapping_in.bias", "encoder.embedding_hidden_mapping_in.bias"],
["pooler.weight", "pooler.weight", "transpose"],
["pooler.bias", "pooler.bias"],
]
for group_index in range(config.num_hidden_groups):
group_mappings = [
[
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.full_layer_layer_norm.weight",
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.full_layer_layer_norm.weight",
],
[
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.full_layer_layer_norm.bias",
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.full_layer_layer_norm.bias",
],
[
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.attention.query.weight",
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.attention.query.weight",
"transpose",
],
[
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.attention.query.bias",
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.attention.query.bias",
],
[
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.attention.key.weight",
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.attention.key.weight",
"transpose",
],
[
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.attention.key.bias",
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.attention.key.bias",
],
[
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.attention.value.weight",
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.attention.value.weight",
"transpose",
],
[
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.attention.value.bias",
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.attention.value.bias",
],
[
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.attention.dense.weight",
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.attention.dense.weight",
"transpose",
],
[
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.attention.dense.bias",
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.attention.dense.bias",
],
[
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.attention.LayerNorm.weight",
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.attention.layer_norm.weight",
],
[
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.attention.LayerNorm.bias",
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.attention.layer_norm.bias",
],
[
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.ffn.weight",
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.ffn.weight",
"transpose",
],
[
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.ffn.bias",
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.ffn.bias",
],
[
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.ffn_output.weight",
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.ffn_output.weight",
"transpose",
],
[
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.ffn_output.bias",
f"encoder.albert_layer_groups.{group_index}.albert_layers.0.ffn_output.bias",
],
]
model_mappings.extend(group_mappings)

# base-model prefix "AlbertModel"
if "AlbertModel" not in config.architectures:
for mapping in model_mappings:
mapping[0] = "albert." + mapping[0]
mapping[1] = "transformer." + mapping[1]

# downstream mappings
if "AlbertForQuestionAnswering" in config.architectures:
model_mappings.extend(
[["qa_outputs.weight", "classifier.weight", "transpose"], ["qa_outputs.bias", "classifier.bias"]]
)
if (
"AlbertForMultipleChoice" in config.architectures
or "AlbertForSequenceClassification" in config.architectures
or "AlbertForTokenClassification" in config.architectures
):
model_mappings.extend(
[["classifier.weight", "classifier.weight", "transpose"], ["classifier.bias", "classifier.bias"]]
)

mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)]
return mappings

def init_weights(self):
# Initialize weights
self.apply(self._init_weights)
Expand Down
149 changes: 147 additions & 2 deletions tests/transformers/albert/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import tempfile
import unittest

import numpy as np
import paddle
from paddle import Tensor
from parameterized import parameterized_class
from parameterized import parameterized, parameterized_class

from paddlenlp.transformers import (
AlbertConfig,
Expand All @@ -30,7 +33,7 @@
AlbertPretrainedModel,
)

from ...testing_utils import slow
from ...testing_utils import require_package, slow
from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask


Expand Down Expand Up @@ -337,6 +340,148 @@ def test_model_from_pretrained(self):
self.assertIsNotNone(model)


class AlbertModelCompatibilityTest(unittest.TestCase):
model_id = "hf-internal-testing/tiny-random-AlbertModel"

@require_package("transformers", "torch")
def test_albert_converter(self):
with tempfile.TemporaryDirectory() as tempdir:
# 1. create input
input_ids = np.random.randint(100, 200, [1, 20])

# 2. forward the paddle model
from paddlenlp.transformers import AlbertModel

paddle_model = AlbertModel.from_pretrained(self.model_id, from_hf_hub=True, cache_dir=tempdir)
paddle_model.eval()
paddle_logit = paddle_model(paddle.to_tensor(input_ids))[0]

# 3. forward the torch model
import torch
from transformers import AlbertModel

torch_model = AlbertModel.from_pretrained(self.model_id, cache_dir=tempdir)
torch_model.eval()
torch_logit = torch_model(torch.tensor(input_ids), return_dict=False)[0]

# 4. compare results
self.assertTrue(
np.allclose(
paddle_logit.detach().cpu().reshape([-1])[:9].numpy(),
torch_logit.detach().cpu().reshape([-1])[:9].numpy(),
rtol=1e-4,
)
)

@require_package("transformers", "torch")
def test_albert_converter_from_local_dir_with_enable_torch(self):
with tempfile.TemporaryDirectory() as tempdir:
# 1. forward the torch model
from transformers import AlbertModel

torch_model = AlbertModel.from_pretrained(self.model_id)
torch_model.save_pretrained(tempdir)

# 2. forward the paddle model
from paddlenlp.transformers import AlbertModel, model_utils

model_utils.ENABLE_TORCH_CHECKPOINT = False

with self.assertRaises(ValueError) as error:
AlbertModel.from_pretrained(tempdir)
self.assertIn("conversion is been disabled" in str(error.exception))
model_utils.ENABLE_TORCH_CHECKPOINT = True

@require_package("transformers", "torch")
def test_albert_converter_from_local_dir(self):
with tempfile.TemporaryDirectory() as tempdir:

# 1. create commmon input
input_ids = np.random.randint(100, 200, [1, 20])

# 2. forward the torch model
import torch
from transformers import AlbertModel

torch_model = AlbertModel.from_pretrained(self.model_id)
torch_model.eval()
torch_model.save_pretrained(tempdir)
torch_logit = torch_model(torch.tensor(input_ids), return_dict=False)[0]

# 2. forward the paddle model
from paddlenlp.transformers import AlbertModel

paddle_model = AlbertModel.from_pretrained(tempdir)
paddle_model.eval()
paddle_logit = paddle_model(paddle.to_tensor(input_ids))[0]

self.assertTrue(
np.allclose(
paddle_logit.detach().cpu().reshape([-1])[:9].numpy(),
torch_logit.detach().cpu().reshape([-1])[:9].numpy(),
rtol=1e-4,
)
)

@parameterized.expand(
[
("AlbertModel",),
# ("AlbertForMaskedLM",), TODO: need to tie weights
# ("AlbertForPretraining",), TODO: need to tie weights
Comment on lines +429 to +430
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#5623 合入之后就可以用 tie_weights 了。

("AlbertForMultipleChoice",),
# ("AlbertForQuestionAnswering",), TODO: transformers NOT add the last pool layer before qa_outputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可通过 architectures 来控制 pooler 的映射。

("AlbertForSequenceClassification",),
("AlbertForTokenClassification",),
]
)
@require_package("transformers", "torch")
def test_albert_classes_from_local_dir(self, class_name, pytorch_class_name=None):
pytorch_class_name = pytorch_class_name or class_name
with tempfile.TemporaryDirectory() as tempdir:

# 1. create commmon input
input_ids = np.random.randint(100, 200, [1, 20])

# 2. forward the torch model
import torch
import transformers

torch_model_class = getattr(transformers, pytorch_class_name)
torch_model = torch_model_class.from_pretrained(self.model_id)
torch_model.eval()

if "MultipleChoice" in class_name:
# construct input for MultipleChoice Model
torch_model.config.num_choices = random.randint(2, 10)
input_ids = (
paddle.to_tensor(input_ids)
.unsqueeze(1)
.expand([-1, torch_model.config.num_choices, -1])
.cpu()
.numpy()
)

torch_model.save_pretrained(tempdir)
torch_logit = torch_model(torch.tensor(input_ids), return_dict=False)[0]

# 3. forward the paddle model
from paddlenlp import transformers

paddle_model_class = getattr(transformers, class_name)
paddle_model = paddle_model_class.from_pretrained(tempdir)
paddle_model.eval()

paddle_logit = paddle_model(paddle.to_tensor(input_ids), return_dict=False)[0]

self.assertTrue(
np.allclose(
paddle_logit.detach().cpu().reshape([-1])[:9].numpy(),
torch_logit.detach().cpu().reshape([-1])[:9].numpy(),
atol=1e-3,
)
)


class AlbertModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_no_head_absolute_embedding(self):
Expand Down