Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Hugging Face style chat templating #28

Open
simonw opened this issue Dec 16, 2023 · 4 comments
Open

Add support for Hugging Face style chat templating #28

simonw opened this issue Dec 16, 2023 · 4 comments
Labels
enhancement New feature or request

Comments

@simonw
Copy link
Owner

simonw commented Dec 16, 2023

This plugin urgently needs a better solution for handling chat templates, to better support models like Mixtral.

Currently it only supports one, for Llama 2, which is hard-coded like this:

# Now build the prompt pieces
first = True
if conversation is not None:
for prev_response in conversation.responses:
prompt_bits.append("<s>[INST] ")
if first:
prompt_bits.append(
f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n",
)
first = False
prompt_bits.append(
f"{prev_response.prompt.prompt} [/INST] ",
)
prompt_bits.append(
f"{prev_response.text()} </s>",
)

I think templating is the right way to go here. Rather than invent something new I'd like to reuse this Hugging Face mechanism, which was created back in September as far as I can tell:

https://huggingface.co/docs/transformers/chat_templating

Templates can use Jinja and end up looking something like this:

{% for message in messages %}
    {% if message['role'] == 'user' %}
        {{ ' ' }}
    {% endif %}
    {{ message['content'] }}
    {% if not loop.last %}
        {{ '  ' }}
    {% endif %}
{% endfor %}
{{ eos_token }}
@simonw simonw added the enhancement New feature or request label Dec 16, 2023
@simonw
Copy link
Owner Author

simonw commented Dec 16, 2023

Here's an example template for Mixtral Instruct: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json#L42

"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"

Decoded with https://observablehq.com/@simonw/display-content-from-a-json-string

{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}

Pretty printed by ChatGPT:

{{ bos_token }}
{% for message in messages %}
    {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
        {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
    {% endif %}
    {% if message['role'] == 'user' %}
        {{ '[INST] ' + message['content'] + ' [/INST]' }}
    {% elif message['role'] == 'assistant' %}
        {{ message['content'] + eos_token }}
    {% else %}
        {{ raise_exception('Only user and assistant roles are supported!') }}
    {% endif %}
{% endfor %}

I think bos_token and eos_token are defined in that JSON too:

  "bos_token": "<s>",
  "clean_up_tokenization_spaces": false,
  "eos_token": "</s>",

@simonw
Copy link
Owner Author

simonw commented Dec 16, 2023

Here's the code that renders that: https://github.com/huggingface/transformers/blob/238d2e3c44366aba9dc5c770c95475765a6725cb/src/transformers/tokenization_utils_base.py#L1760-L1779

    @lru_cache
    def _compile_jinja_template(self, chat_template):
        try:
            import jinja2
            from jinja2.exceptions import TemplateError
            from jinja2.sandbox import ImmutableSandboxedEnvironment
        except ImportError:
            raise ImportError("apply_chat_template requires jinja2 to be installed.")

        if version.parse(jinja2.__version__) <= version.parse("3.0.0"):
            raise ImportError(
                "apply_chat_template requires jinja2>=3.0.0 to be installed. Your version is " f"{jinja2.__version__}."
            )

        def raise_exception(message):
            raise TemplateError(message)

        jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
        jinja_env.globals["raise_exception"] = raise_exception
        return jinja_env.from_string(chat_template)

Which is called from here:

https://github.com/huggingface/transformers/blob/238d2e3c44366aba9dc5c770c95475765a6725cb/src/transformers/tokenization_utils_base.py#L1738-L1743

        # Compilation function uses a cache to avoid recompiling the same template
        compiled_template = self._compile_jinja_template(chat_template)

        rendered = compiled_template.render(
            messages=conversation, add_generation_prompt=add_generation_prompt, **self.special_tokens_map
        )

Not yet sure where self.special_tokens_map is populated.

The add_generation_prompt controls if there's the equivalent of Assistant: added on at the end of the prompt - true for completion models, instruction tuned models tend not to need it.

@simonw
Copy link
Owner Author

simonw commented Dec 16, 2023

Useful note about Mistral: https://discord.com/channels/1144547040454508606/1156609509674975262/1184860885748035594

from transformers import AutoTokenizer
from typing import List, Dict

def build_prompt(
    messages: List[Dict[str, str]],
    tokenizer: AutoTokenizer,
):
    prompt = ""
    for i, msg in enumerate(messages):
        is_user = {"user": True, "assistant": False}[msg["role"]]
        assert (i % 2 == 0) == is_user
        content = msg["content"]
        assert content == content.strip()
        if is_user:
            prompt += f"[INST] {content} [/INST]"
        else:
            prompt += f" {content}</s>"
    tokens_ids = tokenizer.encode(prompt)
    token_str = tokenizer.convert_ids_to_tokens(tokens_ids)
    return tokens_ids, token_str

tok = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
messages = [
    {"role": "user", "content": "2+2"},
    {"role": "assistant", "content": "4!"},
    {"role": "user", "content": "+2"},
    {"role": "assistant", "content": "6!"},
    {"role": "user", "content": "+4"},
]

tokens_ids, token_str = build_prompt(messages, tok)
print(tokens_ids)
# [1, 733, 16289, 28793, 28705, 28750, 28806, 28750, 733, 28748, 16289, 28793, 28705, 28781, 28808, 2, 733, 16289, 28793, 648, 28750, 733, 28748, 16289, 28793, 28705, 28784, 28808, 2, 733, 16289, 28793, 648, 28781, 733, 28748, 16289, 28793]
print(token_str)
# ['<s>', '▁[', 'INST', ']', '▁', '2', '+', '2', '▁[', '/', 'INST', ']', '▁', '4', '!', '</s>', '▁[', 'INST', ']', '▁+', '2', '▁[', '/', 'INST', ']', '▁', '6', '!', '</s>', '▁[', 'INST', ']', '▁+', '4', '▁[', '/', 'INST', ']']

@aawadat
Copy link

aawadat commented Jul 6, 2024


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants