Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate chat templates into a single file #33957

Merged
merged 33 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
020b6d9
Initial draft
Rocketknight1 Oct 4, 2024
b8f9e0f
Add .jinja file loading for processors
Rocketknight1 Oct 7, 2024
ca101e6
Add processor saving of naked chat template files
Rocketknight1 Oct 7, 2024
82c357e
make fixup
Rocketknight1 Oct 7, 2024
3775e2b
Add save-load test for tokenizers
Rocketknight1 Oct 7, 2024
91483a0
Add save-load test for tokenizers
Rocketknight1 Oct 7, 2024
47205a2
stash commit
Rocketknight1 Oct 8, 2024
820990b
Try popping the file
Rocketknight1 Oct 8, 2024
2aeb319
make fixup
Rocketknight1 Oct 8, 2024
a92dc10
Pop the arg correctly
Rocketknight1 Oct 8, 2024
1d54b7d
Pop the arg correctly
Rocketknight1 Oct 8, 2024
5b26748
Add processor test
Rocketknight1 Oct 9, 2024
743a4a5
Fix processor code
Rocketknight1 Oct 9, 2024
455a297
stash commit
Rocketknight1 Oct 10, 2024
52614f3
Processor clobbers child tokenizer's chat template
Rocketknight1 Oct 10, 2024
a200a42
Processor clobbers child tokenizer's chat template
Rocketknight1 Oct 10, 2024
5c792bc
make fixup
Rocketknight1 Oct 10, 2024
5305a87
Split processor/tokenizer files to avoid interactions
Rocketknight1 Oct 10, 2024
3656468
fix test
Rocketknight1 Oct 10, 2024
e412619
Expand processor tests
Rocketknight1 Oct 10, 2024
bd81682
Rename arg to "save_raw_chat_template" across all classes
Rocketknight1 Oct 10, 2024
50d374a
Update processor warning
Rocketknight1 Oct 10, 2024
a81a7e1
Move templates to single file
Rocketknight1 Oct 31, 2024
f9127c1
Move templates to single file
Rocketknight1 Nov 1, 2024
3878828
Improve testing for processor/tokenizer clashes
Rocketknight1 Nov 1, 2024
74cd295
Improve testing for processor/tokenizer clashes
Rocketknight1 Nov 1, 2024
cf577e1
Extend saving test
Rocketknight1 Nov 5, 2024
b97e76d
Test file priority correctly
Rocketknight1 Nov 5, 2024
e5b8b76
make fixup
Rocketknight1 Nov 5, 2024
5b40f04
Don't pop the chat template file before the slow tokenizer gets a look
Rocketknight1 Nov 5, 2024
2159cdc
Remove breakpoint
Rocketknight1 Nov 5, 2024
5b88768
make fixup
Rocketknight1 Nov 5, 2024
1187047
Fix error
Rocketknight1 Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 38 additions & 13 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
TruncationStrategy,
)
from .utils import (
CHAT_TEMPLATE_NAME,
PROCESSOR_NAME,
PushToHubMixin,
TensorType,
Expand Down Expand Up @@ -527,18 +526,24 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):
# If we save using the predefined names, we can load using `from_pretrained`
# plus we save chat_template in its own file
output_processor_file = os.path.join(save_directory, PROCESSOR_NAME)
output_chat_template_file = os.path.join(save_directory, CHAT_TEMPLATE_NAME)
output_raw_chat_template_file = os.path.join(save_directory, "chat_template.jinja")
output_chat_template_file = os.path.join(save_directory, "chat_template.json")

processor_dict = self.to_dict()
# Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict`
# to avoid serializing chat template in json config file. So let's get it from `self` directly
if self.chat_template is not None:
chat_template_json_string = (
json.dumps({"chat_template": self.chat_template}, indent=2, sort_keys=True) + "\n"
)
with open(output_chat_template_file, "w", encoding="utf-8") as writer:
writer.write(chat_template_json_string)
logger.info(f"chat template saved in {output_chat_template_file}")
if kwargs.get("save_raw_chat_template", False):
with open(output_raw_chat_template_file, "w", encoding="utf-8") as writer:
writer.write(self.chat_template)
logger.info(f"chat template saved in {output_raw_chat_template_file}")
else:
chat_template_json_string = (
json.dumps({"chat_template": self.chat_template}, indent=2, sort_keys=True) + "\n"
)
with open(output_chat_template_file, "w", encoding="utf-8") as writer:
writer.write(chat_template_json_string)
logger.info(f"chat template saved in {output_chat_template_file}")

