Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add: .get_engine() で latest version を取得できるように #1421

Merged
merged 10 commits into from
Jun 28, 2024
21 changes: 20 additions & 1 deletion test/unit/tts_pipeline/test_tts_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi import HTTPException

from voicevox_engine.dev.tts_engine.mock import MockTTSEngine
from voicevox_engine.tts_pipeline.tts_engine import TTSEngineManager
from voicevox_engine.tts_pipeline.tts_engine import LATEST_VERSION, TTSEngineManager


def test_tts_engines_register_engine() -> None:
Expand Down Expand Up @@ -48,6 +48,25 @@ def test_tts_engines_get_engine_existing() -> None:
assert true_acquired_tts_engine == acquired_tts_engine


def test_tts_engines_get_engine_latest() -> None:
"""TTSEngineManager.get_engine(LATEST_VERSION) で最新版の TTS エンジンを取得できる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engine1 = MockTTSEngine()
tts_engine2 = MockTTSEngine()
tts_engine3 = MockTTSEngine()
tts_engines.register_engine(tts_engine1, "0.0.1")
tts_engines.register_engine(tts_engine2, "0.0.2")
tts_engines.register_engine(tts_engine3, "0.1.0")
# Expects
true_acquired_tts_engine = tts_engine3
# Outputs
acquired_tts_engine = tts_engines.get_engine(LATEST_VERSION)

# Test
assert true_acquired_tts_engine == acquired_tts_engine


def test_tts_engines_get_engine_missing() -> None:
"""TTSEngineManager.get_engine() で存在しない TTS エンジンを取得しようとするとエラーになる。"""
# Inputs
Expand Down
6 changes: 2 additions & 4 deletions voicevox_engine/app/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,9 @@ def _get_core_characters(version: str | None) -> list[CoreCharacter]:
)

app.include_router(
generate_tts_pipeline_router(
tts_engines, core_manager, preset_manager, cancellable_engine
)
generate_tts_pipeline_router(tts_engines, preset_manager, cancellable_engine)
)
app.include_router(generate_morphing_router(tts_engines, core_manager, metas_store))
app.include_router(generate_morphing_router(tts_engines, metas_store))
app.include_router(
generate_preset_router(preset_manager, verify_mutability_allowed)
)
Expand Down
9 changes: 3 additions & 6 deletions voicevox_engine/app/routers/morphing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from starlette.background import BackgroundTask
from starlette.responses import FileResponse

from voicevox_engine.core.core_initializer import CoreManager
from voicevox_engine.metas.Metas import StyleId
from voicevox_engine.metas.MetasStore import MetasStore
from voicevox_engine.model import AudioQuery
Expand All @@ -24,7 +23,7 @@
synthesis_morphing_parameter as _synthesis_morphing_parameter,
)
from voicevox_engine.morphing.morphing import synthesize_morphed_wave
from voicevox_engine.tts_pipeline.tts_engine import TTSEngineManager
from voicevox_engine.tts_pipeline.tts_engine import LATEST_VERSION, TTSEngineManager
from voicevox_engine.utility.file_utility import try_delete_file

# キャッシュを有効化
Expand All @@ -34,9 +33,7 @@


def generate_morphing_router(
tts_engines: TTSEngineManager,
core_manager: CoreManager,
metas_store: MetasStore,
tts_engines: TTSEngineManager, metas_store: MetasStore
) -> APIRouter:
"""モーフィング API Router を生成する"""
router = APIRouter(tags=["音声合成"])
Expand Down Expand Up @@ -89,7 +86,7 @@ def _synthesis_morphing(
指定された2種類のスタイルで音声を合成、指定した割合でモーフィングした音声を得ます。
モーフィングの割合は`morph_rate`で指定でき、0.0でベースのスタイル、1.0でターゲットのスタイルに近づきます。
"""
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)

