-
-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for Hugging Face style chat templating #28
Comments
Here's an example template for Mixtral Instruct: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json#L42
Decoded with https://observablehq.com/@simonw/display-content-from-a-json-string {{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %} Pretty printed by ChatGPT: {{ bos_token }}
{% for message in messages %}
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif %}
{% if message['role'] == 'user' %}
{{ '[INST] ' + message['content'] + ' [/INST]' }}
{% elif message['role'] == 'assistant' %}
{{ message['content'] + eos_token }}
{% else %}
{{ raise_exception('Only user and assistant roles are supported!') }}
{% endif %}
{% endfor %} I think
|
Here's the code that renders that: https://github.com/huggingface/transformers/blob/238d2e3c44366aba9dc5c770c95475765a6725cb/src/transformers/tokenization_utils_base.py#L1760-L1779 @lru_cache
def _compile_jinja_template(self, chat_template):
try:
import jinja2
from jinja2.exceptions import TemplateError
from jinja2.sandbox import ImmutableSandboxedEnvironment
except ImportError:
raise ImportError("apply_chat_template requires jinja2 to be installed.")
if version.parse(jinja2.__version__) <= version.parse("3.0.0"):
raise ImportError(
"apply_chat_template requires jinja2>=3.0.0 to be installed. Your version is " f"{jinja2.__version__}."
)
def raise_exception(message):
raise TemplateError(message)
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
jinja_env.globals["raise_exception"] = raise_exception
return jinja_env.from_string(chat_template) Which is called from here: # Compilation function uses a cache to avoid recompiling the same template
compiled_template = self._compile_jinja_template(chat_template)
rendered = compiled_template.render(
messages=conversation, add_generation_prompt=add_generation_prompt, **self.special_tokens_map
) Not yet sure where The |
Useful note about Mistral: https://discord.com/channels/1144547040454508606/1156609509674975262/1184860885748035594 from transformers import AutoTokenizer
from typing import List, Dict
def build_prompt(
messages: List[Dict[str, str]],
tokenizer: AutoTokenizer,
):
prompt = ""
for i, msg in enumerate(messages):
is_user = {"user": True, "assistant": False}[msg["role"]]
assert (i % 2 == 0) == is_user
content = msg["content"]
assert content == content.strip()
if is_user:
prompt += f"[INST] {content} [/INST]"
else:
prompt += f" {content}</s>"
tokens_ids = tokenizer.encode(prompt)
token_str = tokenizer.convert_ids_to_tokens(tokens_ids)
return tokens_ids, token_str
tok = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
messages = [
{"role": "user", "content": "2+2"},
{"role": "assistant", "content": "4!"},
{"role": "user", "content": "+2"},
{"role": "assistant", "content": "6!"},
{"role": "user", "content": "+4"},
]
tokens_ids, token_str = build_prompt(messages, tok)
print(tokens_ids)
# [1, 733, 16289, 28793, 28705, 28750, 28806, 28750, 733, 28748, 16289, 28793, 28705, 28781, 28808, 2, 733, 16289, 28793, 648, 28750, 733, 28748, 16289, 28793, 28705, 28784, 28808, 2, 733, 16289, 28793, 648, 28781, 733, 28748, 16289, 28793]
print(token_str)
# ['<s>', '▁[', 'INST', ']', '▁', '2', '+', '2', '▁[', '/', 'INST', ']', '▁', '4', '!', '</s>', '▁[', 'INST', ']', '▁+', '2', '▁[', '/', 'INST', ']', '▁', '6', '!', '</s>', '▁[', 'INST', ']', '▁+', '4', '▁[', '/', 'INST', ']'] |
|
This plugin urgently needs a better solution for handling chat templates, to better support models like Mixtral.
Currently it only supports one, for Llama 2, which is hard-coded like this:
llm-llama-cpp/llm_llama_cpp.py
Lines 220 to 235 in dc53ef9
I think templating is the right way to go here. Rather than invent something new I'd like to reuse this Hugging Face mechanism, which was created back in September as far as I can tell:
https://huggingface.co/docs/transformers/chat_templating
Templates can use Jinja and end up looking something like this:
The text was updated successfully, but these errors were encountered: