diff --git a/paddlenlp/transformers/t5/modeling.py b/paddlenlp/transformers/t5/modeling.py index b9c26b751737..1e8195d4c424 100644 --- a/paddlenlp/transformers/t5/modeling.py +++ b/paddlenlp/transformers/t5/modeling.py @@ -25,7 +25,9 @@ from paddle import Tensor from paddle.distributed.fleet.utils import recompute +from ...utils.converter import StateDictNameMapping from ...utils.log import logger +from ..activations import ACT2FN from ..model_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -34,7 +36,6 @@ convert_encoder_output, ) from ..model_utils import PretrainedModel, register_base_model -from ..nezha.modeling import ACT2FN from .configuration import ( T5_PRETRAINED_INIT_CONFIGURATION, T5_PRETRAINED_RESOURCE_FILES_MAP, @@ -571,6 +572,123 @@ class T5PretrainedModel(PretrainedModel): pretrained_init_configuration = T5_PRETRAINED_INIT_CONFIGURATION pretrained_resource_files_map = T5_PRETRAINED_RESOURCE_FILES_MAP + @classmethod + def _get_name_mappings(cls, config: T5Config) -> list[StateDictNameMapping]: + mappings: list[StateDictNameMapping] = [] + model_mappings = [ + ["shared.weight", "shared.weight"], + ["encoder.embed_tokens.weight", "encoder.embed_tokens.weight"], + ["encoder.final_layer_norm.weight", "encoder.final_layer_norm.weight"], + ["decoder.embed_tokens.weight", "decoder.embed_tokens.weight"], + ["decoder.final_layer_norm.weight", "decoder.final_layer_norm.weight"], + [ + "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight", + "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight", + ], + [ + "decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight", + "decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight", + ], + ] + for layer_index in range(config.num_hidden_layers): + for att_head in ["q", "k", "v", "o"]: + model_mappings.extend( + [ + [ + f"encoder.block.{layer_index}.layer.0.SelfAttention.{att_head}.weight", + f"encoder.block.{layer_index}.layer.0.SelfAttention.{att_head}.weight", + "transpose", + ], + [ + f"decoder.block.{layer_index}.layer.0.SelfAttention.{att_head}.weight", + f"decoder.block.{layer_index}.layer.0.SelfAttention.{att_head}.weight", + "transpose", + ], + [ + f"decoder.block.{layer_index}.layer.1.EncDecAttention.{att_head}.weight", + f"decoder.block.{layer_index}.layer.1.EncDecAttention.{att_head}.weight", + "transpose", + ], + ] + ) + + layer_mappings = [ + [ + f"encoder.block.{layer_index}.layer.1.DenseReluDense.wo.weight", + f"encoder.block.{layer_index}.layer.1.DenseReluDense.wo.weight", + "transpose", + ], + [ + f"decoder.block.{layer_index}.layer.2.DenseReluDense.wo.weight", + f"decoder.block.{layer_index}.layer.2.DenseReluDense.wo.weight", + "transpose", + ], + [ + f"encoder.block.{layer_index}.layer.0.layer_norm.weight", + f"encoder.block.{layer_index}.layer.0.layer_norm.weight", + ], + [ + f"encoder.block.{layer_index}.layer.1.layer_norm.weight", + f"encoder.block.{layer_index}.layer.1.layer_norm.weight", + ], + [ + f"decoder.block.{layer_index}.layer.0.layer_norm.weight", + f"decoder.block.{layer_index}.layer.0.layer_norm.weight", + ], + [ + f"decoder.block.{layer_index}.layer.1.layer_norm.weight", + f"decoder.block.{layer_index}.layer.1.layer_norm.weight", + ], + [ + f"decoder.block.{layer_index}.layer.2.layer_norm.weight", + f"decoder.block.{layer_index}.layer.2.layer_norm.weight", + ], + ] + + if config.feed_forward_proj == "relu": + layer_mappings.extend( + [ + [ + f"encoder.block.{layer_index}.layer.1.DenseReluDense.wi.weight", + f"encoder.block.{layer_index}.layer.1.DenseReluDense.wi.weight", + "transpose", + ], + [ + f"decoder.block.{layer_index}.layer.2.DenseReluDense.wi.weight", + f"decoder.block.{layer_index}.layer.2.DenseReluDense.wi.weight", + "transpose", + ], + ] + ) + elif config.feed_forward_proj == "gated-gelu": + for i in range(2): + layer_mappings.extend( + [ + [ + f"encoder.block.{layer_index}.layer.1.DenseReluDense.wi_{i}.weight", + f"encoder.block.{layer_index}.layer.1.DenseReluDense.wi_{i}.weight", + "transpose", + ], + [ + f"decoder.block.{layer_index}.layer.2.DenseReluDense.wi_{i}.weight", + f"decoder.block.{layer_index}.layer.2.DenseReluDense.wi_{i}.weight", + "transpose", + ], + ] + ) + + model_mappings.extend(layer_mappings) + + if cls.__name__ != "T5Model": + for mapping in model_mappings: + mapping[1] = "t5." + mapping[1] + + if config.architectures is not None and "T5ForConditionalGeneration" in config.architectures: + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + + mappings = [StateDictNameMapping(*mapping) for mapping in model_mappings] + return mappings + @property def dummy_inputs(self): DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] diff --git a/paddlenlp/utils/serialization.py b/paddlenlp/utils/serialization.py index 7b7b35dac403..050931c1ec1c 100644 --- a/paddlenlp/utils/serialization.py +++ b/paddlenlp/utils/serialization.py @@ -34,9 +34,10 @@ def __init__(self, key: str, n_bytes: int, dtype: str): self.nbytes = n_bytes self.dtype = dtype self.size = None + self.stride = None def __repr__(self): - return f"size: {self.size} key: {self.key}, nbytes: {self.nbytes}, dtype: {self.dtype}" + return f"size: {self.size} key: {self.key}, nbytes: {self.nbytes}, dtype: {self.dtype}, stride: {self.stride}" class SerializationError(Exception): @@ -123,6 +124,7 @@ def get_data_iostream(file: str, file_name="data.pkl"): def _rebuild_tensor_stage(storage, storage_offset, size, stride, requires_grad, backward_hooks): if isinstance(storage, TensorMeta): storage.size = size + storage.stride = stride return storage @@ -237,9 +239,21 @@ def extract_maybe_dict(result): file_handler.seek(padding_offset, 1) # save the tensor info in result to re-use memory - stage1_key_to_tensor[key] = np.frombuffer( - file_handler.read(tensor_meta.nbytes), dtype=tensor_meta.dtype - ).reshape(tensor_meta.size) + np_buffer = np.frombuffer(file_handler.read(tensor_meta.nbytes), dtype=tensor_meta.dtype) + + # if a tensor has shape [M, N] and stride is [1, N], it's column-wise / fortran-style + # if a tensor has shape [M, N] and stride is [M, 1], it's row-wise / C-style + # defautls to C-style + if ( + tensor_meta.stride is not None + and len(tensor_meta.stride) > 1 + and tensor_meta.stride[0] == 1 + and tensor_meta.stride[1] > 1 + ): + order = "F" + else: + order = "C" + stage1_key_to_tensor[key] = np_buffer.reshape(tensor_meta.size, order=order) def persistent_load_stage2(saved_id): assert isinstance(saved_id, tuple) diff --git a/tests/transformers/t5/test_modeling.py b/tests/transformers/t5/test_modeling.py index 575c5b4742c1..c927720cf88f 100644 --- a/tests/transformers/t5/test_modeling.py +++ b/tests/transformers/t5/test_modeling.py @@ -14,6 +14,7 @@ # limitations under the License. import random +import tempfile import unittest import numpy as np @@ -28,7 +29,7 @@ ) from paddlenlp.transformers.t5.configuration import T5Config from paddlenlp.transformers.t5.modeling import T5_PRETRAINED_MODEL_ARCHIVE_LIST -from tests.testing_utils import slow +from tests.testing_utils import require_package, slow from ..test_generation_utils import GenerationTesterMixin from ..test_modeling_common import ModelTesterMixin, ids_tensor @@ -703,6 +704,125 @@ def test_model_name_list(self): pass +class T5CompatibilityTest(unittest.TestCase): + @require_package("transformers", "torch") + def test_t5_converter(self): + with tempfile.TemporaryDirectory() as tempdir: + model_id = "hf-internal-testing/tiny-random-T5Model" + # 1. create commmon input + input_ids = np.array([[i for i in range(10)]]) + + # 2. forward the paddle model + from paddlenlp.transformers import T5Model + + paddle_model = T5Model.from_pretrained(model_id, from_hf_hub=True, cache_dir=tempdir) + paddle_model.eval() + paddle_logit = paddle_model( + input_ids=paddle.to_tensor(input_ids), decoder_input_ids=paddle.to_tensor(input_ids) + )[0][0] + + # 3. forward the torch model + import torch + from transformers import T5Model + + torch_model = T5Model.from_pretrained(model_id, cache_dir=tempdir) + torch_model.eval() + torch_logit = torch_model( + input_ids=torch.tensor(input_ids), decoder_input_ids=torch.tensor(input_ids), return_dict=False + )[0][0] + self.assertTrue( + np.allclose( + paddle_logit.detach().cpu().numpy()[:4, :4], torch_logit.detach().cpu().numpy()[:4, :4], rtol=1e-4 + ) + ) + + @require_package("transformers", "torch") + def test_t5_converter_from_local_dir_with_enable_torch(self): + with tempfile.TemporaryDirectory() as tempdir: + model_id = "hf-internal-testing/tiny-random-T5Model" + # 1. forward the torch model + from transformers import T5Model + + torch_model = T5Model.from_pretrained(model_id) + torch_model.save_pretrained(tempdir) + + # 2. forward the paddle model + from paddlenlp.transformers import T5Model, model_utils + + model_utils.ENABLE_TORCH_CHECKPOINT = False + + with self.assertRaises(ValueError) as error: + T5Model.from_pretrained(tempdir) + self.assertIn("conversion is been disabled" in str(error.exception)) + model_utils.ENABLE_TORCH_CHECKPOINT = True + + @require_package("transformers", "torch") + def test_t5_converter_from_local_dir(self): + with tempfile.TemporaryDirectory() as tempdir: + model_id = "hf-internal-testing/tiny-random-T5Model" + # 1. create commmon input + input_ids = np.array([[i for i in range(10)]]) + + # 2. forward the torch model + import torch + from transformers import T5Model + + torch_model = T5Model.from_pretrained(model_id) + torch_model.eval() + torch_model.save_pretrained(tempdir) + torch_logit = torch_model( + input_ids=torch.tensor(input_ids), decoder_input_ids=torch.tensor(input_ids), return_dict=False + )[0][0] + + # 2. forward the paddle model + from paddlenlp.transformers import T5Model + + paddle_model = T5Model.from_pretrained(tempdir) + paddle_model.eval() + paddle_logit = paddle_model( + input_ids=paddle.to_tensor(input_ids), decoder_input_ids=paddle.to_tensor(input_ids) + )[0][0] + + self.assertTrue( + np.allclose( + paddle_logit.detach().cpu().numpy()[:4, :4], torch_logit.detach().cpu().numpy()[:4, :4], rtol=1e-4 + ) + ) + + @require_package("transformers", "torch") + def test_t5_for_conditional_generation(self): + with tempfile.TemporaryDirectory() as tempdir: + model_id = "hf-internal-testing/tiny-random-T5Model" + # 1. create commmon input + input_ids = np.array([[i for i in range(10)]]) + + # 2. forward the torch model + import torch + from transformers import T5ForConditionalGeneration + + torch_model = T5ForConditionalGeneration.from_pretrained(model_id) + torch_model.eval() + torch_model.save_pretrained(tempdir) + torch_logit = torch_model( + input_ids=torch.tensor(input_ids), decoder_input_ids=torch.tensor(input_ids), return_dict=False + )[0][0] + + # 2. forward the paddle model + from paddlenlp.transformers import T5ForConditionalGeneration + + paddle_model = T5ForConditionalGeneration.from_pretrained(tempdir) + paddle_model.eval() + paddle_logit = paddle_model( + input_ids=paddle.to_tensor(input_ids), decoder_input_ids=paddle.to_tensor(input_ids) + )[0][0] + + self.assertTrue( + np.allclose( + paddle_logit.detach().cpu().numpy()[:4, :4], torch_logit.detach().cpu().numpy()[:4, :4], rtol=1e-4 + ) + ) + + class T5ModelIntegrationTests(unittest.TestCase): def model(self): return T5ForConditionalGeneration.from_pretrained("t5-base")