Skip to content

Commit

Permalink
[Improvement] add context_data support for chat_template rendering (#…
Browse files Browse the repository at this point in the history
…7480)

* add context_data support for chat_template rendering

* update method name

* add chat_template_with_context_data file
  • Loading branch information
wj-Mcat authored Nov 22, 2023
1 parent 4645ddc commit d4faef6
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 2 deletions.
9 changes: 7 additions & 2 deletions paddlenlp/transformers/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,12 +638,17 @@ def chat_template_sep_token_id(self):
return self.sep_token_id

def apply_chat_template(
self, conversation: List[List[str, str]] | str, tokenize: bool = True, **tokenizer_kwargs
self,
conversation: List[List[str, str]] | str,
tokenize: bool = True,
context_data: Dict[str, Any] = {},
**tokenizer_kwargs
) -> str | dict[str, numpy.ndarray | paddle.Tensor]:
"""apply chat_template rules to conversation which should not be batched data
Args:
conversation (List[List[str, str]] | str): the conversation messages between user and bot
context_data (Dict[str, Any]): the context data for chat_template.json
tokenize (bool, optional): whether do tokenization. Defaults to True.
Returns:
Expand All @@ -657,7 +662,7 @@ def apply_chat_template(
"so you should apply the conversation one by one."
)

query = self.chat_template(conversation)
query = self.chat_template(conversation, context_data=context_data)
if not tokenize:
return query

Expand Down
5 changes: 5 additions & 0 deletions tests/fixtures/chat_template_with_context.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"system": "你是一个人工智能助手{{system}}-{{instruction}}\n",
"conversation": ["Human: {{user}}<sep> Bot:", "{{bot}}\n"],
"query": "Human: {{query}}<sep> Bot:"
}
32 changes: 32 additions & 0 deletions tests/transformers/test_chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,24 @@ def test_conversation(self):
assert final_query == "Human: 你好<sep>Bot: 您好,我是个人人工智能助手\n\n" + query


class ChatTemplateContextDataTest(unittest.TestCase):
chat_template_config_file = "./tests/fixtures/chat_template_with_context.json"

@property
def chat_template(self):
return ChatTemplate.from_file(self.chat_template_config_file)

def test_inference_template(self):
query = [["你好"]]
context_data = {
"system": "<<SYSTEM-MESSAGE>>",
"instruction": "<<INSTRUCTION-MESSAGE>>",
}
final_query = self.chat_template(query, context_data=context_data)
expected_query = "你是一个人工智能助手<<SYSTEM-MESSAGE>>-<<INSTRUCTION-MESSAGE>>\nHuman: 你好<sep> Bot:"
self.assertEqual(final_query, expected_query)


class ChatTemplateIntegrationTest(unittest.TestCase):
def test_llama2_chat_template(self):
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat")
Expand Down Expand Up @@ -256,3 +274,17 @@ def test_at_least_one_turn(self):
sentence_result,
expected_sentence,
)

def test_inference_template_with_context_data(self):
tokenizer = AutoTokenizer.from_pretrained("__internal_testing__/tiny-random-llama")
chat_template_config_file = "./tests/fixtures/chat_template_with_context.json"
tokenizer.init_chat_template(chat_template_config_file)

query = "你好"
context_data = {
"system": "<<SYSTEM-MESSAGE>>",
"instruction": "<<INSTRUCTION-MESSAGE>>",
}
final_query = tokenizer.apply_chat_template(query, context_data=context_data, tokenize=False)
expected_query = "你是一个人工智能助手<<SYSTEM-MESSAGE>>-<<INSTRUCTION-MESSAGE>>\nHuman: 你好<sep> Bot:"
self.assertEqual(final_query, expected_query)

0 comments on commit d4faef6

Please sign in to comment.