Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add T5 to AutoConverter #4477

Merged
merged 6 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 119 additions & 1 deletion paddlenlp/transformers/t5/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from paddle import Tensor
from paddle.distributed.fleet.utils import recompute

from ...utils.converter import StateDictNameMapping
from ...utils.log import logger
from ..activations import ACT2FN
from ..model_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Expand All @@ -34,7 +36,6 @@
convert_encoder_output,
)
from ..model_utils import PretrainedModel, register_base_model
from ..nezha.modeling import ACT2FN
from .configuration import (
T5_PRETRAINED_INIT_CONFIGURATION,
T5_PRETRAINED_RESOURCE_FILES_MAP,
Expand Down Expand Up @@ -571,6 +572,123 @@ class T5PretrainedModel(PretrainedModel):
pretrained_init_configuration = T5_PRETRAINED_INIT_CONFIGURATION
pretrained_resource_files_map = T5_PRETRAINED_RESOURCE_FILES_MAP

@classmethod
def _get_name_mappings(cls, config: T5Config) -> list[StateDictNameMapping]:
mappings: list[StateDictNameMapping] = []
model_mappings = [
["shared.weight", "shared.weight"],
["encoder.embed_tokens.weight", "encoder.embed_tokens.weight"],
["encoder.final_layer_norm.weight", "encoder.final_layer_norm.weight"],
["decoder.embed_tokens.weight", "decoder.embed_tokens.weight"],
["decoder.final_layer_norm.weight", "decoder.final_layer_norm.weight"],
[
"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight",
"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight",
],
[
"decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight",
"decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight",
],
]
for layer_index in range(config.num_hidden_layers):
for att_head in ["q", "k", "v", "o"]:
model_mappings.extend(
[
[
f"encoder.block.{layer_index}.layer.0.SelfAttention.{att_head}.weight",
f"encoder.block.{layer_index}.layer.0.SelfAttention.{att_head}.weight",
"transpose",
],
[
f"decoder.block.{layer_index}.layer.0.SelfAttention.{att_head}.weight",
f"decoder.block.{layer_index}.layer.0.SelfAttention.{att_head}.weight",
"transpose",
],
[
f"decoder.block.{layer_index}.layer.1.EncDecAttention.{att_head}.weight",
f"decoder.block.{layer_index}.layer.1.EncDecAttention.{att_head}.weight",
"transpose",
],
]
)

layer_mappings = [
[
f"encoder.block.{layer_index}.layer.1.DenseReluDense.wo.weight",
f"encoder.block.{layer_index}.layer.1.DenseReluDense.wo.weight",
"transpose",
],
[
f"decoder.block.{layer_index}.layer.2.DenseReluDense.wo.weight",
f"decoder.block.{layer_index}.layer.2.DenseReluDense.wo.weight",
"transpose",
],
[
f"encoder.block.{layer_index}.layer.0.layer_norm.weight",
f"encoder.block.{layer_index}.layer.0.layer_norm.weight",
],
[
f"encoder.block.{layer_index}.layer.1.layer_norm.weight",
f"encoder.block.{layer_index}.layer.1.layer_norm.weight",
],
[
f"decoder.block.{layer_index}.layer.0.layer_norm.weight",
f"decoder.block.{layer_index}.layer.0.layer_norm.weight",
],
[
f"decoder.block.{layer_index}.layer.1.layer_norm.weight",
f"decoder.block.{layer_index}.layer.1.layer_norm.weight",
],
[
f"decoder.block.{layer_index}.layer.2.layer_norm.weight",
f"decoder.block.{layer_index}.layer.2.layer_norm.weight",
],
]

if config.feed_forward_proj == "relu":
layer_mappings.extend(
[
[
f"encoder.block.{layer_index}.layer.1.DenseReluDense.wi.weight",
f"encoder.block.{layer_index}.layer.1.DenseReluDense.wi.weight",
"transpose",
],
[
f"decoder.block.{layer_index}.layer.2.DenseReluDense.wi.weight",
f"decoder.block.{layer_index}.layer.2.DenseReluDense.wi.weight",
"transpose",
],
]
)
elif config.feed_forward_proj == "gated-gelu":
for i in range(2):
layer_mappings.extend(
[
[
f"encoder.block.{layer_index}.layer.1.DenseReluDense.wi_{i}.weight",
f"encoder.block.{layer_index}.layer.1.DenseReluDense.wi_{i}.weight",
"transpose",
],
[
f"decoder.block.{layer_index}.layer.2.DenseReluDense.wi_{i}.weight",
f"decoder.block.{layer_index}.layer.2.DenseReluDense.wi_{i}.weight",
"transpose",
],
]
)

model_mappings.extend(layer_mappings)

if cls.__name__ != "T5Model":
for mapping in model_mappings:
mapping[1] = "t5." + mapping[1]

if config.architectures is not None and "T5ForConditionalGeneration" in config.architectures:
model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"])

mappings = [StateDictNameMapping(*mapping) for mapping in model_mappings]
return mappings

@property
def dummy_inputs(self):
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
Expand Down
22 changes: 18 additions & 4 deletions paddlenlp/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ def __init__(self, key: str, n_bytes: int, dtype: str):
self.nbytes = n_bytes
self.dtype = dtype
self.size = None
self.stride = None

def __repr__(self):
return f"size: {self.size} key: {self.key}, nbytes: {self.nbytes}, dtype: {self.dtype}"
return f"size: {self.size} key: {self.key}, nbytes: {self.nbytes}, dtype: {self.dtype}, stride: {self.stride}"


class SerializationError(Exception):
Expand Down Expand Up @@ -123,6 +124,7 @@ def get_data_iostream(file: str, file_name="data.pkl"):
def _rebuild_tensor_stage(storage, storage_offset, size, stride, requires_grad, backward_hooks):
if isinstance(storage, TensorMeta):
storage.size = size
storage.stride = stride
return storage


Expand Down Expand Up @@ -237,9 +239,21 @@ def extract_maybe_dict(result):
file_handler.seek(padding_offset, 1)

# save the tensor info in result to re-use memory
stage1_key_to_tensor[key] = np.frombuffer(
file_handler.read(tensor_meta.nbytes), dtype=tensor_meta.dtype
).reshape(tensor_meta.size)
np_buffer = np.frombuffer(file_handler.read(tensor_meta.nbytes), dtype=tensor_meta.dtype)

# if a tensor has shape [M, N] and stride is [1, N], it's column-wise / fortran-style
# if a tensor has shape [M, N] and stride is [M, 1], it's row-wise / C-style
# defautls to C-style
if (
tensor_meta.stride is not None
and len(tensor_meta.stride) > 1
and tensor_meta.stride[0] == 1
and tensor_meta.stride[1] > 1
):
order = "F"
else:
order = "C"
stage1_key_to_tensor[key] = np_buffer.reshape(tensor_meta.size, order=order)

def persistent_load_stage2(saved_id):
assert isinstance(saved_id, tuple)
Expand Down
122 changes: 121 additions & 1 deletion tests/transformers/t5/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import random
import tempfile
import unittest

import numpy as np
Expand All @@ -28,7 +29,7 @@
)
from paddlenlp.transformers.t5.configuration import T5Config
from paddlenlp.transformers.t5.modeling import T5_PRETRAINED_MODEL_ARCHIVE_LIST
from tests.testing_utils import slow
from tests.testing_utils import require_package, slow

from ..test_generation_utils import GenerationTesterMixin
from ..test_modeling_common import ModelTesterMixin, ids_tensor
Expand Down Expand Up @@ -703,6 +704,125 @@ def test_model_name_list(self):
pass


class T5CompatibilityTest(unittest.TestCase):
@require_package("transformers", "torch")
def test_t5_converter(self):
with tempfile.TemporaryDirectory() as tempdir:
model_id = "hf-internal-testing/tiny-random-T5Model"
# 1. create commmon input
input_ids = np.array([[i for i in range(10)]])

# 2. forward the paddle model
from paddlenlp.transformers import T5Model

paddle_model = T5Model.from_pretrained(model_id, from_hf_hub=True, cache_dir=tempdir)
paddle_model.eval()
paddle_logit = paddle_model(
input_ids=paddle.to_tensor(input_ids), decoder_input_ids=paddle.to_tensor(input_ids)
)[0][0]

# 3. forward the torch model
import torch
from transformers import T5Model

torch_model = T5Model.from_pretrained(model_id, cache_dir=tempdir)
torch_model.eval()
torch_logit = torch_model(
input_ids=torch.tensor(input_ids), decoder_input_ids=torch.tensor(input_ids), return_dict=False
)[0][0]
self.assertTrue(
np.allclose(
paddle_logit.detach().cpu().numpy()[:4, :4], torch_logit.detach().cpu().numpy()[:4, :4], rtol=1e-4
)
)