# For now, let's not save to `processor_config.json` if the processor doesn't have extra attributes and
# `auto_map` is not specified.
Expand Down Expand Up @@ -601,21 +606,23 @@ def get_processor_dict(
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
processor_file = os.path.join(pretrained_model_name_or_path, PROCESSOR_NAME)
chat_template_file = os.path.join(pretrained_model_name_or_path, "chat_template.json")

if os.path.isfile(pretrained_model_name_or_path):
resolved_processor_file = pretrained_model_name_or_path
# cant't load chat-template when given a file as pretrained_model_name_or_path
resolved_chat_template_file = None
resolved_raw_chat_template_file = None
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
processor_file = pretrained_model_name_or_path
resolved_processor_file = download_url(pretrained_model_name_or_path)
# can't load chat-template when given a file url as pretrained_model_name_or_path
resolved_chat_template_file = None
resolved_raw_chat_template_file = None
else:
processor_file = PROCESSOR_NAME
chat_template_file = CHAT_TEMPLATE_NAME
chat_template_file = "chat_template.json"
raw_chat_template_file = "chat_template.jinja"
try:
# Load from local folder or from cache or download from model Hub and cache
resolved_processor_file = cached_file(
Expand Down Expand Up @@ -650,6 +657,21 @@ def get_processor_dict(
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
)

resolved_raw_chat_template_file = cached_file(
pretrained_model_name_or_path,
raw_chat_template_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
Expand All @@ -664,8 +686,11 @@ def get_processor_dict(
)

# Add chat template as kwarg before returning because most models don't have processor config
chat_template = None
if resolved_chat_template_file is not None:
if resolved_raw_chat_template_file is not None:
with open(resolved_raw_chat_template_file, "r", encoding="utf-8") as reader:
chat_template = reader.read()
kwargs["chat_template"] = chat_template
elif resolved_chat_template_file is not None:
with open(resolved_chat_template_file, "r", encoding="utf-8") as reader:
text = reader.read()
chat_template = json.loads(text)["chat_template"]
Expand Down Expand Up @@ -696,7 +721,7 @@ def get_processor_dict(

if "chat_template" in processor_dict and processor_dict["chat_template"] is not None:
logger.warning_once(
"Chat templates should be in a 'chat_template.json' file but found key='chat_template' "
"Chat templates should be in a 'chat_template.jinja' file but found key='chat_template' "
"in the processor's config. Make sure to move your template to its own file."
)

Expand Down
19 changes: 19 additions & 0 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class EncodingFast:
SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
ADDED_TOKENS_FILE = "added_tokens.json"
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
CHAT_TEMPLATE_FILE = "chat_template.jinja"

# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
FULL_TOKENIZER_FILE = "tokenizer.json"
Expand Down Expand Up @@ -1941,6 +1942,7 @@ def from_pretrained(
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
# tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders
"tokenizer_file": FULL_TOKENIZER_FILE,
"chat_template_file": CHAT_TEMPLATE_FILE,
}
vocab_files = {**cls.vocab_files_names, **additional_files_names}
if "tokenizer_file" in vocab_files:
Expand Down Expand Up @@ -2097,6 +2099,12 @@ def _from_pretrained(
config_tokenizer_class = None
init_kwargs = init_configuration

# If an independent chat template file exists, it takes priority over template entries in the tokenizer config
chat_template_file = resolved_vocab_files.pop("chat_template_file", None)
if chat_template_file is not None:
with open(chat_template_file) as chat_template_handle:
init_kwargs["chat_template"] = chat_template_handle.read() # Clobbers any template in the config

if not _is_local:
if "auto_map" in init_kwargs:
# For backward compatibility with odl format.
Expand Down Expand Up @@ -2396,6 +2404,9 @@ def save_pretrained(
tokenizer_config_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_CONFIG_FILE
)
chat_template_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + CHAT_TEMPLATE_FILE
)

tokenizer_config = copy.deepcopy(self.init_kwargs)

Expand All @@ -2418,7 +2429,15 @@ def save_pretrained(
if isinstance(self.chat_template, dict):
# Chat template dicts are saved to the config as lists of dicts with fixed key names.
# They will be reconstructed as a single dict during loading.
# We're trying to discourage chat template dicts, and they are always
# saved in the config, never as single files.
tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()]
elif kwargs.get("save_raw_chat_template", False):
with open(chat_template_file, "w", encoding="utf-8") as f:
f.write(self.chat_template)
logger.info(f"chat template saved in {chat_template_file}")
if "chat_template" in tokenizer_config:
tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too
else:
tokenizer_config["chat_template"] = self.chat_template

Expand Down
25 changes: 25 additions & 0 deletions tests/test_processing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import random
import tempfile
from pathlib import Path
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -519,3 +520,27 @@ def test_prepare_and_validate_optional_call_args(self):
processor.prepare_and_validate_optional_call_args(
*(f"optional_{i}" for i in range(num_optional_call_args + 1))
)

def test_chat_template_save_loading(self):
processor = self.get_processor()
existing_tokenizer_template = getattr(processor.tokenizer, "chat_template", None)
processor.chat_template = "test template"
with tempfile.TemporaryDirectory() as tmpdirname:
processor.save_pretrained(tmpdirname)
self.assertTrue(Path(tmpdirname, "chat_template.json").is_file())
self.assertFalse(Path(tmpdirname, "chat_template.jinja").is_file())
reloaded_processor = self.processor_class.from_pretrained(tmpdirname)
self.assertEqual(processor.chat_template, reloaded_processor.chat_template)
# When we don't use single-file chat template saving, processor and tokenizer chat templates
# should remain separate
self.assertEqual(getattr(reloaded_processor.tokenizer, "chat_template", None), existing_tokenizer_template)

with tempfile.TemporaryDirectory() as tmpdirname:
processor.save_pretrained(tmpdirname, save_raw_chat_template=True)
self.assertTrue(Path(tmpdirname, "chat_template.jinja").is_file())
self.assertFalse(Path(tmpdirname, "chat_template.json").is_file())
reloaded_processor = self.processor_class.from_pretrained(tmpdirname)
self.assertEqual(processor.chat_template, reloaded_processor.chat_template)
# When we save as single files, tokenizers and processors share a chat template, which means
# the reloaded tokenizer should get the chat template as well
self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template)
67 changes: 53 additions & 14 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import unittest
from collections import OrderedDict
from itertools import takewhile
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union

from parameterized import parameterized
Expand Down Expand Up @@ -1107,13 +1108,29 @@ def test_chat_template(self):

with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name)
tokenizer = tokenizer.from_pretrained(tmp_dir_name)
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)

self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
self.assertEqual(new_tokenizer.chat_template, dummy_template) # Test template has persisted
output = new_tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
self.assertEqual(output, expected_output) # Test output is the same after reloading
# Check that no error raised
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
new_tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)

