Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tokenizer] Add BertTokenizerFast, support register new tokenizer #9353

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion paddlenlp/transformers/auto/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ...utils.import_utils import import_module
from ...utils.log import logger
from ..configuration_utils import PretrainedConfig
from ..tokenizer_utils import PretrainedTokenizer
from ..tokenizer_utils_base import TOKENIZER_CONFIG_FILE
from ..tokenizer_utils_fast import PretrainedTokenizerFast
from .configuration import (
Expand All @@ -45,7 +46,13 @@
[
("albert", (("AlbertChineseTokenizer", "AlbertEnglishTokenizer"), None)),
("bart", "BartTokenizer"),
("bert", "BertTokenizer"),
(
"bert",
(
"BertTokenizer",
"BertTokenizerFast" if is_tokenizers_available() else None,
),
),
("blenderbot", "BlenderbotTokenizer"),
("bloom", "BloomTokenizer"),
("clip", "CLIPTokenizer"),
Expand Down Expand Up @@ -459,3 +466,46 @@
"- or a correct model-identifier of community-contributed pretrained models,\n"
"- or the correct path to a directory containing relevant tokenizer files.\n"
)

def register(config_class, slow_tokenizer_class=None, fast_tokenizer_class=None, exist_ok=False):
"""
Register a new tokenizer in this mapping.


Args:
config_class ([`PretrainedConfig`]):
The configuration corresponding to the model to register.
slow_tokenizer_class ([`PretrainedTokenizer`], *optional*):
The slow tokenizer to register.
fast_tokenizer_class ([`PretrainedTokenizerFast`], *optional*):
The fast tokenizer to register.
"""
if slow_tokenizer_class is None and fast_tokenizer_class is None:
raise ValueError("You need to pass either a `slow_tokenizer_class` or a `fast_tokenizer_class")
if slow_tokenizer_class is not None and issubclass(slow_tokenizer_class, PretrainedTokenizerFast):
raise ValueError("You passed a fast tokenizer in the `slow_tokenizer_class`.")
if fast_tokenizer_class is not None and issubclass(fast_tokenizer_class, PretrainedTokenizer):
raise ValueError("You passed a slow tokenizer in the `fast_tokenizer_class`.")

if (
slow_tokenizer_class is not None
and fast_tokenizer_class is not None
and issubclass(fast_tokenizer_class, PretrainedTokenizerFast)
and fast_tokenizer_class.slow_tokenizer_class != slow_tokenizer_class
):
raise ValueError(
"The fast tokenizer class you are passing has a `slow_tokenizer_class` attribute that is not "
"consistent with the slow tokenizer class you passed (fast tokenizer has "
f"{fast_tokenizer_class.slow_tokenizer_class} and you passed {slow_tokenizer_class}. Fix one of those "
"so they match!"
)

# Avoid resetting a set slow/fast tokenizer if we are passing just the other ones.
if config_class in TOKENIZER_MAPPING._extra_content:
existing_slow, existing_fast = TOKENIZER_MAPPING[config_class]
if slow_tokenizer_class is None:
slow_tokenizer_class = existing_slow
if fast_tokenizer_class is None:
fast_tokenizer_class = existing_fast

Check warning on line 509 in paddlenlp/transformers/auto/tokenizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto/tokenizer.py#L509

Added line #L509 was not covered by tests

TOKENIZER_MAPPING.register(config_class, (slow_tokenizer_class, fast_tokenizer_class), exist_ok=exist_ok)
164 changes: 164 additions & 0 deletions paddlenlp/transformers/bert/tokenizer_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from typing import List, Optional, Tuple

from tokenizers import normalizers

from ..tokenizer_utils_fast import PretrainedTokenizerFast
from .tokenizer import BertTokenizer

VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}


class BertTokenizerFast(PretrainedTokenizerFast):
r"""

This tokenizer inherits from [`PretrainedTokenizerFast`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.

Args:
vocab_file (`str`):
File containing the vocabulary.
do_lower_case (`bool`, *optional*, defaults to `True`):
Whether or not to lowercase the input when tokenizing.
unk_token (`str`, *optional*, defaults to `"[UNK]"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
sep_token (`str`, *optional*, defaults to `"[SEP]"`):
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
sequence classification or for a text and a question for question answering. It is also used as the last
token of a sequence built with special tokens.
pad_token (`str`, *optional*, defaults to `"[PAD]"`):
The token used for padding, for example when batching sequences of different lengths.
cls_token (`str`, *optional*, defaults to `"[CLS]"`):
The classifier token which is used when doing sequence classification (classification of the whole sequence
instead of per-token classification). It is the first token of the sequence when built with special tokens.
mask_token (`str`, *optional*, defaults to `"[MASK]"`):
The token used for masking values. This is the token used when training this model with masked language
modeling. This is the token which the model will try to predict.
clean_text (`bool`, *optional*, defaults to `True`):
Whether or not to clean the text before tokenization by removing any control characters and replacing all
whitespaces by the classic one.
tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese.
strip_accents (`bool`, *optional*):
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
value for `lowercase` (as in the original BERT).
wordpieces_prefix (`str`, *optional*, defaults to `"##"`):
The prefix for subwords.
"""

resource_files_names = VOCAB_FILES_NAMES
slow_tokenizer_class = BertTokenizer

