Skip to content

Commit

Permalink
Properly tokenize the template for hunyuan video.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Dec 17, 2024
1 parent cd6f615 commit ca457f7
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
22 changes: 12 additions & 10 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import json
import logging
import numbers
import re

def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None)
Expand Down Expand Up @@ -429,13 +430,14 @@ def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedd
self.end_token = None

empty = self.tokenizer('')["input_ids"]
self.tokenizer_adds_end_token = has_end_token
if has_start_token:
self.tokens_start = 1
self.start_token = empty[0]
if has_end_token:
if end_token is not None:
self.end_token = end_token
else:
if end_token is not None:
self.end_token = end_token
else:
if has_end_token:
self.end_token = empty[1]
else:
self.tokens_start = 0
Expand Down Expand Up @@ -468,7 +470,7 @@ def _try_get_embedding(self, embedding_name:str):
Takes a potential embedding name and tries to retrieve it.
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
'''
split_embed = embedding_name.split(' ')
split_embed = embedding_name.split()
embedding_name = split_embed[0]
leftover = ' '.join(split_embed[1:])
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
Expand All @@ -491,18 +493,18 @@ def tokenize_with_weights(self, text:str, return_word_ids=False):
text = escape_important(text)
parsed_weights = token_weights(text, 1.0)

#tokenize words
# tokenize words
tokens = []
for weighted_segment, weight in parsed_weights:
to_tokenize = unescape_important(weighted_segment).replace("\n", " ")
split = to_tokenize.split(' {}'.format(self.embedding_identifier))
to_tokenize = unescape_important(weighted_segment)
split = re.split(' {0}|\n{0}'.format(self.embedding_identifier), to_tokenize)
to_tokenize = [split[0]]
for i in range(1, len(split)):
to_tokenize.append("{}{}".format(self.embedding_identifier, split[i]))

to_tokenize = [x for x in to_tokenize if x != ""]
for word in to_tokenize:
#if we find an embedding, deal with the embedding
# if we find an embedding, deal with the embedding
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
embedding_name = word[len(self.embedding_identifier):].strip('\n')
embed, leftover = self._try_get_embedding(embedding_name)
Expand All @@ -519,7 +521,7 @@ def tokenize_with_weights(self, text:str, return_word_ids=False):
else:
continue
end = 999999999999
if self.end_token is not None:
if self.tokenizer_adds_end_token:
end = -1
#parse word
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:end]])
Expand Down
9 changes: 5 additions & 4 deletions comfy/text_encoders/hunyuan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def llama_detect(state_dict, prefix=""):
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length)

class LLAMAModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
Expand All @@ -38,9 +38,7 @@ class HunyuanVideoTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.llama_template = """<|start_header_id|>system<|end_header_id|>
Describe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>""" # 93 tokens
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n""" # 95 tokens
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)

def tokenize_with_weights(self, text:str, return_word_ids=False):
Expand Down Expand Up @@ -86,6 +84,9 @@ def encode_token_weights(self, token_weight_pairs):
if v[0] == 128007: # <|end_header_id|>
template_end = i

if llama_out.shape[1] > (template_end + 2):
if token_weight_pairs_llama[0][template_end + 1][0] == 271:
template_end += 2
llama_out = llama_out[:, template_end:]
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end:]
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
Expand Down

0 comments on commit ca457f7

Please sign in to comment.