From 71dafa679f4dd3cb428698821931113b39160b5e Mon Sep 17 00:00:00 2001 From: Weiguo Zhu Date: Fri, 1 Nov 2024 14:33:33 +0800 Subject: [PATCH] [Tokenizer] Support adding special tokens to Qwen tokenizer (#9344) * update qwen tokenizer * add test case --- paddlenlp/transformers/qwen/tokenizer.py | 33 ++++++++++++++-- paddlenlp/utils/import_utils.py | 2 - tests/transformers/qwen/test_tokenizer.py | 46 +++++++++++++++++++++++ 3 files changed, 76 insertions(+), 5 deletions(-) create mode 100644 tests/transformers/qwen/test_tokenizer.py diff --git a/paddlenlp/transformers/qwen/tokenizer.py b/paddlenlp/transformers/qwen/tokenizer.py index ca682b40f17c..02fe5926d5fc 100644 --- a/paddlenlp/transformers/qwen/tokenizer.py +++ b/paddlenlp/transformers/qwen/tokenizer.py @@ -125,14 +125,41 @@ def convert_tokens_to_ids(self, tokens: Union[bytes, str, List[Union[bytes, str] ids.append(self.mergeable_ranks.get(token)) return ids + def _update_tiktoken(self, tokens: List[str], special_tokens: bool = False) -> int: + if special_tokens: + added_tokens = [] + for token in tokens: + if token in self.special_tokens: + continue + + token_id = len(self.mergeable_ranks) + len(self.special_tokens) + self.special_tokens[token] = token_id + self.decoder[token_id] = token + + added_tokens.append(token) + + import tiktoken + + self.tokenizer = tiktoken.Encoding( + "Qwen", + pat_str=PAT_STR, + mergeable_ranks=self.mergeable_ranks, + special_tokens=self.special_tokens, + ) + + return len(added_tokens) + else: + raise ValueError("Adding regular tokens is not supported") + def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int: if not special_tokens and new_tokens: raise ValueError("Adding regular tokens is not supported") + new_tokens_str = [] for token in new_tokens: surface_form = token.content if isinstance(token, AddedToken) else token - if surface_form not in SPECIAL_TOKENS: - raise ValueError("Adding unknown special tokens is not supported") - return 0 + new_tokens_str.append(surface_form) + + return self._update_tiktoken(new_tokens_str, special_tokens) def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]: """ diff --git a/paddlenlp/utils/import_utils.py b/paddlenlp/utils/import_utils.py index 2c3796214a7f..33d4ee831002 100644 --- a/paddlenlp/utils/import_utils.py +++ b/paddlenlp/utils/import_utils.py @@ -52,7 +52,6 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ else: # For packages other than "torch", don't attempt the fallback and set as not available package_exists = False - logger.debug(f"Detected {pkg_name} version: {package_version}") if return_version: return package_exists, package_version else: @@ -96,7 +95,6 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ else: # For packages other than "torch", don't attempt the fallback and set as not available package_exists = False - logger.debug(f"Detected {pkg_name} version: {package_version}") if return_version: return package_exists, package_version else: diff --git a/tests/transformers/qwen/test_tokenizer.py b/tests/transformers/qwen/test_tokenizer.py new file mode 100644 index 000000000000..35412a2757db --- /dev/null +++ b/tests/transformers/qwen/test_tokenizer.py @@ -0,0 +1,46 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from paddlenlp.transformers import QWenTokenizer + + +class Qwen2TokenizationTest(unittest.TestCase): + from_pretrained_id = "qwen/qwen-7b" + tokenizer_class = QWenTokenizer + test_slow_tokenizer = True + space_between_special_tokens = False + from_pretrained_kwargs = None + test_seq2seq = False + + def setUp(self): + super().setUp() + + def get_tokenizer(self, **kwargs): + return QWenTokenizer.from_pretrained(self.from_pretrained_id, **kwargs) + + def test_add_special_tokens(self): + tokenizer = self.get_tokenizer() + origin_tokens_len = len(tokenizer) + + add_tokens_num = tokenizer.add_special_tokens({"additional_special_tokens": [""]}) + assert add_tokens_num == 1 + assert len(tokenizer) == origin_tokens_len + 1 + + add_tokens_num = tokenizer.add_special_tokens({"unk_token": ""}) + assert add_tokens_num == 1 + assert len(tokenizer) == origin_tokens_len + 2