Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix resize token embedding method #2763

Merged
merged 8 commits into from
Jul 14, 2022

Conversation

wj-Mcat
Copy link
Contributor

@wj-Mcat wj-Mcat commented Jul 8, 2022

PR types

Bug fixes

PR changes

Models

Description

there are some latent bugs in the resize_token_embedding method according to the unit test.

@wj-Mcat
Copy link
Contributor Author

wj-Mcat commented Jul 8, 2022

ping @yingyibiao

Comment on lines 520 to 521
if not new_num_tokens or new_num_tokens == len(old_embeddings):
return old_embeddings
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

old_embeddings 的类型是 Embedding,应该没有 len() 方法。

Copy link
Contributor Author

@wj-Mcat wj-Mcat Jul 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, 我用weight.shape[0]替换

try:
old_embeddings = self.get_input_embeddings()
except NotImplementedError:
raise NotImplementedError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not a good choice to catch a NotImplementedError and then throw a NotImplementedError, maybe what we need to do is just to add error message into get_input_embeddings

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should add error message into the get_input_embedding method.

except NotImplementedError:
raise NotImplementedError(
f'model of {type(self)} has not implemented the `get_input_embedding` or `set_input_embedding` '
'method, please use the another model to call `resize_token_embeddings` method'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use the another model is not so as expected, maybe users can implement it.

@wj-Mcat
Copy link
Contributor Author

wj-Mcat commented Jul 11, 2022

I have completed all of above comments, please review it, thanks.

@wj-Mcat
Copy link
Contributor Author

wj-Mcat commented Jul 11, 2022

I will share the unit test code for you:

import numpy as np
import os
from typing import Type, List, Tuple
import shutil
import unittest
from multiprocessing import Process
from tempfile import TemporaryDirectory
from parameterized import parameterized

from paddle import nn
from paddlenlp.transformers.model_utils import PretrainedModel, MODEL_HOME
from paddlenlp.transformers.tokenizer_utils import PretrainedTokenizer

from paddlenlp.transformers import BertTokenizer, BasicTokenizer, WordpieceTokenizer

from paddlenlp.transformers.bert.modeling import BertForPretraining
from paddlenlp.transformers.gpt.modeling import GPTForPretraining
from paddlenlp.transformers.tinybert.modeling import TinyBertForPretraining

from paddlenlp.transformers.bert.tokenizer import BertTokenizer
from paddlenlp.transformers.gpt.tokenizer import GPTTokenizer, GPTChineseTokenizer
from paddlenlp.transformers.tinybert.tokenizer import TinyBertTokenizer

from tests.common_test import CpuCommonTest, CommonTest
from tests.util import slow, assert_raises
def get_pretrained_models_params() -> List[Tuple[str, Type[PretrainedModel]]]:
    """get all of pretrained model names in some PretrainedModels

    Returns:
        List[Tuple[str, Type[PretrainedModel]]]: the parameters of unit test method
    """
    # from paddlenlp.transformers.electra.modeling import ElectraForTotalPretraining
    from paddlenlp.transformers.ctrl.modeling import CTRLModel 
    model_types: List[PretrainedModel] = [
        BertForPretraining, GPTForPretraining, TinyBertForPretraining, CTRLModel
    ]
    name_class_tuples: List[Tuple[str, Type[PretrainedModel]]] = []
    for ModelType in model_types:
        for model_name in ModelType.pretrained_resource_files_map.get(
                'model_state', {}).keys():
            name_class_tuples.append([model_name, ModelType])
    return name_class_tuples


class TestPretrainedFromPretrained(CpuCommonTest):
    """module for test pretrained model"""

    @parameterized.expand(get_pretrained_models_params())
    def test_resize_token_embedding(self, model_name: str,
                              PretrainedModelClass: Type[PretrainedModel]):

        cache_dir = os.path.join(MODEL_HOME, model_name)

        model: PretrainedModelClass = PretrainedModelClass.from_pretrained(
            model_name)
        
        vocab_size = model.base_model.config['vocab_size']
        model.resize_token_embeddings(vocab_size + 10)
        assert model.base_model.config["vocab_size"] == vocab_size + 10 

you should modify something to make it run in your code context. hope that can help you.


# Update vocab_size
self.vocab_size = new_num_tokens
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个也保留吧,和HF一致

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done,冲突也已经解决了

@guoshengCS
Copy link
Contributor

这里和最新的代码冲突了还需要解决下

@wj-Mcat
Copy link
Contributor Author

wj-Mcat commented Jul 12, 2022

Done @guoshengCS

@wj-Mcat
Copy link
Contributor Author

wj-Mcat commented Jul 12, 2022

ping @yingyibiao

Copy link
Contributor

@yingyibiao yingyibiao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@guoshengCS guoshengCS merged commit 48d52c5 into PaddlePaddle:develop Jul 14, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants