Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tokenizer] Support reading Tiktoken tokenizer.model. #9215

Merged
merged 28 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
0fd7240
add support of tiktoken tokenizer, refactor some code
lvdongyi Sep 24, 2024
d1ee434
Merge branch 'PaddlePaddle:develop' into dev-refactor-pretrained
lvdongyi Sep 27, 2024
9004ac9
add support of tiktoken tokenizer, refactor some code
lvdongyi Sep 27, 2024
d004c33
clean code & add blobfile to requirements.txt
lvdongyi Sep 27, 2024
0b61d11
Don't allow multiple Class in a
lvdongyi Sep 28, 2024
aad6750
update docstring, add a RuntimeError when AutoTokenizer failed to loa…
lvdongyi Sep 28, 2024
04dff4d
update albert_english/__init__.py and mbart/__init__.py
lvdongyi Sep 28, 2024
6475a83
fix typo, rm redundent notations
lvdongyi Sep 28, 2024
dea3ad4
some changes...
lvdongyi Oct 11, 2024
f5ae794
AutoTokenizer will not load TokenzierFast by default
lvdongyi Oct 11, 2024
ce684a1
Add test for external config
lvdongyi Oct 11, 2024
75368d5
revert unnecrssary changes
lvdongyi Oct 12, 2024
469ffbf
Update test_modeling_common.py
lvdongyi Oct 12, 2024
ee33fba
fix
lvdongyi Oct 12, 2024
92e4e0e
Merge branch 'PaddlePaddle:develop' into dev-20240927-support-tiktoken
lvdongyi Oct 12, 2024
f0f4113
Merge branch 'PaddlePaddle:develop' into dev-20240927-support-tiktoken
lvdongyi Oct 15, 2024
353fb41
rm redundent print
lvdongyi Oct 17, 2024
d279d8d
revert some changes
lvdongyi Oct 17, 2024
e367332
fix problem in TOKENIZER_MAPPING_NAMES
lvdongyi Oct 17, 2024
a422932
try fix
lvdongyi Oct 18, 2024
7ff5a17
Merge branch 'PaddlePaddle:develop' into dev-20240927-support-tiktoken
lvdongyi Oct 21, 2024
d46655c
update
lvdongyi Oct 21, 2024
3412f50
fix
lvdongyi Oct 22, 2024
19521f9
rm redundent comment, resolve complicate
lvdongyi Oct 23, 2024
99299b0
Merge branch 'PaddlePaddle:develop' into dev-20240927-support-tiktoken
lvdongyi Oct 23, 2024
5c169fb
Merge branch 'PaddlePaddle:develop' into dev-20240927-support-tiktoken
lvdongyi Oct 25, 2024
d2d7eeb
add case of built-in tokenizers to handle CI error
lvdongyi Oct 25, 2024
5579695
Merge branch 'PaddlePaddle:develop' into dev-20240927-support-tiktoken
lvdongyi Oct 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddlenlp/transformers/albert_chinese/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/albert_english/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
41 changes: 37 additions & 4 deletions paddlenlp/transformers/auto/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,11 @@
for key, cls in CONFIG_MAPPING_NAMES.items():
if cls == config:
return key
# if key not found check in extra content
for key, cls in CONFIG_MAPPING._extra_content.items():
if cls.__name__ == config:
return key
return None

Check warning on line 220 in paddlenlp/transformers/auto/configuration.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto/configuration.py#L217-L220

Added lines #L217 - L220 were not covered by tests
DrownFish19 marked this conversation as resolved.
Show resolved Hide resolved


class _LazyConfigMapping(OrderedDict):
Expand All @@ -230,33 +234,35 @@
if key in self._extra_content:
return self._extra_content[key]
if key not in self._mapping:
raise KeyError(key)

Check warning on line 237 in paddlenlp/transformers/auto/configuration.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto/configuration.py#L237

Added line #L237 was not covered by tests
value = self._mapping[key]
module_name = model_type_to_module_name(key)
if module_name not in self._modules:
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
self._modules[module_name] = importlib.import_module(
f".{module_name}.configuration", "paddlenlp.transformers"
)
if hasattr(self._modules[module_name], value):
return getattr(self._modules[module_name], value)

# Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the
# object at the top level.
transformers_module = importlib.import_module("transformers")
transformers_module = importlib.import_module("paddlenlp")
return getattr(transformers_module, value)

Check warning on line 250 in paddlenlp/transformers/auto/configuration.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto/configuration.py#L249-L250

Added lines #L249 - L250 were not covered by tests

def keys(self):
return list(self._mapping.keys()) + list(self._extra_content.keys())

def values(self):
return [self[k] for k in self._mapping.keys()] + list(self._extra_content.values())

Check warning on line 256 in paddlenlp/transformers/auto/configuration.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto/configuration.py#L256

Added line #L256 was not covered by tests

def items(self):
return [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items())

Check warning on line 259 in paddlenlp/transformers/auto/configuration.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto/configuration.py#L259

Added line #L259 was not covered by tests

def __iter__(self):
return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))

Check warning on line 262 in paddlenlp/transformers/auto/configuration.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto/configuration.py#L262

Added line #L262 was not covered by tests

def __contains__(self, item):
return item in self._mapping or item in self._extra_content

Check warning on line 265 in paddlenlp/transformers/auto/configuration.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto/configuration.py#L265

Added line #L265 was not covered by tests

def register(self, key, value, exist_ok=False):
DrownFish19 marked this conversation as resolved.
Show resolved Hide resolved
"""
Expand Down Expand Up @@ -436,14 +442,24 @@
from_hf_hub=from_hf_hub,
from_aistudio=from_aistudio,
)
if config_file is not None and os.path.exists(config_file):
config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
if "model_type" in config_dict:
try:
config_class = CONFIG_MAPPING[config_dict["model_type"]]
except KeyError:
raise ValueError(

Check warning on line 450 in paddlenlp/transformers/auto/configuration.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto/configuration.py#L449-L450

Added lines #L449 - L450 were not covered by tests
f"The checkpoint you are trying to load has model type `{config_dict['model_type']}` "
"but Transformers does not recognize this architecture. This could be because of an "
"issue with the checkpoint, or because your version of Transformers is out of date."
)
return config_class.from_dict(config_dict, **unused_kwargs)
elif "model_type" not in config_dict and config_file is not None and os.path.exists(config_file):
config_class = cls._get_config_class_from_config(pretrained_model_name_or_path, config_file)
logger.info("We are using %s to load '%s'." % (config_class, pretrained_model_name_or_path))
if config_class is cls:
return cls.from_file(config_file)
return config_class.from_pretrained(config_file, *model_args, **kwargs)
elif config_file is None:
config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
# Fallback: use pattern matching on the string.
# We go from longer names to shorter names to catch roberta before bert (for instance)
for pattern in sorted(CONFIG_MAPPING.keys(), key=len, reverse=True):
Expand All @@ -457,3 +473,20 @@
"- or a correct model-identifier of community-contributed pretrained models,\n"
"- or the correct path to a directory containing relevant config files.\n"
)

@staticmethod
def register(model_type, config, exist_ok=False):
"""
Register a new configuration for this class.

Args:
model_type (`str`): The model type like "bert" or "gpt".
config ([`PretrainedConfig`]): The config to register.
"""
if issubclass(config, PretrainedConfig) and config.model_type != model_type:
raise ValueError(
"The config you are passing has a `model_type` attribute that is not consistent with the model type "
f"you passed (config has {config.model_type} and you passed {model_type}. Fix one of those so they "
"match!"
)
CONFIG_MAPPING.register(model_type, config, exist_ok=exist_ok)
21 changes: 9 additions & 12 deletions paddlenlp/transformers/auto/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,27 @@
from collections import OrderedDict
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union

from paddlenlp.transformers.auto.configuration import (
from ...utils import is_tokenizers_available
from ...utils.download import resolve_file_path
from ...utils.import_utils import import_module
from ...utils.log import logger
from ..configuration_utils import PretrainedConfig
from ..tokenizer_utils_base import TOKENIZER_CONFIG_FILE
from ..tokenizer_utils_fast import PretrainedTokenizerFast
from .configuration import (
CONFIG_MAPPING_NAMES,
AutoConfig,
config_class_to_model_type,
model_type_to_module_name,
)
from paddlenlp.transformers.configuration_utils import PretrainedConfig
from paddlenlp.transformers.tokenizer_utils_base import TOKENIZER_CONFIG_FILE
from paddlenlp.transformers.tokenizer_utils_fast import PretrainedTokenizerFast

from ...utils import is_tokenizers_available
from ...utils.download import resolve_file_path
from ...utils.import_utils import import_module
from ...utils.log import logger
from .factory import _LazyAutoMapping

__all__ = [
"AutoTokenizer",
]

if TYPE_CHECKING:
# This significantly improves completion suggestion performance when
# the transformers package is used with Microsoft's Pylance language server.
TOKENIZER_MAPPING_NAMES: OrderedDict[str, Tuple[Optional[str], Optional[str]]] = OrderedDict()

Check warning on line 42 in paddlenlp/transformers/auto/tokenizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto/tokenizer.py#L42

Added line #L42 was not covered by tests
else:
TOKENIZER_MAPPING_NAMES = OrderedDict(
[
Expand Down Expand Up @@ -141,7 +138,7 @@

def tokenizer_class_from_name(class_name: str):
if class_name == "PretrainedTokenizerFast":
return PretrainedTokenizerFast

Check warning on line 141 in paddlenlp/transformers/auto/tokenizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto/tokenizer.py#L141

Added line #L141 was not covered by tests

for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
if class_name in tokenizers:
Expand All @@ -156,21 +153,21 @@
module = importlib.import_module(f".{module_name}.tokenizer", "paddlenlp.transformers")

return getattr(module, class_name)
except AttributeError:
raise ValueError(f"Tokenizer class {class_name} is not currently imported.")

Check warning on line 157 in paddlenlp/transformers/auto/tokenizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto/tokenizer.py#L156-L157

Added lines #L156 - L157 were not covered by tests

for config, tokenizers in TOKENIZER_MAPPING._extra_content.items():
for tokenizer in tokenizers:
if getattr(tokenizer, "__name__", None) == class_name:
return tokenizer

Check warning on line 162 in paddlenlp/transformers/auto/tokenizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto/tokenizer.py#L159-L162

Added lines #L159 - L162 were not covered by tests

# We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main
# init and we return the proper dummy to get an appropriate error message.
main_module = importlib.import_module("paddlenlp")
if hasattr(main_module, class_name):
return getattr(main_module, class_name)

Check warning on line 168 in paddlenlp/transformers/auto/tokenizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto/tokenizer.py#L166-L168

Added lines #L166 - L168 were not covered by tests

return None

Check warning on line 170 in paddlenlp/transformers/auto/tokenizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto/tokenizer.py#L170

Added line #L170 was not covered by tests


def get_tokenizer_config(
Expand Down Expand Up @@ -370,11 +367,11 @@
config = kwargs.pop("config", None)
kwargs["_from_auto"] = True

use_fast = kwargs.pop("use_fast", True)
use_fast = kwargs.pop("use_fast", False)
tokenizer_type = kwargs.pop("tokenizer_type", None)
if tokenizer_type is not None:
# TODO: Support tokenizer_type
raise NotImplementedError("tokenizer_type is not supported yet.")

Check warning on line 374 in paddlenlp/transformers/auto/tokenizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto/tokenizer.py#L374

Added line #L374 was not covered by tests

tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
config_tokenizer_class = tokenizer_config.get("tokenizer_class")
Expand All @@ -385,13 +382,13 @@
if config_tokenizer_class is not None:
tokenizer_class = None
if use_fast and not config_tokenizer_class.endswith("Fast"):
tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)

Check warning on line 386 in paddlenlp/transformers/auto/tokenizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto/tokenizer.py#L385-L386

Added lines #L385 - L386 were not covered by tests
if tokenizer_class is None:
tokenizer_class_candidate = config_tokenizer_class
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
if tokenizer_class is None:
raise ValueError(

Check warning on line 391 in paddlenlp/transformers/auto/tokenizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto/tokenizer.py#L391

Added line #L391 was not covered by tests
f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported."
)
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
Expand All @@ -406,7 +403,7 @@
tokenizer_class_fast = tokenizer_class_py[1]
tokenizer_class_py = tokenizer_class_py[0]
else:
tokenizer_class_fast = None

Check warning on line 406 in paddlenlp/transformers/auto/tokenizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto/tokenizer.py#L406

Added line #L406 was not covered by tests
else:
tokenizer_class_fast = None
if tokenizer_class_fast and (use_fast or tokenizer_class_py is None):
Expand All @@ -415,11 +412,11 @@
if tokenizer_class_py is not None:
return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
else:
raise ValueError(

Check warning on line 415 in paddlenlp/transformers/auto/tokenizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto/tokenizer.py#L415

Added line #L415 was not covered by tests
"This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed "
"in order to use this tokenizer."
)
raise RuntimeError(

Check warning on line 419 in paddlenlp/transformers/auto/tokenizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto/tokenizer.py#L419

Added line #L419 was not covered by tests
f"Can't load tokenizer for '{pretrained_model_name_or_path}'.\n"
f"Please make sure that '{pretrained_model_name_or_path}' is:\n"
"- a correct model-identifier of built-in pretrained models,\n"
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/fnet/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import sentencepiece as spm

from ..albert.tokenizer import AddedToken
from ..albert_english.tokenizer import AlbertEnglishTokenizer
from ..tokenizer_utils_base import AddedToken

__all__ = ["FNetTokenizer"]

Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/mbart50/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/utils/download/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def resolve_file_path(
elif index < len(filenames) - 1:
continue
else:
pass
raise FileNotFoundError(f"please make sure one of the {filenames} under the dir {repo_id}")

# check cache
for filename in filenames:
Expand Down Expand Up @@ -272,7 +272,7 @@ def resolve_file_path(
f"'{log_endpoint}' for available revisions."
)
except EntryNotFoundError:
return None
raise EnvironmentError(f"Does not appear one of the {filenames} in {repo_id}.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个Error类型是不是应该是EntryNotFoundError?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块在我修改之前就是这样的(

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

估计是当时就写错了,这个错误可以改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果要raise EntryNotFoundError,那前面就不需要用except捕获EntryNotFoundError了,之前这么做应该有这么做的道理(吧)。

except HTTPError as err:
raise EnvironmentError(f"There was a specific connection error when trying to load {repo_id}:\n{err}")
except ValueError:
Expand Down
31 changes: 31 additions & 0 deletions tests/transformers/auto/test_confiugration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import unittest

from paddlenlp.transformers import AutoConfig
from paddlenlp.transformers.auto.configuration import CONFIG_MAPPING
from paddlenlp.transformers.bert.configuration import BertConfig
from paddlenlp.transformers.configuration_utils import PretrainedConfig
from paddlenlp.utils.env import CONFIG_NAME


Expand Down Expand Up @@ -86,6 +89,34 @@ def test_load_from_legacy_config(self):
auto_config = AutoConfig.from_pretrained(tempdir)
self.assertEqual(auto_config.hidden_size, number)

def test_new_config_registration(self):
class CustomConfig(PretrainedConfig):
model_type = "custom"

def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)

try:
AutoConfig.register("custom", CustomConfig)
# Wrong model type will raise an error
with self.assertRaises(ValueError):
AutoConfig.register("model", CustomConfig)
# Trying to register something existing in the Transformers library will raise an error
lvdongyi marked this conversation as resolved.
Show resolved Hide resolved
with self.assertRaises(ValueError):
AutoConfig.register("bert", BertConfig)

# Now that the config is registered, it can be used as any other config with the auto-API
config = CustomConfig()
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir)
new_config = AutoConfig.from_pretrained(tmp_dir)
self.assertIsInstance(new_config, CustomConfig)

finally:
if "custom" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["custom"]

def test_from_pretrained_cache_dir(self):
model_id = "__internal_testing__/tiny-random-bert"
with tempfile.TemporaryDirectory() as tempdir:
Expand Down
6 changes: 4 additions & 2 deletions tests/transformers/llama/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from paddlenlp.transformers.tokenizer_utils import PretrainedTokenizer
from paddlenlp.transformers.tokenizer_utils_fast import PretrainedTokenizerFast

from ...transformers.test_tokenizer_common import TokenizerTesterMixin
from ..test_tokenizer_common import TokenizerTesterMixin

VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
Expand Down Expand Up @@ -258,13 +258,14 @@ def test_tiktoken_llama(self):
add_bos_token=True,
add_eos_token=True,
from_hf_hub=True,
use_fast=True,
)
self.assertTrue(isinstance(tiktoken_tokenizer, PretrainedTokenizerFast))
tokens = tiktoken_tokenizer.encode(test_text, add_special_tokens=True)["input_ids"]
self.assertEqual(tokens, test_tokens)
tmpdirname = tempfile.mkdtemp()
tiktoken_tokenizer.save_pretrained(tmpdirname)
tokenizer_reload = AutoTokenizer.from_pretrained(tmpdirname)
tokenizer_reload = AutoTokenizer.from_pretrained(tmpdirname, use_fast=True)
self.assertTrue(isinstance(tokenizer_reload, PretrainedTokenizerFast))
tokens = tokenizer_reload.encode(test_text, add_special_tokens=True)["input_ids"]
self.assertEqual(tokens, test_tokens)
Expand All @@ -279,6 +280,7 @@ def test_tiktoken_llama(self):
add_bos_token=True,
add_eos_token=True,
from_hf_hub=True,
use_fast=True,
)
tokens = tiktoken_tokenizer.encode(test_text, add_special_tokens=True)["input_ids"]
self.assertEqual(tokens, test_tokens)
8 changes: 4 additions & 4 deletions tests/transformers/test_chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_inference_template(self):

class ChatTemplateIntegrationTest(unittest.TestCase):
def test_linlyai_chinese_llama_2_chat_template(self):
tokenizer = AutoTokenizer.from_pretrained("linly-ai/chinese-llama-2-7b", use_fast=False)
tokenizer = AutoTokenizer.from_pretrained("linly-ai/chinese-llama-2-7b")
query = "你好"
final_query = tokenizer.apply_chat_template(query, tokenize=False)
expected_query = f"<s>### Instruction:{query} ### Response:"
Expand All @@ -110,7 +110,7 @@ def test_linlyai_chinese_llama_2_chat_template(self):
self.assertEqual(final_query, expected_query)

def test_linlyai_chinese_llama_2_chat_template_with_none_saved(self):
tokenizer = AutoTokenizer.from_pretrained("linly-ai/chinese-llama-2-7b", use_fast=False)
tokenizer = AutoTokenizer.from_pretrained("linly-ai/chinese-llama-2-7b")
tokenizer.chat_template = None
with tempfile.TemporaryDirectory() as tempdir:
tokenizer.save_pretrained(tempdir)
Expand Down Expand Up @@ -182,7 +182,7 @@ def get_common_prefix(self, tokenizer):

def test_prefix(self):
prompt = "欢迎使用 PaddleNLP 大模型开发套件"
tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False)
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
result = tokenizer.apply_chat_template(prompt, tokenize=False)

result_ids = tokenizer(result, add_special_tokens=False)["input_ids"]
Expand Down Expand Up @@ -230,7 +230,7 @@ def test_must_have_system(self):

def test_at_least_one_turn(self):
query = [["你好", "您好,我是个人人工智能助手"], ["今天吃啥", "你可以选择不同的菜系"]]
tokenizer = AutoTokenizer.from_pretrained("linly-ai/chinese-llama-2-7b", use_fast=False)
tokenizer = AutoTokenizer.from_pretrained("linly-ai/chinese-llama-2-7b")
# tokenizer.init_chat_template(self.chat_template_config_file)

# get all query sentence
Expand Down
5 changes: 3 additions & 2 deletions tests/transformers/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,8 +929,9 @@ def tearDown(self):
@unittest.skip("Paddle enable PIR API in Python")
def test_to_static_use_top_k(self):
tokenizer = self.TokenizerClass.from_pretrained(self.internal_testing_model)
if "LlamaTokenizer" in tokenizer.__class__.__name__:
if tokenizer.__class__.__name__ == "LlamaTokenizer":
tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else "<pad>"

model = self.CausalLMClass.from_pretrained(self.internal_testing_model)
model_kwargs = tokenizer(
self.article,
Expand Down Expand Up @@ -1009,7 +1010,7 @@ def test_to_static_use_top_k(self):
@unittest.skip("Paddle enable PIR API in Python")
def test_to_static_use_top_p(self):
tokenizer = self.TokenizerClass.from_pretrained(self.internal_testing_model)
if "LlamaTokenizer" in tokenizer.__class__.__name__:
if tokenizer.__class__.__name__ == "LlamaTokenizer":
tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else "<pad>"
model = self.CausalLMClass.from_pretrained(self.internal_testing_model)

Expand Down
Loading