From 337eda19a361a3b4e43614657d07e36a50451ba1 Mon Sep 17 00:00:00 2001 From: lvdongyi Date: Thu, 17 Oct 2024 08:45:56 +0000 Subject: [PATCH 1/4] Add , support register tokenizer, fix some typo --- paddlenlp/transformers/auto/tokenizer.py | 52 +++++- paddlenlp/transformers/bert/tokenizer_fast.py | 164 ++++++++++++++++++ .../transformers/convert_slow_tokenizer.py | 48 ++++- .../transformers/tokenizer_utils_base.py | 3 + .../transformers/tokenizer_utils_fast.py | 2 +- tests/transformers/auto/test_confiugration.py | 9 +- tests/transformers/auto/test_tokenizer.py | 73 ++++++++ tests/transformers/bert/test_tokenizer.py | 30 +++- tests/utils/test_module/__init__.py | 13 ++ .../utils/test_module/custom_configuration.py | 23 +++ tests/utils/test_module/custom_tokenizer.py | 19 ++ .../test_module/custom_tokenizer_fast.py | 22 +++ 12 files changed, 444 insertions(+), 14 deletions(-) create mode 100644 paddlenlp/transformers/bert/tokenizer_fast.py create mode 100644 tests/utils/test_module/__init__.py create mode 100644 tests/utils/test_module/custom_configuration.py create mode 100644 tests/utils/test_module/custom_tokenizer.py create mode 100644 tests/utils/test_module/custom_tokenizer_fast.py diff --git a/paddlenlp/transformers/auto/tokenizer.py b/paddlenlp/transformers/auto/tokenizer.py index c056346b1746..6fd8b5fcf2b0 100644 --- a/paddlenlp/transformers/auto/tokenizer.py +++ b/paddlenlp/transformers/auto/tokenizer.py @@ -24,6 +24,7 @@ from ...utils.import_utils import import_module from ...utils.log import logger from ..configuration_utils import PretrainedConfig +from ..tokenizer_utils import PretrainedTokenizer from ..tokenizer_utils_base import TOKENIZER_CONFIG_FILE from ..tokenizer_utils_fast import PretrainedTokenizerFast from .configuration import ( @@ -45,7 +46,13 @@ [ ("albert", (("AlbertChineseTokenizer", "AlbertEnglishTokenizer"), None)), ("bart", "BartTokenizer"), - ("bert", "BertTokenizer"), + ( + "bert", + ( + "BertTokenizer", + "BertTokenizerFast" if is_tokenizers_available() else None, + ), + ), ("blenderbot", "BlenderbotTokenizer"), ("bloom", "BloomTokenizer"), ("clip", "CLIPTokenizer"), @@ -459,3 +466,46 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): "- or a correct model-identifier of community-contributed pretrained models,\n" "- or the correct path to a directory containing relevant tokenizer files.\n" ) + + def register(config_class, slow_tokenizer_class=None, fast_tokenizer_class=None, exist_ok=False): + """ + Register a new tokenizer in this mapping. + + + Args: + config_class ([`PretrainedConfig`]): + The configuration corresponding to the model to register. + slow_tokenizer_class ([`PretrainedTokenizer`], *optional*): + The slow tokenizer to register. + fast_tokenizer_class ([`PretrainedTokenizerFast`], *optional*): + The fast tokenizer to register. + """ + if slow_tokenizer_class is None and fast_tokenizer_class is None: + raise ValueError("You need to pass either a `slow_tokenizer_class` or a `fast_tokenizer_class") + if slow_tokenizer_class is not None and issubclass(slow_tokenizer_class, PretrainedTokenizerFast): + raise ValueError("You passed a fast tokenizer in the `slow_tokenizer_class`.") + if fast_tokenizer_class is not None and issubclass(fast_tokenizer_class, PretrainedTokenizer): + raise ValueError("You passed a slow tokenizer in the `fast_tokenizer_class`.") + + if ( + slow_tokenizer_class is not None + and fast_tokenizer_class is not None + and issubclass(fast_tokenizer_class, PretrainedTokenizerFast) + and fast_tokenizer_class.slow_tokenizer_class != slow_tokenizer_class + ): + raise ValueError( + "The fast tokenizer class you are passing has a `slow_tokenizer_class` attribute that is not " + "consistent with the slow tokenizer class you passed (fast tokenizer has " + f"{fast_tokenizer_class.slow_tokenizer_class} and you passed {slow_tokenizer_class}. Fix one of those " + "so they match!" + ) + + # Avoid resetting a set slow/fast tokenizer if we are passing just the other ones. + if config_class in TOKENIZER_MAPPING._extra_content: + existing_slow, existing_fast = TOKENIZER_MAPPING[config_class] + if slow_tokenizer_class is None: + slow_tokenizer_class = existing_slow + if fast_tokenizer_class is None: + fast_tokenizer_class = existing_fast + + TOKENIZER_MAPPING.register(config_class, (slow_tokenizer_class, fast_tokenizer_class), exist_ok=exist_ok) diff --git a/paddlenlp/transformers/bert/tokenizer_fast.py b/paddlenlp/transformers/bert/tokenizer_fast.py new file mode 100644 index 000000000000..ba11e48f1b37 --- /dev/null +++ b/paddlenlp/transformers/bert/tokenizer_fast.py @@ -0,0 +1,164 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import List, Optional, Tuple + +from tokenizers import normalizers + +from ..tokenizer_utils_fast import PretrainedTokenizerFast +from .tokenizer import BertTokenizer + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} + + +class BertTokenizerFast(PretrainedTokenizerFast): + r""" + + This tokenizer inherits from [`PretrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + clean_text (`bool`, *optional*, defaults to `True`): + Whether or not to clean the text before tokenization by removing any control characters and replacing all + whitespaces by the classic one. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese. + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + wordpieces_prefix (`str`, *optional*, defaults to `"##"`): + The prefix for subwords. + """ + + resource_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = BertTokenizer + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + do_lower_case=True, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs, + ): + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + do_lower_case=do_lower_case, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) + if ( + normalizer_state.get("lowercase", do_lower_case) != do_lower_case + or normalizer_state.get("strip_accents", strip_accents) != strip_accents + or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars + ): + normalizer_class = getattr(normalizers, normalizer_state.pop("type")) + normalizer_state["lowercase"] = do_lower_case + normalizer_state["strip_accents"] = strip_accents + normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars + self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state) + + self.do_lower_case = do_lower_case + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + + if token_ids_1 is not None: + output += token_ids_1 + [self.sep_token_id] + + return output + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence + pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) diff --git a/paddlenlp/transformers/convert_slow_tokenizer.py b/paddlenlp/transformers/convert_slow_tokenizer.py index adc3c52130e6..3cbd4a07cd9c 100644 --- a/paddlenlp/transformers/convert_slow_tokenizer.py +++ b/paddlenlp/transformers/convert_slow_tokenizer.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import OrderedDict from typing import Dict, List, Optional, Tuple import tokenizers @@ -28,7 +29,7 @@ pre_tokenizers, processors, ) -from tokenizers.models import BPE, Unigram +from tokenizers.models import BPE, Unigram, WordPiece from paddlenlp.utils.import_utils import ( is_protobuf_available, @@ -330,6 +331,50 @@ def converted(self) -> Tokenizer: return tokenizer +class BertConverter(Converter): + def converted(self) -> Tokenizer: + vocab = self.original_tokenizer.vocab + tokenizer = Tokenizer( + WordPiece( + OrderedDict([(vocab._idx_to_token[i], i) for i in range(len(vocab))]), + unk_token=str(self.original_tokenizer.unk_token), + ) + ) + + tokenize_chinese_chars = False + strip_accents = False + do_lower_case = False + if hasattr(self.original_tokenizer, "basic_tokenizer"): + tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars + strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents + do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case + + tokenizer.normalizer = normalizers.BertNormalizer( + clean_text=True, + handle_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + lowercase=do_lower_case, + ) + tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer() + + cls = str(self.original_tokenizer.cls_token) + sep = str(self.original_tokenizer.sep_token) + cls_token_id = self.original_tokenizer.cls_token_id + sep_token_id = self.original_tokenizer.sep_token_id + + tokenizer.post_processor = processors.TemplateProcessing( + single=f"{cls}:0 $A:0 {sep}:0", + pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1", + special_tokens=[ + (cls, cls_token_id), + (sep, sep_token_id), + ], + ) + tokenizer.decoder = decoders.WordPiece(prefix="##") + + return tokenizer + + class LlamaConverter(SpmConverter): handle_byte_fallback = True @@ -399,6 +444,7 @@ def pre_tokenizer(self, replacement, add_prefix_space): SLOW_TO_FAST_CONVERTERS = { "LlamaTokenizer": LlamaConverter, + "BertTokenizer": BertConverter, } diff --git a/paddlenlp/transformers/tokenizer_utils_base.py b/paddlenlp/transformers/tokenizer_utils_base.py index e044e7e5830b..98edc25ac1e2 100644 --- a/paddlenlp/transformers/tokenizer_utils_base.py +++ b/paddlenlp/transformers/tokenizer_utils_base.py @@ -1851,6 +1851,9 @@ def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True): # Add tokenizer class to the tokenizer config to be able to reload it with from_pretrained tokenizer_class = self.__class__.__name__ + # Remove the Fast at the end unless we have a special `PreTrainedTokenizerFast` + if tokenizer_class.endswith("Fast") and tokenizer_class != "PreTrainedTokenizerFast": + tokenizer_class = tokenizer_class[:-4] tokenizer_config["tokenizer_class"] = tokenizer_class with io.open(tokenizer_config_file, "w", encoding="utf-8") as f: diff --git a/paddlenlp/transformers/tokenizer_utils_fast.py b/paddlenlp/transformers/tokenizer_utils_fast.py index 60fd432bb9d8..eecb7ef965ff 100644 --- a/paddlenlp/transformers/tokenizer_utils_fast.py +++ b/paddlenlp/transformers/tokenizer_utils_fast.py @@ -750,7 +750,7 @@ def train_new_from_iterator( Additional keyword arguments passed along to the trainer from the 🤗 Tokenizers library. Returns: - [`PreTrainedTokenizerFast`]: A new tokenizer of the same type as the original one, trained on + [`PretrainedTokenizerFast`]: A new tokenizer of the same type as the original one, trained on `text_iterator`. """ diff --git a/tests/transformers/auto/test_confiugration.py b/tests/transformers/auto/test_confiugration.py index eea37fd93a86..f5553bafa2c3 100644 --- a/tests/transformers/auto/test_confiugration.py +++ b/tests/transformers/auto/test_confiugration.py @@ -23,8 +23,8 @@ from paddlenlp.transformers import AutoConfig from paddlenlp.transformers.auto.configuration import CONFIG_MAPPING from paddlenlp.transformers.bert.configuration import BertConfig -from paddlenlp.transformers.configuration_utils import PretrainedConfig from paddlenlp.utils.env import CONFIG_NAME +from tests.utils.test_module.custom_configuration import CustomConfig class AutoConfigTest(unittest.TestCase): @@ -90,13 +90,6 @@ def test_load_from_legacy_config(self): self.assertEqual(auto_config.hidden_size, number) def test_new_config_registration(self): - class CustomConfig(PretrainedConfig): - model_type = "custom" - - def __init__(self, attribute=1, **kwargs): - self.attribute = attribute - super().__init__(**kwargs) - try: AutoConfig.register("custom", CustomConfig) # Wrong model type will raise an error diff --git a/tests/transformers/auto/test_tokenizer.py b/tests/transformers/auto/test_tokenizer.py index dfd7cecde848..77e54c4801ff 100644 --- a/tests/transformers/auto/test_tokenizer.py +++ b/tests/transformers/auto/test_tokenizer.py @@ -19,7 +19,15 @@ import paddlenlp from paddlenlp.transformers import AutoTokenizer +from paddlenlp.transformers.auto.configuration import CONFIG_MAPPING, AutoConfig +from paddlenlp.transformers.auto.tokenizer import TOKENIZER_MAPPING +from paddlenlp.transformers.bert.configuration import BertConfig +from paddlenlp.transformers.bert.tokenizer import BertTokenizer +from paddlenlp.transformers.bert.tokenizer_fast import BertTokenizerFast from paddlenlp.utils.env import TOKENIZER_CONFIG_NAME +from tests.utils.test_module.custom_configuration import CustomConfig +from tests.utils.test_module.custom_tokenizer import CustomTokenizer +from tests.utils.test_module.custom_tokenizer_fast import CustomTokenizerFast class AutoTokenizerTest(unittest.TestCase): @@ -35,3 +43,68 @@ def test_from_pretrained_cache_dir(self): self.assertTrue(os.path.exists(os.path.join(tempdir, model_name, TOKENIZER_CONFIG_NAME))) # check against double appending model_name in cache_dir self.assertFalse(os.path.exists(os.path.join(tempdir, model_name, model_name))) + + def test_new_tokenizer_registration(self): + try: + AutoConfig.register("custom", CustomConfig) + + AutoTokenizer.register(CustomConfig, slow_tokenizer_class=CustomTokenizer) + # Trying to register something existing in the PaddleNLP library will raise an error + with self.assertRaises(ValueError): + AutoTokenizer.register(BertConfig, slow_tokenizer_class=BertTokenizer) + + tokenizer = CustomTokenizer.from_pretrained("julien-c/bert-xsmall-dummy") + with tempfile.TemporaryDirectory() as tmp_dir: + tokenizer.save_pretrained(tmp_dir) + + new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir) + self.assertIsInstance(new_tokenizer, CustomTokenizer) + + finally: + if "custom" in CONFIG_MAPPING._extra_content: + del CONFIG_MAPPING._extra_content["custom"] + if CustomConfig in TOKENIZER_MAPPING._extra_content: + del TOKENIZER_MAPPING._extra_content[CustomConfig] + + def test_new_tokenizer_fast_registration(self): + try: + AutoConfig.register("custom", CustomConfig) + + # Can register in two steps + AutoTokenizer.register(CustomConfig, slow_tokenizer_class=CustomTokenizer) + self.assertEqual(TOKENIZER_MAPPING[CustomConfig], (CustomTokenizer, None)) + AutoTokenizer.register(CustomConfig, fast_tokenizer_class=CustomTokenizerFast) + self.assertEqual(TOKENIZER_MAPPING[CustomConfig], (CustomTokenizer, CustomTokenizerFast)) + + del TOKENIZER_MAPPING._extra_content[CustomConfig] + # Can register in one step + AutoTokenizer.register( + CustomConfig, slow_tokenizer_class=CustomTokenizer, fast_tokenizer_class=CustomTokenizerFast + ) + self.assertEqual(TOKENIZER_MAPPING[CustomConfig], (CustomTokenizer, CustomTokenizerFast)) + + # Trying to register something existing in the PaddleNLP library will raise an error + with self.assertRaises(ValueError): + AutoTokenizer.register(BertConfig, fast_tokenizer_class=BertTokenizerFast) + + # We pass through a llama tokenizer fast cause there is no converter slow to fast for our new toknizer + # and that model does not have a tokenizer.json + with tempfile.TemporaryDirectory() as tmp_dir: + llama_tokenizer = BertTokenizerFast.from_pretrained("julien-c/bert-xsmall-dummy", from_hf_hub=True) + llama_tokenizer.save_pretrained(tmp_dir) + tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir) + + with tempfile.TemporaryDirectory() as tmp_dir: + tokenizer.save_pretrained(tmp_dir, legacy_format=True) + + new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir, use_fast=True) + self.assertIsInstance(new_tokenizer, CustomTokenizerFast) + + new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir, use_fast=False) + self.assertIsInstance(new_tokenizer, CustomTokenizer) + + finally: + if "custom" in CONFIG_MAPPING._extra_content: + del CONFIG_MAPPING._extra_content["custom"] + if CustomConfig in TOKENIZER_MAPPING._extra_content: + del TOKENIZER_MAPPING._extra_content[CustomConfig] diff --git a/tests/transformers/bert/test_tokenizer.py b/tests/transformers/bert/test_tokenizer.py index 41ba41c2528e..1b5780945b09 100644 --- a/tests/transformers/bert/test_tokenizer.py +++ b/tests/transformers/bert/test_tokenizer.py @@ -24,6 +24,7 @@ _is_punctuation, _is_whitespace, ) +from paddlenlp.transformers.bert.tokenizer_fast import BertTokenizerFast from ...testing_utils import slow from ...transformers.test_tokenizer_common import ( @@ -35,6 +36,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = BertTokenizer + rust_tokenizer_class = BertTokenizerFast space_between_special_tokens = True from_pretrained_filter = filter_non_english test_seq2seq = False @@ -206,6 +208,7 @@ def test_offsets_with_special_characters(self): for tokenizer, pretrained_name, kwargs in self.tokenizers_list: with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) sentence = f"A, naïve {tokenizer.mask_token} AllenNLP sentence." tokens = tokenizer.encode( @@ -254,6 +257,21 @@ def test_offsets_with_special_characters(self): ) self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"]) + tokens = tokenizer_r.encode_plus( + sentence, + return_attention_mask=False, + return_token_type_ids=False, + return_offsets_mapping=True, + add_special_tokens=True, + ) + + do_lower_case = tokenizer_r.do_lower_case if hasattr(tokenizer_r, "do_lower_case") else False + + self.assertEqual( + [e[1] for e in expected_results], tokenizer_r.convert_ids_to_tokens(tokens["input_ids"]) + ) + self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"]) + def test_change_tokenize_chinese_chars(self): list_of_commun_chinese_char = ["的", "人", "有"] text_with_chinese_char = "".join(list_of_commun_chinese_char) @@ -262,16 +280,22 @@ def test_change_tokenize_chinese_chars(self): if pretrained_name == "squeezebert-uncased": continue kwargs["tokenize_chinese_chars"] = True - tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) - ids_without_spe_char_p = tokenizer.encode( + ids_without_spe_char_p = tokenizer_p.encode( + text_with_chinese_char, return_token_type_ids=None, add_special_tokens=False + )["input_ids"] + ids_without_spe_char_r = tokenizer_r.encode( text_with_chinese_char, return_token_type_ids=None, add_special_tokens=False )["input_ids"] - tokens_without_spe_char_p = tokenizer.convert_ids_to_tokens(ids_without_spe_char_p) + tokens_without_spe_char_p = tokenizer_p.convert_ids_to_tokens(ids_without_spe_char_p) + tokens_without_spe_char_r = tokenizer_r.convert_ids_to_tokens(ids_without_spe_char_r) # it is expected that each Chinese character is not preceded by "##" self.assertListEqual(tokens_without_spe_char_p, list_of_commun_chinese_char) + self.assertListEqual(tokens_without_spe_char_r, list_of_commun_chinese_char) # not yet supported in bert tokenizer """ diff --git a/tests/utils/test_module/__init__.py b/tests/utils/test_module/__init__.py new file mode 100644 index 000000000000..fd05a9208165 --- /dev/null +++ b/tests/utils/test_module/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/utils/test_module/custom_configuration.py b/tests/utils/test_module/custom_configuration.py new file mode 100644 index 000000000000..95dfaf5dacd5 --- /dev/null +++ b/tests/utils/test_module/custom_configuration.py @@ -0,0 +1,23 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddlenlp.transformers.configuration_utils import PretrainedConfig + + +class CustomConfig(PretrainedConfig): + model_type = "custom" + + def __init__(self, attribute=1, **kwargs): + self.attribute = attribute + super().__init__(**kwargs) diff --git a/tests/utils/test_module/custom_tokenizer.py b/tests/utils/test_module/custom_tokenizer.py new file mode 100644 index 000000000000..511f3b3684a4 --- /dev/null +++ b/tests/utils/test_module/custom_tokenizer.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddlenlp.transformers.bert.tokenizer import BertTokenizer + + +class CustomTokenizer(BertTokenizer): + pass diff --git a/tests/utils/test_module/custom_tokenizer_fast.py b/tests/utils/test_module/custom_tokenizer_fast.py new file mode 100644 index 000000000000..095e0e597d91 --- /dev/null +++ b/tests/utils/test_module/custom_tokenizer_fast.py @@ -0,0 +1,22 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddlenlp.transformers.bert.tokenizer_fast import BertTokenizerFast + +from .custom_tokenizer import CustomTokenizer + + +class CustomTokenizerFast(BertTokenizerFast): + slow_tokenizer_class = CustomTokenizer + pass From 0bca74c917800b06220f2446eec18a63f863511b Mon Sep 17 00:00:00 2001 From: lvdongyi Date: Mon, 4 Nov 2024 15:21:45 +0800 Subject: [PATCH 2/4] add more tests --- tests/transformers/auto/test_confiugration.py | 3 +- tests/transformers/auto/test_tokenizer.py | 25 +++++++++-- tests/transformers/bert/test_tokenizer.py | 44 ++++++++++++++----- tests/transformers/test_tokenizer_common.py | 20 +++++++++ .../test_module/custom_tokenizer_fast.py | 5 +++ 5 files changed, 81 insertions(+), 16 deletions(-) diff --git a/tests/transformers/auto/test_confiugration.py b/tests/transformers/auto/test_confiugration.py index f5553bafa2c3..f2e1bd488531 100644 --- a/tests/transformers/auto/test_confiugration.py +++ b/tests/transformers/auto/test_confiugration.py @@ -24,7 +24,8 @@ from paddlenlp.transformers.auto.configuration import CONFIG_MAPPING from paddlenlp.transformers.bert.configuration import BertConfig from paddlenlp.utils.env import CONFIG_NAME -from tests.utils.test_module.custom_configuration import CustomConfig + +from ...utils.test_module.custom_configuration import CustomConfig class AutoConfigTest(unittest.TestCase): diff --git a/tests/transformers/auto/test_tokenizer.py b/tests/transformers/auto/test_tokenizer.py index 77e54c4801ff..f7ed62c184ab 100644 --- a/tests/transformers/auto/test_tokenizer.py +++ b/tests/transformers/auto/test_tokenizer.py @@ -25,9 +25,13 @@ from paddlenlp.transformers.bert.tokenizer import BertTokenizer from paddlenlp.transformers.bert.tokenizer_fast import BertTokenizerFast from paddlenlp.utils.env import TOKENIZER_CONFIG_NAME -from tests.utils.test_module.custom_configuration import CustomConfig -from tests.utils.test_module.custom_tokenizer import CustomTokenizer -from tests.utils.test_module.custom_tokenizer_fast import CustomTokenizerFast + +from ...utils.test_module.custom_configuration import CustomConfig +from ...utils.test_module.custom_tokenizer import CustomTokenizer +from ...utils.test_module.custom_tokenizer_fast import ( + CustomTokenizerFast, + CustomTokenizerFast2, +) class AutoTokenizerTest(unittest.TestCase): @@ -68,6 +72,18 @@ def test_new_tokenizer_registration(self): def test_new_tokenizer_fast_registration(self): try: + # Trying to register nothing + with self.assertRaises(ValueError): + AutoTokenizer.register(CustomConfig) + # Trying to register tokenizer with wrong type + with self.assertRaises(ValueError): + AutoTokenizer.register(CustomConfig, fast_tokenizer_class=CustomTokenizer) + with self.assertRaises(ValueError): + AutoTokenizer.register(CustomConfig, slow_tokenizer_class=CustomTokenizerFast) + with self.assertRaises(ValueError): + AutoTokenizer.register( + CustomConfig, slow_tokenizer_class=CustomTokenizer, fast_tokenizer_class=CustomTokenizerFast2 + ) AutoConfig.register("custom", CustomConfig) # Can register in two steps @@ -86,6 +102,8 @@ def test_new_tokenizer_fast_registration(self): # Trying to register something existing in the PaddleNLP library will raise an error with self.assertRaises(ValueError): AutoTokenizer.register(BertConfig, fast_tokenizer_class=BertTokenizerFast) + with self.assertRaises(ValueError): + AutoTokenizer.register(BertConfig, slow_tokenizer_class=BertTokenizer) # We pass through a llama tokenizer fast cause there is no converter slow to fast for our new toknizer # and that model does not have a tokenizer.json @@ -102,7 +120,6 @@ def test_new_tokenizer_fast_registration(self): new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir, use_fast=False) self.assertIsInstance(new_tokenizer, CustomTokenizer) - finally: if "custom" in CONFIG_MAPPING._extra_content: del CONFIG_MAPPING._extra_content["custom"] diff --git a/tests/transformers/bert/test_tokenizer.py b/tests/transformers/bert/test_tokenizer.py index 1b5780945b09..8a5d6eadc607 100644 --- a/tests/transformers/bert/test_tokenizer.py +++ b/tests/transformers/bert/test_tokenizer.py @@ -191,19 +191,34 @@ def test_clean_text(self): @slow def test_sequence_builders(self): - tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased") + tokenizer_p = self.tokenizer_class.from_pretrained("bert-base-uncased") + tokenizer_r = self.rust_tokenizer_class.from_pretrained("bert-base-uncased") - text = tokenizer.encode("sequence builders", return_token_type_ids=None, add_special_tokens=False)["input_ids"] - text_2 = tokenizer.encode("multi-sequence build", return_token_type_ids=None, add_special_tokens=False)[ + text = tokenizer_p.encode("sequence builders", return_token_type_ids=None, add_special_tokens=False)[ + "input_ids" + ] + text_2 = tokenizer_p.encode("multi-sequence build", return_token_type_ids=None, add_special_tokens=False)[ "input_ids" ] - encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) - encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) + encoded_sentence = tokenizer_p.build_inputs_with_special_tokens(text) + encoded_pair = tokenizer_p.build_inputs_with_special_tokens(text, text_2) assert encoded_sentence == [101] + text + [102] assert encoded_pair == [101] + text + [102] + text_2 + [102] + text_r = tokenizer_r.encode("sequence builders", return_token_type_ids=None, add_special_tokens=False)[ + "input_ids" + ] + text_2_r = tokenizer_r.encode("multi-sequence build", return_token_type_ids=None, add_special_tokens=False)[ + "input_ids" + ] + + encoded_sentence_r = tokenizer_r.build_inputs_with_special_tokens(text) + encoded_pair_r = tokenizer_r.build_inputs_with_special_tokens(text_r, text_2_r) + assert encoded_sentence_r == [101] + text + [102] + assert encoded_pair_r == [101] + text_r + [102] + text_2_r + [102] + def test_offsets_with_special_characters(self): for tokenizer, pretrained_name, kwargs in self.tokenizers_list: with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): @@ -297,15 +312,22 @@ def test_change_tokenize_chinese_chars(self): self.assertListEqual(tokens_without_spe_char_p, list_of_commun_chinese_char) self.assertListEqual(tokens_without_spe_char_r, list_of_commun_chinese_char) - # not yet supported in bert tokenizer - """ kwargs["tokenize_chinese_chars"] = False - tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) - ids_without_spe_char_p = tokenizer.encode(text_with_chinese_char, return_token_type_ids=None,add_special_tokens=False)["input_ids"] - tokens_without_spe_char_p = tokenizer.convert_ids_to_tokens(ids_without_spe_char_p) + tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + ids_without_spe_char_p = tokenizer_p.encode( + text_with_chinese_char, return_token_type_ids=None, add_special_tokens=False + )["input_ids"] + ids_without_spe_char_r = tokenizer_r.encode( + text_with_chinese_char, return_token_type_ids=None, add_special_tokens=False + )["input_ids"] + + tokens_without_spe_char_p = tokenizer_p.convert_ids_to_tokens(ids_without_spe_char_p) + tokens_without_spe_char_r = tokenizer_r.convert_ids_to_tokens(ids_without_spe_char_r) # it is expected that only the first Chinese character is not preceded by "##". expected_tokens = [ f"##{token}" if idx != 0 else token for idx, token in enumerate(list_of_commun_chinese_char) ] + self.assertListEqual(tokens_without_spe_char_r, expected_tokens) self.assertListEqual(tokens_without_spe_char_p, expected_tokens) - """ diff --git a/tests/transformers/test_tokenizer_common.py b/tests/transformers/test_tokenizer_common.py index 0b693213d374..6f63911a8b63 100644 --- a/tests/transformers/test_tokenizer_common.py +++ b/tests/transformers/test_tokenizer_common.py @@ -2249,6 +2249,26 @@ def test_special_tokens_initialization_with_non_empty_additional_special_tokens( ), ) + def test_create_token_type_ids(self): + if not hasattr(self, "rust_tokenizer_class"): + self.skipTest(reason="Rust tokenizer not available for this tokenizer") + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + input_simple = [1, 2, 3] + input_pair = [1, 2, 3] + + # Generate output + output_r = tokenizer_r.create_token_type_ids_from_sequences(input_simple) + output_p = tokenizer_p.create_token_type_ids_from_sequences(input_simple) + self.assertEqual(output_p, output_r) + + # Generate pair output + output_r = tokenizer_r.create_token_type_ids_from_sequences(input_simple, input_pair) + output_p = tokenizer_p.create_token_type_ids_from_sequences(input_simple, input_pair) + self.assertEqual(output_p, output_r) + class TrieTest(unittest.TestCase): def test_trie(self): diff --git a/tests/utils/test_module/custom_tokenizer_fast.py b/tests/utils/test_module/custom_tokenizer_fast.py index 095e0e597d91..3b4a5c5fdf4c 100644 --- a/tests/utils/test_module/custom_tokenizer_fast.py +++ b/tests/utils/test_module/custom_tokenizer_fast.py @@ -20,3 +20,8 @@ class CustomTokenizerFast(BertTokenizerFast): slow_tokenizer_class = CustomTokenizer pass + + +class CustomTokenizerFast2(BertTokenizerFast): + slow_tokenizer_class = None + pass From 0eba022962e3408e28c2ddb1e88dcf7a7b67272e Mon Sep 17 00:00:00 2001 From: lvdongyi Date: Mon, 4 Nov 2024 15:34:15 +0800 Subject: [PATCH 3/4] CustomTokenizerFast2->CustomTokenizerFastWithoutSlow --- tests/transformers/auto/test_tokenizer.py | 4 ++-- tests/utils/test_module/custom_tokenizer_fast.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/transformers/auto/test_tokenizer.py b/tests/transformers/auto/test_tokenizer.py index f7ed62c184ab..ee4b249d31a7 100644 --- a/tests/transformers/auto/test_tokenizer.py +++ b/tests/transformers/auto/test_tokenizer.py @@ -30,7 +30,7 @@ from ...utils.test_module.custom_tokenizer import CustomTokenizer from ...utils.test_module.custom_tokenizer_fast import ( CustomTokenizerFast, - CustomTokenizerFast2, + CustomTokenizerFastWithoutSlow, ) @@ -82,7 +82,7 @@ def test_new_tokenizer_fast_registration(self): AutoTokenizer.register(CustomConfig, slow_tokenizer_class=CustomTokenizerFast) with self.assertRaises(ValueError): AutoTokenizer.register( - CustomConfig, slow_tokenizer_class=CustomTokenizer, fast_tokenizer_class=CustomTokenizerFast2 + CustomConfig, slow_tokenizer_class=CustomTokenizer, fast_tokenizer_class=CustomTokenizerFastWithoutSlow ) AutoConfig.register("custom", CustomConfig) diff --git a/tests/utils/test_module/custom_tokenizer_fast.py b/tests/utils/test_module/custom_tokenizer_fast.py index 3b4a5c5fdf4c..dff3b4418e4d 100644 --- a/tests/utils/test_module/custom_tokenizer_fast.py +++ b/tests/utils/test_module/custom_tokenizer_fast.py @@ -22,6 +22,6 @@ class CustomTokenizerFast(BertTokenizerFast): pass -class CustomTokenizerFast2(BertTokenizerFast): +class CustomTokenizerFastWithoutSlow(BertTokenizerFast): slow_tokenizer_class = None pass From 851af3850bae238bbbc98e6db83ab4b05330ec63 Mon Sep 17 00:00:00 2001 From: lvdongyi Date: Mon, 4 Nov 2024 07:43:38 +0000 Subject: [PATCH 4/4] lint --- tests/transformers/auto/test_tokenizer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/transformers/auto/test_tokenizer.py b/tests/transformers/auto/test_tokenizer.py index ee4b249d31a7..54c568113023 100644 --- a/tests/transformers/auto/test_tokenizer.py +++ b/tests/transformers/auto/test_tokenizer.py @@ -82,7 +82,9 @@ def test_new_tokenizer_fast_registration(self): AutoTokenizer.register(CustomConfig, slow_tokenizer_class=CustomTokenizerFast) with self.assertRaises(ValueError): AutoTokenizer.register( - CustomConfig, slow_tokenizer_class=CustomTokenizer, fast_tokenizer_class=CustomTokenizerFastWithoutSlow + CustomConfig, + slow_tokenizer_class=CustomTokenizer, + fast_tokenizer_class=CustomTokenizerFastWithoutSlow, ) AutoConfig.register("custom", CustomConfig)