Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 7th No.43】完善 TokenizerFast 功能支持 part 5 #9594

Merged
merged 3 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion paddlenlp/transformers/auto/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@
("ctrl", "CTRLTokenizer"),
("distilbert", "DistilBertTokenizer"),
("electra", "ElectraTokenizer"),
("ernie", "ErnieTokenizer"),
(
"ernie",
("ErnieTokenizer", "ErnieTokenizerFast" if is_tokenizers_available() else None),
),
("ernie_m", "ErnieMTokenizer"),
("fnet", "FNetTokenizer"),
("funnel", "FunnelTokenizer"),
Expand Down
44 changes: 44 additions & 0 deletions paddlenlp/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,49 @@ def converted(self) -> Tokenizer:
return tokenizer


class ErnieConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.vocab
tokenizer = Tokenizer(
WordPiece(
OrderedDict([(vocab._idx_to_token[i], i) for i in range(len(vocab))]),
unk_token=str(self.original_tokenizer.unk_token),
)
)
tokenize_chinese_chars = False
strip_accents = False
do_lower_case = False
if hasattr(self.original_tokenizer, "basic_tokenizer"):
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case

tokenizer.normalizer = normalizers.BertNormalizer(
clean_text=True,
handle_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
lowercase=do_lower_case,
)
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()

cls = str(self.original_tokenizer.cls_token)
sep = str(self.original_tokenizer.sep_token)
cls_token_id = self.original_tokenizer.cls_token_id
sep_token_id = self.original_tokenizer.sep_token_id

tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls}:0 $A:0 {sep}:0",
pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
special_tokens=[
(cls, cls_token_id),
(sep, sep_token_id),
],
)
tokenizer.decoder = decoders.WordPiece(prefix="##")

return tokenizer


class GPTConverter(Converter):
def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]] = None) -> Tokenizer:
if not vocab:
Expand Down Expand Up @@ -612,6 +655,7 @@ def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]]

SLOW_TO_FAST_CONVERTERS = {
"BertTokenizer": BertConverter,
"ErnieTokenizer": ErnieConverter,
"GemmaTokenizer": GemmaConverter,
"GPTTokenizer": GPTConverter,
"LlamaTokenizer": LlamaConverter,
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/ernie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
from .configuration import *
from .modeling import *
from .tokenizer import *
from .tokenizer_fast import *
147 changes: 147 additions & 0 deletions paddlenlp/transformers/ernie/tokenizer_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for Ernie."""

import json
from typing import List, Optional, Tuple

from tokenizers import normalizers

from ..tokenizer_utils_fast import PretrainedTokenizerFast
from .tokenizer import ErnieTokenizer

VOCAB_FILES_NAMES = {
"vocab_file": "vocab.txt",
"tokenizer_file": "tokenizer.json",
}


class ErnieTokenizerFast(PretrainedTokenizerFast):
"""
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.

Args:
vocab_file (`str`, *optional*):
Path to the vocabulary file.
merges_file (`str`, *optional*):
Path to the merges file.
tokenizer_file (`str`, *optional*):
Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
contains everything needed to load the tokenizer.
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead. Not applicable to this tokenizer.
bos_token (`str`, *optional*):
The beginning of sequence token. Not applicable for this tokenizer.
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The end of sequence token.
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The token used for padding, for example when batching sequences of different lengths.
"""

vocab_files_names = VOCAB_FILES_NAMES
resource_files_names = VOCAB_FILES_NAMES
model_input_names = ["input_ids", "attention_mask"]
slow_tokenizer_class = ErnieTokenizer

def __init__(
self,
vocab_file=None,
tokenizer_file=None,
do_lower_case=True,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
**kwargs
):
super().__init__(
vocab_file=vocab_file,
tokenizer_file=tokenizer_file,
do_lower_case=do_lower_case,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
**kwargs,
)

normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
if normalizer_state.get("lowercase", do_lower_case) != do_lower_case:
normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
normalizer_state["lowercase"] = do_lower_case
self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)

Check warning on line 87 in paddlenlp/transformers/ernie/tokenizer_fast.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ernie/tokenizer_fast.py#L85-L87

Added lines #L85 - L87 were not covered by tests

self.do_lower_case = do_lower_case

def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A BERT sequence has the following format:

- single sequence: `[CLS] X [SEP]`
- pair of sequences: `[CLS] A [SEP] B [SEP]`

Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.

Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]

Check warning on line 108 in paddlenlp/transformers/ernie/tokenizer_fast.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ernie/tokenizer_fast.py#L108

Added line #L108 was not covered by tests

if token_ids_1 is not None:
output += token_ids_1 + [self.sep_token_id]

Check warning on line 111 in paddlenlp/transformers/ernie/tokenizer_fast.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ernie/tokenizer_fast.py#L110-L111

Added lines #L110 - L111 were not covered by tests

return output

Check warning on line 113 in paddlenlp/transformers/ernie/tokenizer_fast.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ernie/tokenizer_fast.py#L113

Added line #L113 was not covered by tests

def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
pair mask has the following format:

```
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```

If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).

Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.

Returns:
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]

# Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast.save_vocabulary
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
return tuple(files)

Check warning on line 147 in paddlenlp/transformers/ernie/tokenizer_fast.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ernie/tokenizer_fast.py#L146-L147

Added lines #L146 - L147 were not covered by tests
2 changes: 2 additions & 0 deletions tests/transformers/ernie/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ErnieTokenizer,
WordpieceTokenizer,
)
from paddlenlp.transformers.ernie.tokenizer_fast import ErnieTokenizerFast

from ...testing_utils import slow
from ...transformers.test_tokenizer_common import (
Expand All @@ -32,6 +33,7 @@
class ErnieTokenizationTest(TokenizerTesterMixin, unittest.TestCase):

tokenizer_class = ErnieTokenizer
rust_tokenizer_class = ErnieTokenizerFast
space_between_special_tokens = True
from_pretrained_filter = filter_non_english
test_seq2seq = True
Expand Down