Skip to content

Commit

Permalink
:zep: improve cache
Browse files Browse the repository at this point in the history
- 调整修改 cache 逻辑
- 修复 speaker cache missing
- 修复 环境变量问题
  • Loading branch information
zhzLuke96 committed Jun 6, 2024
1 parent 400afe6 commit eedc558
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 55 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ ChatTTS-Forge 是一个功能强大的文本转语音生成工具,支持通过
> 由于 `MKL FFT doesn't support tensors of type: Half` 所以 `--half``--use_cpu="all"` 不能同时使用
## 推理速度

> 测试平台 `GeForce RTX 2080 Ti`
| 参数组合 | 推理速度 (tk/s) | 中文速度 (char/s) |
| ------------------ | --------------- | ----------------- |
| `--compile` | 54 | 5.4 |
| `--half --compile` | 51 | 5.1 |
| 默认无参数 | 30 | 3.0 |
| `--half` | 28 | 2.8 |
| `--use_cpu=all` | 20 | 2.0 |

### launch.py

Launch.py 是 ChatTTS-Forge 的启动脚本,用于配置和启动 API 服务器。
Expand Down Expand Up @@ -172,6 +184,7 @@ WebUI.py 是一个用于配置和启动 Gradio Web UI 界面的脚本。
| `--server_port` | `int` | `7860` | 服务器端口 |
| `--share` | `bool` | `False` | 启用共享模式,允许外部访问 |
| `--debug` | `bool` | `False` | 启用调试模式 |
| `--compile` | `bool` | `False` | 启用模型编译 |
| `--auth` | `str` | `None` | 用于认证的用户名和密码,格式为 `username:password` |
| `--half` | `bool` | `False` | 开启 f16 半精度推理 |
| `--off_tqdm` | `bool` | `False` | 关闭 tqdm 进度条 |
Expand Down
23 changes: 1 addition & 22 deletions launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,6 @@ def get_and_update_env(*args):
device_id = get_and_update_env(args, "device_id", None, str)
use_cpu = get_and_update_env(args, "use_cpu", [], list)

if compile:
print("Model compile is enabled")
config.enable_model_compile = True

def should_cache(*args, **kwargs):
spk_seed = kwargs.get("spk_seed", -1)
infer_seed = kwargs.get("infer_seed", -1)
return spk_seed != -1 and infer_seed != -1

api = create_api(no_docs=no_docs, exclude=exclude.split(","))
config.api = api

Expand All @@ -158,27 +149,15 @@ def should_cache(*args, **kwargs):
if not no_playground:
api.setup_playground()

if half:
config.model_config["half"] = True

if off_tqdm:
config.disable_tqdm = True

if compile:
logger.info("Model compile is enabled")
config.enable_model_compile = True

def should_cache(*args, **kwargs):
spk_seed = kwargs.get("spk_seed", -1)
infer_seed = kwargs.get("infer_seed", -1)
return spk_seed != -1 and infer_seed != -1

if lru_size > 0:
config.lru_size = lru_size
generate.generate_audio_batch = conditional_cache(should_cache)(
generate.generate_audio_batch
)

generate.setup_lru_cache()
devices.reset_device()
devices.first_time_calculation()

Expand Down
8 changes: 0 additions & 8 deletions modules/config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
from modules.utils.JsonObject import JsonObject

enable_model_compile = False

lru_size = 64

runtime_env_vars = JsonObject({})

api = None

model_config = {"half": False}

disable_tqdm = False
30 changes: 29 additions & 1 deletion modules/generate_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from modules.devices import devices
from typing import Union

from modules.utils.cache import conditional_cache

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -65,7 +67,7 @@ def generate_audio_batch(
"prompt2": prompt2 or "",
"prefix": prefix or "",
"repetition_penalty": 1.0,
"disable_tqdm": config.disable_tqdm,
"disable_tqdm": config.runtime_env_vars.off_tqdm,
}

if isinstance(spk, int):
Expand Down Expand Up @@ -103,6 +105,32 @@ def generate_audio_batch(
return [(sample_rate, np.array(wav).flatten().astype(np.float32)) for wav in wavs]


lru_cache_enabled = False


def setup_lru_cache():
global generate_audio_batch
global lru_cache_enabled

if lru_cache_enabled:
return
lru_cache_enabled = True

def should_cache(*args, **kwargs):
spk_seed = kwargs.get("spk", -1)
infer_seed = kwargs.get("infer_seed", -1)
return spk_seed != -1 and infer_seed != -1

lru_size = config.runtime_env_vars.lru_size
if isinstance(lru_size, int):
generate_audio_batch = conditional_cache(lru_size, should_cache)(
generate_audio_batch
)
logger.info(f"LRU cache enabled with size {lru_size}")
else:
logger.debug(f"LRU cache failed to enable, invalid size {lru_size}")


if __name__ == "__main__":
import soundfile as sf

Expand Down
5 changes: 4 additions & 1 deletion modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ def load_chat_tts():
global chat_tts
if chat_tts:
return chat_tts

chat_tts = ChatTTS.Chat()
chat_tts.load_models(
compile=config.enable_model_compile,
compile=config.runtime_env_vars.compile,
source="local",
local_path="./models/ChatTTS",
device=devices.device,
Expand All @@ -26,6 +27,8 @@ def load_chat_tts():
dtype_decoder=devices.dtype_decoder,
)

devices.torch_gc()

return chat_tts


Expand Down
2 changes: 1 addition & 1 deletion modules/refiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def refine_text(
"temperature": temperature,
"repetition_penalty": repetition_penalty,
"max_new_token": max_new_token,
"disable_tqdm": config.disable_tqdm,
"disable_tqdm": config.runtime_env_vars.off_tqdm,
},
do_text_normalization=False,
)
Expand Down
2 changes: 1 addition & 1 deletion modules/speaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def fix(self):
return is_update

def __hash__(self):
return str(self.id)
return hash(str(self.id))

def __eq__(self, other):
if not isinstance(other, Speaker):
Expand Down
81 changes: 77 additions & 4 deletions modules/utils/cache.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from functools import lru_cache
from typing import Callable
from typing import Callable, TypeVar, Any
from typing_extensions import ParamSpec

from functools import lru_cache, _CacheInfo

def conditional_cache(condition: Callable):

def conditional_cache(maxsize: int, condition: Callable):
def decorator(func):
@lru_cache(None)
@lru_cache_ext(maxsize=maxsize)
def cached_func(*args, **kwargs):
return func(*args, **kwargs)

Expand All @@ -17,3 +19,74 @@ def wrapper(*args, **kwargs):
return wrapper

return decorator


def hash_list(l: list) -> int:
__hash = 0
for i, e in enumerate(l):
__hash = hash((__hash, i, hash_item(e)))
return __hash


def hash_dict(d: dict) -> int:
__hash = 0
for k, v in d.items():
__hash = hash((__hash, k, hash_item(v)))
return __hash


def hash_item(e) -> int:
if hasattr(e, "__hash__") and callable(e.__hash__):
try:
return hash(e)
except TypeError:
pass
if isinstance(e, (list, set, tuple)):
return hash_list(list(e))
elif isinstance(e, (dict)):
return hash_dict(e)
else:
raise TypeError(f"unhashable type: {e.__class__}")


PT = ParamSpec("PT")
RT = TypeVar("RT")


def lru_cache_ext(
*opts, hashfunc: Callable[..., int] = hash_item, **kwopts
) -> Callable[[Callable[PT, RT]], Callable[PT, RT]]:
def decorator(func: Callable[PT, RT]) -> Callable[PT, RT]:
class _lru_cache_ext_wrapper:
args: tuple
kwargs: dict[str, Any]

def cache_info(self) -> _CacheInfo: ...
def cache_clear(self) -> None: ...

@classmethod
@lru_cache(*opts, **kwopts)
def cached_func(cls, args_hash: int) -> RT:
return func(*cls.args, **cls.kwargs)

@classmethod
def __call__(cls, *args: PT.args, **kwargs: PT.kwargs) -> RT:
__hash = hashfunc(
(
id(func),
*[hashfunc(a) for a in args],
*[(hashfunc(k), hashfunc(v)) for k, v in kwargs.items()],
)
)

cls.args = args
cls.kwargs = kwargs

cls.cache_info = cls.cached_func.cache_info
cls.cache_clear = cls.cached_func.cache_clear

return cls.cached_func(__hash)

return _lru_cache_ext_wrapper()

return decorator
20 changes: 3 additions & 17 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,7 @@ def create_interface():
default=[],
type=str.lower,
)
parser.add_argument("--compile", action="store_true", help="Enable model compile")

args = parser.parse_args()

Expand All @@ -906,6 +907,7 @@ def get_and_update_env(*args):
lru_size = get_and_update_env(args, "lru_size", 64, int)
device_id = get_and_update_env(args, "device_id", None, str)
use_cpu = get_and_update_env(args, "use_cpu", [], list)
compile = get_and_update_env(args, "compile", False, bool)

webui_config["tts_max"] = get_and_update_env(args, "tts_max_len", 1000, int)
webui_config["ssml_max"] = get_and_update_env(args, "ssml_max_len", 5000, int)
Expand All @@ -916,23 +918,7 @@ def get_and_update_env(*args):
if auth:
auth = tuple(auth.split(":"))

if half:
config.model_config["half"] = True

if off_tqdm:
config.disable_tqdm = True

def should_cache(*args, **kwargs):
spk_seed = kwargs.get("spk_seed", -1)
infer_seed = kwargs.get("infer_seed", -1)
return spk_seed != -1 and infer_seed != -1

if lru_size > 0:
config.lru_size = lru_size
generate.generate_audio_batch = conditional_cache(should_cache)(
generate.generate_audio_batch
)

generate.setup_lru_cache()
devices.reset_device()
devices.first_time_calculation()

Expand Down

0 comments on commit eedc558

Please sign in to comment.