Skip to content

Commit

Permalink
♻️ refactor models_setup
Browse files Browse the repository at this point in the history
  • Loading branch information
zhzLuke96 committed Jun 23, 2024
1 parent 9af0361 commit ff9c7c0
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 123 deletions.
57 changes: 46 additions & 11 deletions launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

try:
setup_ffmpeg_path()
# NOTE: 因为 logger 都是在模块中初始化,所以这个 config 必须在最前面
logging.basicConfig(
level=os.getenv("LOG_LEVEL", "INFO"),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
Expand All @@ -16,29 +17,44 @@

import uvicorn

from modules.api.api_setup import setup_api_args, setup_model_args, setup_uvicon_args
from modules.api.api_setup import setup_api_args
from modules.models_setup import setup_model_args
from modules.utils import env
from modules.utils.ignore_warn import ignore_useless_warnings

ignore_useless_warnings()

logger = logging.getLogger(__name__)

if __name__ == "__main__":
import dotenv

dotenv.load_dotenv(
dotenv_path=os.getenv("ENV_FILE", ".env.api"),
def setup_uvicon_args(parser: argparse.ArgumentParser):
parser.add_argument("--host", type=str, help="Host to run the server on")
parser.add_argument("--port", type=int, help="Port to run the server on")
parser.add_argument(
"--reload", action="store_true", help="Enable auto-reload for development"
)
parser = argparse.ArgumentParser(
description="Start the FastAPI server with command line arguments"
parser.add_argument("--workers", type=int, help="Number of worker processes")
parser.add_argument("--log_level", type=str, help="Log level")
parser.add_argument("--access_log", action="store_true", help="Enable access log")
parser.add_argument(
"--proxy_headers", action="store_true", help="Enable proxy headers"
)
parser.add_argument(
"--timeout_keep_alive", type=int, help="Keep-alive timeout duration"
)
parser.add_argument(
"--timeout_graceful_shutdown",
type=int,
help="Graceful shutdown timeout duration",
)
parser.add_argument("--ssl_keyfile", type=str, help="SSL key file path")
parser.add_argument("--ssl_certfile", type=str, help="SSL certificate file path")
parser.add_argument(
"--ssl_keyfile_password", type=str, help="SSL key file password"
)
setup_api_args(parser)
setup_model_args(parser)
setup_uvicon_args(parser=parser)

args = parser.parse_args()

def process_uvicon_args(args):
host = env.get_and_update_env(args, "host", "0.0.0.0", str)
port = env.get_and_update_env(args, "port", 7870, int)
reload = env.get_and_update_env(args, "reload", False, bool)
Expand Down Expand Up @@ -71,3 +87,22 @@
ssl_certfile=ssl_certfile,
ssl_keyfile_password=ssl_keyfile_password,
)


if __name__ == "__main__":
import dotenv

dotenv.load_dotenv(
dotenv_path=os.getenv("ENV_FILE", ".env.api"),
)
parser = argparse.ArgumentParser(
description="Start the FastAPI server with command line arguments"
)
# NOTE: 主进程中不需要处理 model args / api args,但是要接收这些参数, 具体处理在 worker.py 中
setup_api_args(parser=parser)
setup_model_args(parser=parser)
setup_uvicon_args(parser=parser)

args = parser.parse_args()

process_uvicon_args(args)
104 changes: 5 additions & 99 deletions modules/api/api_setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import argparse
import logging

from modules import config, generate_audio
from fastapi import FastAPI

from modules import config
from modules.api.Api import APIManager
from modules.api.impl import (
google_api,
Expand All @@ -15,15 +17,12 @@
tts_api,
xtts_v2_api,
)
from modules.devices import devices
from modules.Enhancer.ResembleEnhance import load_enhancer
from modules.models import load_chat_tts
from modules.utils import env

logger = logging.getLogger(__name__)


def create_api(app, exclude=[]):
def create_api(app: FastAPI, exclude=[]):
app_mgr = APIManager(app=app, exclude_patterns=exclude)

ping_api.setup(app_mgr)
Expand All @@ -40,99 +39,6 @@ def create_api(app, exclude=[]):
return app_mgr


def setup_model_args(parser: argparse.ArgumentParser):
parser.add_argument("--compile", action="store_true", help="Enable model compile")
parser.add_argument(
"--no_half",
action="store_true",
help="Disalbe half precision for model inference",
)
parser.add_argument(
"--off_tqdm",
action="store_true",
help="Disable tqdm progress bar",
)
parser.add_argument(
"--device_id",
type=str,
help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
default=None,
)
parser.add_argument(
"--use_cpu",
nargs="+",
help="use CPU as torch device for specified modules",
default=[],
type=str.lower,
choices=["all", "chattts", "enhancer", "trainer"],
)
parser.add_argument(
"--lru_size",
type=int,
default=64,
help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
)
parser.add_argument(
"--debug_generate",
action="store_true",
help="Enable debug mode for audio generation",
)
parser.add_argument(
"--preload_models",
action="store_true",
help="Preload all models at startup",
)


def process_model_args(args):
lru_size = env.get_and_update_env(args, "lru_size", 64, int)
compile = env.get_and_update_env(args, "compile", False, bool)
device_id = env.get_and_update_env(args, "device_id", None, str)
use_cpu = env.get_and_update_env(args, "use_cpu", [], list)
no_half = env.get_and_update_env(args, "no_half", False, bool)
off_tqdm = env.get_and_update_env(args, "off_tqdm", False, bool)
debug_generate = env.get_and_update_env(args, "debug_generate", False, bool)
preload_models = env.get_and_update_env(args, "preload_models", False, bool)

generate_audio.setup_lru_cache()
devices.reset_device()
devices.first_time_calculation()

if debug_generate:
generate_audio.logger.setLevel(logging.DEBUG)

if preload_models:
load_chat_tts()
load_enhancer()


def setup_uvicon_args(parser: argparse.ArgumentParser):
parser.add_argument("--host", type=str, help="Host to run the server on")
parser.add_argument("--port", type=int, help="Port to run the server on")
parser.add_argument(
"--reload", action="store_true", help="Enable auto-reload for development"
)
parser.add_argument("--workers", type=int, help="Number of worker processes")
parser.add_argument("--log_level", type=str, help="Log level")
parser.add_argument("--access_log", action="store_true", help="Enable access log")
parser.add_argument(
"--proxy_headers", action="store_true", help="Enable proxy headers"
)
parser.add_argument(
"--timeout_keep_alive", type=int, help="Keep-alive timeout duration"
)
parser.add_argument(
"--timeout_graceful_shutdown",
type=int,
help="Graceful shutdown timeout duration",
)
parser.add_argument("--ssl_keyfile", type=str, help="SSL key file path")
parser.add_argument("--ssl_certfile", type=str, help="SSL certificate file path")
parser.add_argument(
"--ssl_keyfile_password", type=str, help="SSL key file password"
)


def setup_api_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--cors_origin",
Expand All @@ -157,7 +63,7 @@ def setup_api_args(parser: argparse.ArgumentParser):
)


def process_api_args(args, app):
def process_api_args(args: argparse.Namespace, app: FastAPI):
cors_origin = env.get_and_update_env(args, "cors_origin", "*", str)
no_playground = env.get_and_update_env(args, "no_playground", False, bool)
no_docs = env.get_and_update_env(args, "no_docs", False, bool)
Expand Down
10 changes: 3 additions & 7 deletions modules/api/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import dotenv
from fastapi import FastAPI

from launch import setup_uvicon_args
from modules.ffmpeg_env import setup_ffmpeg_path
from modules.models_setup import process_model_args, setup_model_args

setup_ffmpeg_path()
logging.basicConfig(
Expand All @@ -14,13 +16,7 @@
)

from modules import config
from modules.api.api_setup import (
process_api_args,
process_model_args,
setup_api_args,
setup_model_args,
setup_uvicon_args,
)
from modules.api.api_setup import process_api_args, setup_api_args
from modules.api.app_config import app_description, app_title, app_version
from modules.utils.torch_opt import configure_torch_optimizations

Expand Down
74 changes: 74 additions & 0 deletions modules/models_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import argparse
import logging

from modules import generate_audio
from modules.devices import devices
from modules.Enhancer.ResembleEnhance import load_enhancer
from modules.models import load_chat_tts
from modules.utils import env


def setup_model_args(parser: argparse.ArgumentParser):
parser.add_argument("--compile", action="store_true", help="Enable model compile")
parser.add_argument(
"--no_half",
action="store_true",
help="Disalbe half precision for model inference",
)
parser.add_argument(
"--off_tqdm",
action="store_true",
help="Disable tqdm progress bar",
)
parser.add_argument(
"--device_id",
type=str,
help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
default=None,
)
parser.add_argument(
"--use_cpu",
nargs="+",
help="use CPU as torch device for specified modules",
default=[],
type=str.lower,
choices=["all", "chattts", "enhancer", "trainer"],
)
parser.add_argument(
"--lru_size",
type=int,
default=64,
help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
)
parser.add_argument(
"--debug_generate",
action="store_true",
help="Enable debug mode for audio generation",
)
parser.add_argument(
"--preload_models",
action="store_true",
help="Preload all models at startup",
)


def process_model_args(args: argparse.Namespace):
lru_size = env.get_and_update_env(args, "lru_size", 64, int)
compile = env.get_and_update_env(args, "compile", False, bool)
device_id = env.get_and_update_env(args, "device_id", None, str)
use_cpu = env.get_and_update_env(args, "use_cpu", [], list)
no_half = env.get_and_update_env(args, "no_half", False, bool)
off_tqdm = env.get_and_update_env(args, "off_tqdm", False, bool)
debug_generate = env.get_and_update_env(args, "debug_generate", False, bool)
preload_models = env.get_and_update_env(args, "preload_models", False, bool)

generate_audio.setup_lru_cache()
devices.reset_device()
devices.first_time_calculation()

if debug_generate:
generate_audio.logger.setLevel(logging.DEBUG)

if preload_models:
load_chat_tts()
load_enhancer()
9 changes: 3 additions & 6 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

try:
setup_ffmpeg_path()
# NOTE: 因为 logger 都是在模块中初始化,所以这个 config 必须在最前面
logging.basicConfig(
level=os.getenv("LOG_LEVEL", "INFO"),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
Expand All @@ -16,14 +17,10 @@
import argparse

from modules import config
from modules.api.api_setup import (
process_api_args,
process_model_args,
setup_api_args,
setup_model_args,
)
from modules.api.api_setup import process_api_args, setup_api_args
from modules.api.app_config import app_description, app_title, app_version
from modules.gradio_dcls_fix import dcls_patch
from modules.models_setup import process_model_args, setup_model_args
from modules.utils.env import get_and_update_env
from modules.utils.ignore_warn import ignore_useless_warnings
from modules.utils.torch_opt import configure_torch_optimizations
Expand Down

0 comments on commit ff9c7c0

Please sign in to comment.