Skip to content

Commit

Permalink
⚡ improve apply_character_map
Browse files Browse the repository at this point in the history
  • Loading branch information
zhzLuke96 committed Jun 20, 2024
1 parent b0de527 commit ea7399f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
6 changes: 4 additions & 2 deletions modules/ChatTTS/ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,14 @@ def _infer(
reserved_tokens = self.pretrain_models[
"tokenizer"
].additional_special_tokens
invalid_characters = count_invalid_characters(t, reserved_tokens)
invalid_characters = count_invalid_characters(
t, reserved_tokens=reserved_tokens
)
if len(invalid_characters):
self.logger.log(
logging.WARNING, f"Invalid characters found! : {invalid_characters}"
)
text[i] = apply_character_map(t)
text[i] = apply_character_map(t, reserved_tokens=reserved_tokens)

if not skip_refine_text:
text_tokens_gen = refine_text(
Expand Down
36 changes: 30 additions & 6 deletions modules/ChatTTS/ChatTTS/utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,35 @@ def detect_language(sentence):
}


def apply_half2full_map(text):
translation_table = str.maketrans(halfwidth_2_fullwidth_map)
return text.translate(translation_table)
def replace_unsupported_chars(text, replace_dict, reserved_tokens: list = []):
escaped_tokens = [re.escape(token) for token in reserved_tokens]
special_tokens_pattern = "|".join(escaped_tokens)
tokens = re.split(f"({special_tokens_pattern})", text)

def replace_chars(segment):
for old_char, new_char in replace_dict.items():
segment = segment.replace(old_char, new_char)
return segment

result = "".join(
(
replace_chars(segment)
if not re.match(special_tokens_pattern, segment)
else segment
)
for segment in tokens
)

return result


def apply_half2full_map(text, reserved_tokens: list = []):
return replace_unsupported_chars(
text, halfwidth_2_fullwidth_map, reserved_tokens=reserved_tokens
)


def apply_character_map(text):
translation_table = str.maketrans(character_map)
return text.translate(translation_table)
def apply_character_map(text, reserved_tokens: list = []):
return replace_unsupported_chars(
text, character_map, reserved_tokens=reserved_tokens
)

0 comments on commit ea7399f

Please sign in to comment.