# モーフィングが許可されないキャラクターペアを拒否する
Expand Down
31 changes: 15 additions & 16 deletions voicevox_engine/app/routers/tts_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
CancellableEngine,
CancellableEngineInternalError,
)
from voicevox_engine.core.core_initializer import CoreManager
from voicevox_engine.metas.Metas import StyleId
from voicevox_engine.model import AudioQuery
from voicevox_engine.preset.preset_manager import (
Expand All @@ -39,6 +38,7 @@
Score,
)
from voicevox_engine.tts_pipeline.tts_engine import (
LATEST_VERSION,
TalkSingInvalidInputError,
TTSEngineManager,
)
Expand All @@ -65,7 +65,6 @@ def __init__(self, err: ParseKanaError):

def generate_tts_pipeline_router(
tts_engines: TTSEngineManager,
core_manager: CoreManager,
preset_manager: PresetManager,
cancellable_engine: CancellableEngine | None,
) -> APIRouter:
Expand All @@ -85,7 +84,7 @@ def audio_query(
"""
音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
accent_phrases = engine.create_accent_phrases(text, style_id)
return AudioQuery(
Expand Down Expand Up @@ -116,7 +115,7 @@ def audio_query_from_preset(
"""
音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
try:
presets = preset_manager.load_presets()
Expand Down Expand Up @@ -175,7 +174,7 @@ def accent_phrases(
* アクセント位置を`'`で指定する。全てのアクセント句にはアクセント位置を1つ指定する必要がある。
* アクセント句末に`?`(全角)を入れることにより疑問文の発音ができる。
"""
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
if is_kana:
try:
Expand All @@ -197,7 +196,7 @@ def mora_data(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | SkipJsonSchema[None] = None,
) -> list[AccentPhrase]:
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
return engine.update_length_and_pitch(accent_phrases, style_id)

Expand All @@ -211,7 +210,7 @@ def mora_length(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | SkipJsonSchema[None] = None,
) -> list[AccentPhrase]:
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
return engine.update_length(accent_phrases, style_id)

Expand All @@ -225,7 +224,7 @@ def mora_pitch(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | SkipJsonSchema[None] = None,
) -> list[AccentPhrase]:
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
return engine.update_pitch(accent_phrases, style_id)

Expand Down Expand Up @@ -253,7 +252,7 @@ def synthesis(
] = True,
core_version: str | SkipJsonSchema[None] = None,
) -> FileResponse:
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
wave = engine.synthesize_wave(
query, style_id, enable_interrogative_upspeak=enable_interrogative_upspeak
Expand Down Expand Up @@ -294,8 +293,8 @@ def cancellable_synthesis(
status_code=404,
detail="実験的機能はデフォルトで無効になっています。使用するには引数を指定してください。",
)
version = core_version or core_manager.latest_version()
try:
version = core_version or LATEST_VERSION
f_name = cancellable_engine._synthesis_impl(
query, style_id, request, version=version
)
Expand Down Expand Up @@ -331,7 +330,7 @@ def multi_synthesis(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | SkipJsonSchema[None] = None,
) -> FileResponse:
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
sampling_rate = queries[0].outputSamplingRate

Expand Down Expand Up @@ -374,7 +373,7 @@ def sing_frame_audio_query(
"""
歌唱音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま歌唱音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
try:
phonemes, f0, volume = engine.create_sing_phoneme_and_f0_and_volume(
Expand Down Expand Up @@ -403,7 +402,7 @@ def sing_frame_volume(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | SkipJsonSchema[None] = None,
) -> list[float]:
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
try:
return engine.create_sing_volume_from_phoneme_and_f0(
Expand Down Expand Up @@ -432,7 +431,7 @@ def frame_synthesis(
"""
歌唱音声合成を行います。
"""
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
try:
wave = engine.frame_synthsize_wave(query, style_id)
Expand Down Expand Up @@ -528,7 +527,7 @@ def initialize_speaker(
指定されたスタイルを初期化します。
実行しなくても他のAPIは使用できますが、初回実行時に時間がかかることがあります。
"""
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
engine.initialize_synthesis(style_id, skip_reinit=skip_reinit)

Expand All @@ -540,7 +539,7 @@ def is_initialized_speaker(
"""
指定されたスタイルが初期化されているかどうかを返します。
"""
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
return engine.is_synthesis_initialized(style_id)

Expand Down
10 changes: 5 additions & 5 deletions voicevox_engine/cancellable_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .core.core_initializer import initialize_cores
from .metas.Metas import StyleId
from .model import AudioQuery
from .tts_pipeline.tts_engine import make_tts_engines_from_cores
from .tts_pipeline.tts_engine import LatestVersion, make_tts_engines_from_cores


class CancellableEngineInternalError(Exception):
Expand Down Expand Up @@ -149,7 +149,7 @@ def _synthesis_impl(
query: AudioQuery,
style_id: StyleId,
request: Request,
version: str,
version: str | LatestVersion,
) -> str:
"""
音声合成を行う関数
Expand All @@ -163,7 +163,7 @@ def _synthesis_impl(
request: fastapi.Request
接続確立時に受け取ったものをそのまま渡せばよい
https://fastapi.tiangolo.com/advanced/using-request-directly/
version: str
version

Returns
-------
Expand Down Expand Up @@ -245,9 +245,9 @@ def start_synthesis_subprocess(
while True:
try:
query, style_id, version = sub_proc_con.recv()
if tts_engines.has_engine(version):
try:
_engine = tts_engines.get_engine(version)
else:
except Exception:
Comment on lines -248 to +250
Copy link
Member

@Hiroshiba Hiroshiba Jun 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

!!

has_engine危ないかも?(LagestVersionが受け付けられない・・・けど型的には大丈夫か。)

確認したらここでの利用がなくなって、どうやらhas_engineがテストからしか使われなくなったっぽみです。
一緒に消してしまっても良いかも・・・?

だけどまあこのPRは問題じゃないから大丈夫そう!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

見逃していました、指摘ありがとうございます!
#1446 にて削除します。

# バージョンが見つからないエラー
sub_proc_con.send("")
continue
Expand Down
16 changes: 14 additions & 2 deletions voicevox_engine/tts_pipeline/tts_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

import copy
import math
from typing import Final, Literal, TypeAlias

import numpy as np
from fastapi import HTTPException
from numpy.typing import NDArray
from soxr import resample

from voicevox_engine.utility.core_version_utility import get_latest_version

from ..core.core_adapter import CoreAdapter, DeviceSupport
from ..core.core_initializer import CoreManager
from ..core.core_wrapper import CoreWrapper
Expand Down Expand Up @@ -697,6 +700,10 @@ def frame_synthsize_wave(
return wave


LatestVersion: TypeAlias = Literal["LATEST_VERSION"]
LATEST_VERSION: Final[LatestVersion] = "LATEST_VERSION"


class TTSEngineManager:
"""TTS エンジンの集まりを一括管理するマネージャー"""

Expand All @@ -707,13 +714,18 @@ def versions(self) -> list[str]:
"""登録されたエンジンのバージョン一覧を取得する。"""
return list(self._engines.keys())

def _latest_version(self) -> str:
return get_latest_version(self.versions())

def register_engine(self, engine: TTSEngine, version: str) -> None:
"""エンジンを登録する。"""
self._engines[version] = engine

def get_engine(self, version: str) -> TTSEngine:
def get_engine(self, version: str | LatestVersion) -> TTSEngine:
"""指定バージョンのエンジンを取得する。"""
if version in self._engines:
if version == LATEST_VERSION:
return self._engines[self._latest_version()]
elif version in self._engines:
return self._engines[version]

raise HTTPException(status_code=422, detail="不明なバージョンです")
Expand Down