From ecc31770ade35f478d5144b2ead55eb33fb8e6ae Mon Sep 17 00:00:00 2001 From: lvdongyi Date: Wed, 20 Nov 2024 05:25:40 +0000 Subject: [PATCH] Fix chat template for Gemma when answer is contained within question --- paddlenlp/transformers/gemma/tokenizer.py | 31 ++++++++++++++++++++++ tests/transformers/gemma/test_tokenizer.py | 26 ++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/paddlenlp/transformers/gemma/tokenizer.py b/paddlenlp/transformers/gemma/tokenizer.py index 54a6413d4f2a..4cee34e6187c 100644 --- a/paddlenlp/transformers/gemma/tokenizer.py +++ b/paddlenlp/transformers/gemma/tokenizer.py @@ -14,6 +14,7 @@ # limitations under the License. import os +import re from shutil import copyfile from typing import Any, Dict, List, Optional, Tuple @@ -310,3 +311,33 @@ def create_token_type_ids_from_sequences( output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) return output + + def _extract_non_learnable_parts(self, origin_msg: List[Dict[str, str]], split_s: List[str]): + regex_pattern = "|".join(map(re.escape, split_s)) + rendered_messages = self.chat_template.render( + messages=origin_msg, add_generation_prompt=False, **self.special_tokens_map + ) + pattern = re.compile(r"(?:%s)" % regex_pattern) + split_positions = [match.span() for match in pattern.finditer(rendered_messages)] + + filtered_positions = [] + for start, end in split_positions: + # Find the last occurrence of '' before the split index + last_start = rendered_messages.rfind("", 0, start) + if last_start == -1: + continue # Skip if '' is not found + model_start = last_start + len("") + + # Get the text following 'model_start' and check if it starts with 'model' + following_text = rendered_messages[model_start:].lstrip() + if following_text.startswith("model"): + filtered_positions.append((start, end)) + non_learnable_parts = [] + last_end = 0 + for start, end in filtered_positions: + non_learnable_parts.append(rendered_messages[last_end:start]) + last_end = end + remaining_part = rendered_messages[last_end:] + if remaining_part: + non_learnable_parts.append(remaining_part) + return non_learnable_parts diff --git a/tests/transformers/gemma/test_tokenizer.py b/tests/transformers/gemma/test_tokenizer.py index e8527c40ee4b..17c792f7a447 100644 --- a/tests/transformers/gemma/test_tokenizer.py +++ b/tests/transformers/gemma/test_tokenizer.py @@ -223,3 +223,29 @@ def test_add_special_tokens(self): self.assertEqual(encoded, input_encoded + special_token_id) decoded = tokenizer.decode(encoded, skip_special_tokens=True) self.assertTrue(special_token not in decoded) + + def test_extract_non_learnable_parts(self): + models_with_templates = ["google/gemma-2b-it", "google/gemma-7b-it"] + dummy_conversastions = [ + ["Q.", "A."], + ["Q.A.", "A."], + ["Q?", "A!"], + ] + decode_outputs = [ + ["user\nQ.\nmodel\n", "A.\n"], + ["user\nQ.A.\nmodel\n", "A.\n"], + ["user\nQ?\nmodel\n", "A!\n"], + ] + context_data = {} + context_data["is_training"] = True + for model_id in models_with_templates: + tokenizer = GemmaTokenizer.from_pretrained(model_id) + if tokenizer.chat_template is None: + continue + conversation_result: list[tuple[list[int], list[int]]] = tokenizer.encode_chat_inputs( + dummy_conversastions, + context_data=context_data, + ) + for idx, round in enumerate(conversation_result["conversations"]): + self.assertEquals(tokenizer.decode(round[0]), decode_outputs[idx][0]) + self.assertEquals(tokenizer.decode(round[1]), decode_outputs[idx][1])