with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=True)
chat_template_file = Path(tmp_dir_name) / "chat_template.jinja"
self.assertTrue(chat_template_file.is_file())
self.assertEqual(chat_template_file.read_text(), dummy_template)
config_dict = json.loads((Path(tmp_dir_name) / "tokenizer_config.json").read_text())
# Assert the chat template is not in the config when it's saved as a separate file
self.assertNotIn("chat_template", config_dict)
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)

self.assertEqual(new_tokenizer.chat_template, dummy_template) # Test template has persisted
output = new_tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
self.assertEqual(output, expected_output) # Test output is the same after reloading
# Check that no error raised
new_tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)

@require_jinja
def test_chat_template_batched(self):
Expand Down Expand Up @@ -1526,18 +1543,40 @@ def test_chat_template_dict_saving(self):
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
for save_raw_chat_template in (True, False):
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
with tempfile.TemporaryDirectory() as tmp_dir_name:
# Test that save_raw_chat_template is ignored when there's a dict of multiple templates
tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=save_raw_chat_template)
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
# Assert that chat templates are correctly serialized as lists of dictionaries
self.assertEqual(
config_dict["chat_template"],
[
{"name": "template1", "template": "{{'a'}}"},
{"name": "template2", "template": "{{'b'}}"},
],
)
self.assertFalse(os.path.exists(os.path.join(tmp_dir_name, "chat_template.jinja")))
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
# Assert that the serialized list is correctly reconstructed as a single dict
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template)

@require_jinja
def test_chat_template_file_priority(self):
dummy_template1 = "a"
dummy_template2 = "b"
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name)
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
# Assert that chat templates are correctly serialized as lists of dictionaries
self.assertEqual(
config_dict["chat_template"],
[{"name": "template1", "template": "{{'a'}}"}, {"name": "template2", "template": "{{'b'}}"}],
)
tokenizer.chat_template = dummy_template1
tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=False)
with Path(tmp_dir_name, "chat_template.jinja").open("w") as f:
f.write(dummy_template2)
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
# Assert that the serialized list is correctly reconstructed as a single dict
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template)
# Assert the file template clobbers any template in the config
self.assertEqual(new_tokenizer.chat_template, dummy_template2)

def test_number_of_added_tokens(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
Expand Down
Loading