@require_package("transformers", "torch")
def test_t5_converter_from_local_dir_with_enable_torch(self):
with tempfile.TemporaryDirectory() as tempdir:
model_id = "hf-internal-testing/tiny-random-T5Model"
# 1. forward the torch model
from transformers import T5Model

torch_model = T5Model.from_pretrained(model_id)
torch_model.save_pretrained(tempdir)

# 2. forward the paddle model
from paddlenlp.transformers import T5Model, model_utils

model_utils.ENABLE_TORCH_CHECKPOINT = False

with self.assertRaises(ValueError) as error:
T5Model.from_pretrained(tempdir)
self.assertIn("conversion is been disabled" in str(error.exception))
model_utils.ENABLE_TORCH_CHECKPOINT = True

@require_package("transformers", "torch")
def test_t5_converter_from_local_dir(self):
with tempfile.TemporaryDirectory() as tempdir:
model_id = "hf-internal-testing/tiny-random-T5Model"
# 1. create commmon input
input_ids = np.array([[i for i in range(10)]])

# 2. forward the torch model
import torch
from transformers import T5Model

torch_model = T5Model.from_pretrained(model_id)
torch_model.eval()
torch_model.save_pretrained(tempdir)
torch_logit = torch_model(
input_ids=torch.tensor(input_ids), decoder_input_ids=torch.tensor(input_ids), return_dict=False
)[0][0]

# 2. forward the paddle model
from paddlenlp.transformers import T5Model

paddle_model = T5Model.from_pretrained(tempdir)
paddle_model.eval()
paddle_logit = paddle_model(
input_ids=paddle.to_tensor(input_ids), decoder_input_ids=paddle.to_tensor(input_ids)
)[0][0]

self.assertTrue(
np.allclose(
paddle_logit.detach().cpu().numpy()[:4, :4], torch_logit.detach().cpu().numpy()[:4, :4], rtol=1e-4
)
)

@require_package("transformers", "torch")
def test_t5_for_conditional_generation(self):
with tempfile.TemporaryDirectory() as tempdir:
model_id = "hf-internal-testing/tiny-random-T5Model"
# 1. create commmon input
input_ids = np.array([[i for i in range(10)]])

# 2. forward the torch model
import torch
from transformers import T5ForConditionalGeneration

torch_model = T5ForConditionalGeneration.from_pretrained(model_id)
torch_model.eval()
torch_model.save_pretrained(tempdir)
torch_logit = torch_model(
input_ids=torch.tensor(input_ids), decoder_input_ids=torch.tensor(input_ids), return_dict=False
)[0][0]

# 2. forward the paddle model
from paddlenlp.transformers import T5ForConditionalGeneration

paddle_model = T5ForConditionalGeneration.from_pretrained(tempdir)
paddle_model.eval()
paddle_logit = paddle_model(
input_ids=paddle.to_tensor(input_ids), decoder_input_ids=paddle.to_tensor(input_ids)
)[0][0]

self.assertTrue(
np.allclose(
paddle_logit.detach().cpu().numpy()[:4, :4], torch_logit.detach().cpu().numpy()[:4, :4], rtol=1e-4
)
)


class T5ModelIntegrationTests(unittest.TestCase):
def model(self):
return T5ForConditionalGeneration.from_pretrained("t5-base")
Expand Down