diff --git a/README.md b/README.md index 0c61920..0e3a92a 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,18 @@ ChatTTS-Forge 是一个功能强大的文本转语音生成工具,支持通过 > 由于 `MKL FFT doesn't support tensors of type: Half` 所以 `--half` 和 `--use_cpu="all"` 不能同时使用 +## 推理速度 + +> 测试平台 `GeForce RTX 2080 Ti` + +| 参数组合 | 推理速度 (tk/s) | 中文速度 (char/s) | +| ------------------ | --------------- | ----------------- | +| `--compile` | 54 | 5.4 | +| `--half --compile` | 51 | 5.1 | +| 默认无参数 | 30 | 3.0 | +| `--half` | 28 | 2.8 | +| `--use_cpu=all` | 20 | 2.0 | + ### launch.py Launch.py 是 ChatTTS-Forge 的启动脚本,用于配置和启动 API 服务器。 @@ -172,6 +184,7 @@ WebUI.py 是一个用于配置和启动 Gradio Web UI 界面的脚本。 | `--server_port` | `int` | `7860` | 服务器端口 | | `--share` | `bool` | `False` | 启用共享模式,允许外部访问 | | `--debug` | `bool` | `False` | 启用调试模式 | +| `--compile` | `bool` | `False` | 启用模型编译 | | `--auth` | `str` | `None` | 用于认证的用户名和密码,格式为 `username:password` | | `--half` | `bool` | `False` | 开启 f16 半精度推理 | | `--off_tqdm` | `bool` | `False` | 关闭 tqdm 进度条 | diff --git a/launch.py b/launch.py index 38a0392..3432ff8 100644 --- a/launch.py +++ b/launch.py @@ -140,15 +140,6 @@ def get_and_update_env(*args): device_id = get_and_update_env(args, "device_id", None, str) use_cpu = get_and_update_env(args, "use_cpu", [], list) - if compile: - print("Model compile is enabled") - config.enable_model_compile = True - - def should_cache(*args, **kwargs): - spk_seed = kwargs.get("spk_seed", -1) - infer_seed = kwargs.get("infer_seed", -1) - return spk_seed != -1 and infer_seed != -1 - api = create_api(no_docs=no_docs, exclude=exclude.split(",")) config.api = api @@ -158,27 +149,15 @@ def should_cache(*args, **kwargs): if not no_playground: api.setup_playground() - if half: - config.model_config["half"] = True - - if off_tqdm: - config.disable_tqdm = True - if compile: logger.info("Model compile is enabled") - config.enable_model_compile = True def should_cache(*args, **kwargs): spk_seed = kwargs.get("spk_seed", -1) infer_seed = kwargs.get("infer_seed", -1) return spk_seed != -1 and infer_seed != -1 - if lru_size > 0: - config.lru_size = lru_size - generate.generate_audio_batch = conditional_cache(should_cache)( - generate.generate_audio_batch - ) - + generate.setup_lru_cache() devices.reset_device() devices.first_time_calculation() diff --git a/modules/config.py b/modules/config.py index 4a61d92..3dd24c4 100644 --- a/modules/config.py +++ b/modules/config.py @@ -1,13 +1,5 @@ from modules.utils.JsonObject import JsonObject -enable_model_compile = False - -lru_size = 64 - runtime_env_vars = JsonObject({}) api = None - -model_config = {"half": False} - -disable_tqdm = False diff --git a/modules/generate_audio.py b/modules/generate_audio.py index 5d25bc3..da84a00 100644 --- a/modules/generate_audio.py +++ b/modules/generate_audio.py @@ -11,6 +11,8 @@ from modules.devices import devices from typing import Union +from modules.utils.cache import conditional_cache + logger = logging.getLogger(__name__) @@ -65,7 +67,7 @@ def generate_audio_batch( "prompt2": prompt2 or "", "prefix": prefix or "", "repetition_penalty": 1.0, - "disable_tqdm": config.disable_tqdm, + "disable_tqdm": config.runtime_env_vars.off_tqdm, } if isinstance(spk, int): @@ -103,6 +105,32 @@ def generate_audio_batch( return [(sample_rate, np.array(wav).flatten().astype(np.float32)) for wav in wavs] +lru_cache_enabled = False + + +def setup_lru_cache(): + global generate_audio_batch + global lru_cache_enabled + + if lru_cache_enabled: + return + lru_cache_enabled = True + + def should_cache(*args, **kwargs): + spk_seed = kwargs.get("spk", -1) + infer_seed = kwargs.get("infer_seed", -1) + return spk_seed != -1 and infer_seed != -1 + + lru_size = config.runtime_env_vars.lru_size + if isinstance(lru_size, int): + generate_audio_batch = conditional_cache(lru_size, should_cache)( + generate_audio_batch + ) + logger.info(f"LRU cache enabled with size {lru_size}") + else: + logger.debug(f"LRU cache failed to enable, invalid size {lru_size}") + + if __name__ == "__main__": import soundfile as sf diff --git a/modules/models.py b/modules/models.py index a813560..a6a818b 100644 --- a/modules/models.py +++ b/modules/models.py @@ -13,9 +13,10 @@ def load_chat_tts(): global chat_tts if chat_tts: return chat_tts + chat_tts = ChatTTS.Chat() chat_tts.load_models( - compile=config.enable_model_compile, + compile=config.runtime_env_vars.compile, source="local", local_path="./models/ChatTTS", device=devices.device, @@ -26,6 +27,8 @@ def load_chat_tts(): dtype_decoder=devices.dtype_decoder, ) + devices.torch_gc() + return chat_tts diff --git a/modules/refiner.py b/modules/refiner.py index 7d4d6bf..bf72918 100644 --- a/modules/refiner.py +++ b/modules/refiner.py @@ -29,7 +29,7 @@ def refine_text( "temperature": temperature, "repetition_penalty": repetition_penalty, "max_new_token": max_new_token, - "disable_tqdm": config.disable_tqdm, + "disable_tqdm": config.runtime_env_vars.off_tqdm, }, do_text_normalization=False, ) diff --git a/modules/speaker.py b/modules/speaker.py index 5ca0418..3af6f62 100644 --- a/modules/speaker.py +++ b/modules/speaker.py @@ -55,7 +55,7 @@ def fix(self): return is_update def __hash__(self): - return str(self.id) + return hash(str(self.id)) def __eq__(self, other): if not isinstance(other, Speaker): diff --git a/modules/utils/cache.py b/modules/utils/cache.py index 81b695c..329627d 100644 --- a/modules/utils/cache.py +++ b/modules/utils/cache.py @@ -1,10 +1,12 @@ -from functools import lru_cache -from typing import Callable +from typing import Callable, TypeVar, Any +from typing_extensions import ParamSpec +from functools import lru_cache, _CacheInfo -def conditional_cache(condition: Callable): + +def conditional_cache(maxsize: int, condition: Callable): def decorator(func): - @lru_cache(None) + @lru_cache_ext(maxsize=maxsize) def cached_func(*args, **kwargs): return func(*args, **kwargs) @@ -17,3 +19,74 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +def hash_list(l: list) -> int: + __hash = 0 + for i, e in enumerate(l): + __hash = hash((__hash, i, hash_item(e))) + return __hash + + +def hash_dict(d: dict) -> int: + __hash = 0 + for k, v in d.items(): + __hash = hash((__hash, k, hash_item(v))) + return __hash + + +def hash_item(e) -> int: + if hasattr(e, "__hash__") and callable(e.__hash__): + try: + return hash(e) + except TypeError: + pass + if isinstance(e, (list, set, tuple)): + return hash_list(list(e)) + elif isinstance(e, (dict)): + return hash_dict(e) + else: + raise TypeError(f"unhashable type: {e.__class__}") + + +PT = ParamSpec("PT") +RT = TypeVar("RT") + + +def lru_cache_ext( + *opts, hashfunc: Callable[..., int] = hash_item, **kwopts +) -> Callable[[Callable[PT, RT]], Callable[PT, RT]]: + def decorator(func: Callable[PT, RT]) -> Callable[PT, RT]: + class _lru_cache_ext_wrapper: + args: tuple + kwargs: dict[str, Any] + + def cache_info(self) -> _CacheInfo: ... + def cache_clear(self) -> None: ... + + @classmethod + @lru_cache(*opts, **kwopts) + def cached_func(cls, args_hash: int) -> RT: + return func(*cls.args, **cls.kwargs) + + @classmethod + def __call__(cls, *args: PT.args, **kwargs: PT.kwargs) -> RT: + __hash = hashfunc( + ( + id(func), + *[hashfunc(a) for a in args], + *[(hashfunc(k), hashfunc(v)) for k, v in kwargs.items()], + ) + ) + + cls.args = args + cls.kwargs = kwargs + + cls.cache_info = cls.cached_func.cache_info + cls.cache_clear = cls.cached_func.cache_clear + + return cls.cached_func(__hash) + + return _lru_cache_ext_wrapper() + + return decorator diff --git a/webui.py b/webui.py index 368ed68..3412e32 100644 --- a/webui.py +++ b/webui.py @@ -887,6 +887,7 @@ def create_interface(): default=[], type=str.lower, ) + parser.add_argument("--compile", action="store_true", help="Enable model compile") args = parser.parse_args() @@ -906,6 +907,7 @@ def get_and_update_env(*args): lru_size = get_and_update_env(args, "lru_size", 64, int) device_id = get_and_update_env(args, "device_id", None, str) use_cpu = get_and_update_env(args, "use_cpu", [], list) + compile = get_and_update_env(args, "compile", False, bool) webui_config["tts_max"] = get_and_update_env(args, "tts_max_len", 1000, int) webui_config["ssml_max"] = get_and_update_env(args, "ssml_max_len", 5000, int) @@ -916,23 +918,7 @@ def get_and_update_env(*args): if auth: auth = tuple(auth.split(":")) - if half: - config.model_config["half"] = True - - if off_tqdm: - config.disable_tqdm = True - - def should_cache(*args, **kwargs): - spk_seed = kwargs.get("spk_seed", -1) - infer_seed = kwargs.get("infer_seed", -1) - return spk_seed != -1 and infer_seed != -1 - - if lru_size > 0: - config.lru_size = lru_size - generate.generate_audio_batch = conditional_cache(should_cache)( - generate.generate_audio_batch - ) - + generate.setup_lru_cache() devices.reset_device() devices.first_time_calculation()