Skip to content

Commit

Permalink
[Tokenizer] Fix chat template for Gemma when answer is contained with…
Browse files Browse the repository at this point in the history
…in question (#9462)
  • Loading branch information
lvdongyi authored Nov 21, 2024
1 parent 5ceb930 commit 872aafa
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
31 changes: 31 additions & 0 deletions paddlenlp/transformers/gemma/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import os
import re
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple

Expand Down Expand Up @@ -310,3 +311,33 @@ def create_token_type_ids_from_sequences(
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)

return output

def _extract_non_learnable_parts(self, origin_msg: List[Dict[str, str]], split_s: List[str]):
regex_pattern = "|".join(map(re.escape, split_s))
rendered_messages = self.chat_template.render(
messages=origin_msg, add_generation_prompt=False, **self.special_tokens_map
)
pattern = re.compile(r"(?:%s)" % regex_pattern)
split_positions = [match.span() for match in pattern.finditer(rendered_messages)]

filtered_positions = []
for start, end in split_positions:
# Find the last occurrence of '<start_of_turn>' before the split index
last_start = rendered_messages.rfind("<start_of_turn>", 0, start)
if last_start == -1:
continue # Skip if '<start_of_turn>' is not found
model_start = last_start + len("<start_of_turn>")

# Get the text following 'model_start' and check if it starts with 'model'
following_text = rendered_messages[model_start:].lstrip()
if following_text.startswith("model"):
filtered_positions.append((start, end))
non_learnable_parts = []
last_end = 0
for start, end in filtered_positions:
non_learnable_parts.append(rendered_messages[last_end:start])
last_end = end
remaining_part = rendered_messages[last_end:]
if remaining_part:
non_learnable_parts.append(remaining_part)
return non_learnable_parts
26 changes: 26 additions & 0 deletions tests/transformers/gemma/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,29 @@ def test_add_special_tokens(self):
self.assertEqual(encoded, input_encoded + special_token_id)
decoded = tokenizer.decode(encoded, skip_special_tokens=True)
self.assertTrue(special_token not in decoded)

def test_extract_non_learnable_parts(self):
models_with_templates = ["google/gemma-2b-it", "google/gemma-7b-it"]
dummy_conversastions = [
["Q.", "A."],
["Q.A.", "A."],
["Q?", "A!"],
]
decode_outputs = [
["<bos><start_of_turn>user\nQ.<end_of_turn>\n<start_of_turn>model\n", "A.<end_of_turn>\n"],
["<start_of_turn>user\nQ.A.<end_of_turn>\n<start_of_turn>model\n", "A.<end_of_turn>\n"],
["<start_of_turn>user\nQ?<end_of_turn>\n<start_of_turn>model\n", "A!<end_of_turn>\n"],
]
context_data = {}
context_data["is_training"] = True
for model_id in models_with_templates:
tokenizer = GemmaTokenizer.from_pretrained(model_id)
if tokenizer.chat_template is None:
continue
conversation_result: list[tuple[list[int], list[int]]] = tokenizer.encode_chat_inputs(
dummy_conversastions,
context_data=context_data,
)
for idx, round in enumerate(conversation_result["conversations"]):
self.assertEquals(tokenizer.decode(round[0]), decode_outputs[idx][0])
self.assertEquals(tokenizer.decode(round[1]), decode_outputs[idx][1])

0 comments on commit 872aafa

Please sign in to comment.