def __init__(
self,
vocab_file=None,
tokenizer_file=None,
do_lower_case=True,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
tokenize_chinese_chars=True,
strip_accents=None,
**kwargs,
):
super().__init__(
vocab_file,
tokenizer_file=tokenizer_file,
do_lower_case=do_lower_case,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)

normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
if (
normalizer_state.get("lowercase", do_lower_case) != do_lower_case
or normalizer_state.get("strip_accents", strip_accents) != strip_accents
or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars
):
normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
normalizer_state["lowercase"] = do_lower_case
normalizer_state["strip_accents"] = strip_accents
normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars
self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)

Check warning on line 105 in paddlenlp/transformers/bert/tokenizer_fast.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/bert/tokenizer_fast.py#L101-L105

Added lines #L101 - L105 were not covered by tests

self.do_lower_case = do_lower_case

def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A BERT sequence has the following format:

- single sequence: `[CLS] X [SEP]`
- pair of sequences: `[CLS] A [SEP] B [SEP]`

Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.

Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]

Check warning on line 126 in paddlenlp/transformers/bert/tokenizer_fast.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/bert/tokenizer_fast.py#L126

Added line #L126 was not covered by tests

if token_ids_1 is not None:
output += token_ids_1 + [self.sep_token_id]

Check warning on line 129 in paddlenlp/transformers/bert/tokenizer_fast.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/bert/tokenizer_fast.py#L128-L129

Added lines #L128 - L129 were not covered by tests

return output

Check warning on line 131 in paddlenlp/transformers/bert/tokenizer_fast.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/bert/tokenizer_fast.py#L131

Added line #L131 was not covered by tests

def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
pair mask has the following format:

```
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```

If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).

Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.

Returns:
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
return tuple(files)
48 changes: 47 additions & 1 deletion paddlenlp/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import tokenizers
Expand All @@ -28,7 +29,7 @@
pre_tokenizers,
processors,
)
from tokenizers.models import BPE, Unigram
from tokenizers.models import BPE, Unigram, WordPiece

from paddlenlp.utils.import_utils import (
is_protobuf_available,
Expand Down Expand Up @@ -330,6 +331,50 @@ def converted(self) -> Tokenizer:
return tokenizer


class BertConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.vocab
tokenizer = Tokenizer(
WordPiece(
OrderedDict([(vocab._idx_to_token[i], i) for i in range(len(vocab))]),
unk_token=str(self.original_tokenizer.unk_token),
)
)

tokenize_chinese_chars = False
strip_accents = False
do_lower_case = False
if hasattr(self.original_tokenizer, "basic_tokenizer"):
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case

tokenizer.normalizer = normalizers.BertNormalizer(
clean_text=True,
handle_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
lowercase=do_lower_case,
)
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()

cls = str(self.original_tokenizer.cls_token)
sep = str(self.original_tokenizer.sep_token)
cls_token_id = self.original_tokenizer.cls_token_id
sep_token_id = self.original_tokenizer.sep_token_id

tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls}:0 $A:0 {sep}:0",
pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
special_tokens=[
(cls, cls_token_id),
(sep, sep_token_id),
],
)
tokenizer.decoder = decoders.WordPiece(prefix="##")

return tokenizer


class LlamaConverter(SpmConverter):
handle_byte_fallback = True

Expand Down Expand Up @@ -399,6 +444,7 @@ def pre_tokenizer(self, replacement, add_prefix_space):

SLOW_TO_FAST_CONVERTERS = {
"LlamaTokenizer": LlamaConverter,
"BertTokenizer": BertConverter,
}


Expand Down
3 changes: 3 additions & 0 deletions paddlenlp/transformers/tokenizer_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1851,6 +1851,9 @@ def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True):

# Add tokenizer class to the tokenizer config to be able to reload it with from_pretrained
tokenizer_class = self.__class__.__name__
# Remove the Fast at the end unless we have a special `PreTrainedTokenizerFast`
if tokenizer_class.endswith("Fast") and tokenizer_class != "PreTrainedTokenizerFast":
tokenizer_class = tokenizer_class[:-4]
tokenizer_config["tokenizer_class"] = tokenizer_class

with io.open(tokenizer_config_file, "w", encoding="utf-8") as f:
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/tokenizer_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ def train_new_from_iterator(
Additional keyword arguments passed along to the trainer from the 🤗 Tokenizers library.

Returns:
[`PreTrainedTokenizerFast`]: A new tokenizer of the same type as the original one, trained on
[`PretrainedTokenizerFast`]: A new tokenizer of the same type as the original one, trained on
`text_iterator`.

"""
Expand Down
10 changes: 2 additions & 8 deletions tests/transformers/auto/test_confiugration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
from paddlenlp.transformers import AutoConfig
from paddlenlp.transformers.auto.configuration import CONFIG_MAPPING
from paddlenlp.transformers.bert.configuration import BertConfig
from paddlenlp.transformers.configuration_utils import PretrainedConfig
from paddlenlp.utils.env import CONFIG_NAME

from ...utils.test_module.custom_configuration import CustomConfig


class AutoConfigTest(unittest.TestCase):
def test_built_in_model_class_config(self):
Expand Down Expand Up @@ -90,13 +91,6 @@ def test_load_from_legacy_config(self):
self.assertEqual(auto_config.hidden_size, number)

def test_new_config_registration(self):
class CustomConfig(PretrainedConfig):
model_type = "custom"

def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)

try:
AutoConfig.register("custom", CustomConfig)
# Wrong model type will raise an error
Expand Down
Loading
Loading