diff --git a/llm/utils.py b/llm/utils.py index d15d72dcddae..2b90bf950c14 100644 --- a/llm/utils.py +++ b/llm/utils.py @@ -467,6 +467,7 @@ def dybatch_preprocess( max_length=src_length, return_attention_mask=False, return_token_type_ids=False, + add_special_tokens=tokenizer.chat_template is None or isinstance(tokenizer, ChatGLMv2Tokenizer), ) input_ids.append(tokens["input_ids"][0]) diff --git a/tests/fixtures/llm/pretrain.yaml b/tests/fixtures/llm/pretrain.yaml index df8c27a293f2..c438c32f7ea5 100644 --- a/tests/fixtures/llm/pretrain.yaml +++ b/tests/fixtures/llm/pretrain.yaml @@ -33,6 +33,7 @@ inference-predict: batch_size: 2 decode_strategy: greedy_search dtype: float16 + chat_template: none inference-to-static: default: @@ -45,3 +46,4 @@ inference-infer: batch_size: 2 decode_strategy: greedy_search max_length: 20 + chat_template: none