Skip to content

Commit

Permalink
整理: UserDictionary クラスを追加 (#1222)
Browse files Browse the repository at this point in the history
refactor: `UserDictionary` クラスを追加
  • Loading branch information
tarepan committed May 13, 2024
1 parent 477cdb6 commit c356918
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 122 deletions.
2 changes: 2 additions & 0 deletions build_util/make_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from voicevox_engine.preset.PresetManager import PresetManager
from voicevox_engine.setting.SettingLoader import USER_SETTING_PATH, SettingHandler
from voicevox_engine.tts_pipeline.tts_engine import CoreAdapter
from voicevox_engine.user_dict.user_dict import UserDictionary
from voicevox_engine.utility.path_utility import engine_root


Expand Down Expand Up @@ -44,6 +45,7 @@ def generate_api_docs_html(schema: str) -> str:
preset_manager=PresetManager( # FIXME: impl MockPresetManager
preset_path=engine_root() / "presets.yaml",
),
user_dict=UserDictionary(),
)
api_schema = json.dumps(app.openapi())

Expand Down
4 changes: 4 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from voicevox_engine.setting.Setting import CorsPolicyMode
from voicevox_engine.setting.SettingLoader import USER_SETTING_PATH, SettingHandler
from voicevox_engine.tts_pipeline.tts_engine import make_tts_engines_from_cores
from voicevox_engine.user_dict.user_dict import UserDictionary
from voicevox_engine.utility.core_version_utility import get_latest_version
from voicevox_engine.utility.path_utility import engine_root
from voicevox_engine.utility.run_utility import decide_boolean_from_env
Expand Down Expand Up @@ -294,6 +295,8 @@ def main() -> None:
# ファイルの存在に関わらず指定されたパスをプリセットファイルとして使用する
preset_manager = PresetManager(preset_path)

use_dict = UserDictionary()

if arg_disable_mutable_api:
disable_mutable_api = True
else:
Expand All @@ -306,6 +309,7 @@ def main() -> None:
latest_core_version,
setting_loader,
preset_manager,
use_dict,
cancellable_engine,
root_dir,
cors_policy_mode,
Expand Down
3 changes: 3 additions & 0 deletions test/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from voicevox_engine.preset.PresetManager import PresetManager
from voicevox_engine.setting.SettingLoader import SettingHandler
from voicevox_engine.tts_pipeline.tts_engine import make_tts_engines_from_cores
from voicevox_engine.user_dict.user_dict import UserDictionary
from voicevox_engine.utility.core_version_utility import get_latest_version


Expand All @@ -26,13 +27,15 @@ def app_params(tmp_path: Path) -> dict[str, Any]:
preset_path = tmp_path / "presets.yaml"
shutil.copyfile(original_preset_path, preset_path)
preset_manager = PresetManager(preset_path)
user_dict = UserDictionary()

return {
"tts_engines": tts_engines,
"cores": cores,
"latest_core_version": latest_core_version,
"setting_loader": setting_loader,
"preset_manager": preset_manager,
"user_dict": user_dict,
}


Expand Down
138 changes: 70 additions & 68 deletions test/user_dict/test_user_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,8 @@
)
from voicevox_engine.user_dict.user_dict import (
UserDictInputError,
UserDictionary,
_create_word,
apply_word,
delete_word,
import_user_dict,
read_dict,
rewrite_word,
update_dict,
)

# jsonとして保存される正しい形式の辞書データ
Expand Down Expand Up @@ -85,15 +80,22 @@ def tearDown(self) -> None:
self.tmp_dir.cleanup()

def test_read_not_exist_json(self) -> None:
user_dict = UserDictionary(user_dict_path=self.tmp_dir_path / "not_exist.json")
self.assertEqual(
read_dict(user_dict_path=(self.tmp_dir_path / "not_exist.json")),
user_dict.read_dict(),
{},
)

def test_create_word(self) -> None:
# 将来的に品詞などが追加された時にテストを増やす
self.assertEqual(
_create_word(surface="test", pronunciation="テスト", accent_type=1),
_create_word(
surface="test",
pronunciation="テスト",
accent_type=1,
word_type=None,
priority=None,
),
UserDictWord(
surface="test",
priority=5,
Expand All @@ -113,15 +115,13 @@ def test_create_word(self) -> None:
)

def test_apply_word_without_json(self) -> None:
user_dict_path = self.tmp_dir_path / "test_apply_word_without_json.json"
apply_word(
surface="test",
pronunciation="テスト",
accent_type=1,
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_apply_word_without_json.dic"),

user_dict = UserDictionary(
user_dict_path=self.tmp_dir_path / "test_apply_word_without_json.json",
compiled_dict_path=self.tmp_dir_path / "test_apply_word_without_json.dic",
)
res = read_dict(user_dict_path=user_dict_path)
user_dict.apply_word(surface="test", pronunciation="テスト", accent_type=1)
res = user_dict.read_dict()
self.assertEqual(len(res), 1)
new_word = get_new_word(res)
self.assertEqual(
Expand All @@ -138,14 +138,16 @@ def test_apply_word_with_json(self) -> None:
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
apply_word(
user_dict = UserDictionary(
user_dict_path=user_dict_path,
compiled_dict_path=self.tmp_dir_path / "test_apply_word_with_json.dic",
)
user_dict.apply_word(
surface="test2",
pronunciation="テストツー",
accent_type=3,
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_apply_word_with_json.dic"),
)
res = read_dict(user_dict_path=user_dict_path)
res = user_dict.read_dict()
self.assertEqual(len(res), 2)
new_word = get_new_word(res)
self.assertEqual(
Expand All @@ -162,33 +164,35 @@ def test_rewrite_word_invalid_id(self) -> None:
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
user_dict = UserDictionary(
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_rewrite_word_invalid_id.dic"),
)
self.assertRaises(
UserDictInputError,
rewrite_word,
user_dict.rewrite_word,
word_uuid="c2be4dc5-d07d-4767-8be1-04a1bb3f05a9",
surface="test2",
pronunciation="テストツー",
accent_type=2,
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_rewrite_word_invalid_id.dic"),
)

def test_rewrite_word_valid_id(self) -> None:
user_dict_path = self.tmp_dir_path / "test_rewrite_word_valid_id.json"
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
rewrite_word(
user_dict = UserDictionary(
user_dict_path=user_dict_path,
compiled_dict_path=self.tmp_dir_path / "test_rewrite_word_valid_id.dic",
)
user_dict.rewrite_word(
word_uuid="aab7dda2-0d97-43c8-8cb7-3f440dab9b4e",
surface="test2",
pronunciation="テストツー",
accent_type=2,
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_rewrite_word_valid_id.dic"),
)
new_word = read_dict(user_dict_path=user_dict_path)[
"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"
]
new_word = user_dict.read_dict()["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"]
self.assertEqual(
(new_word.surface, new_word.pronunciation, new_word.accent_type),
("test2", "テストツー", 2),
Expand All @@ -199,25 +203,27 @@ def test_delete_word_invalid_id(self) -> None:
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
user_dict = UserDictionary(
user_dict_path=user_dict_path,
compiled_dict_path=self.tmp_dir_path / "test_delete_word_invalid_id.dic",
)
self.assertRaises(
UserDictInputError,
delete_word,
user_dict.delete_word,
word_uuid="c2be4dc5-d07d-4767-8be1-04a1bb3f05a9",
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_delete_word_invalid_id.dic"),
)

def test_delete_word_valid_id(self) -> None:
user_dict_path = self.tmp_dir_path / "test_delete_word_valid_id.json"
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
delete_word(
word_uuid="aab7dda2-0d97-43c8-8cb7-3f440dab9b4e",
user_dict = UserDictionary(
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_delete_word_valid_id.dic"),
compiled_dict_path=self.tmp_dir_path / "test_delete_word_valid_id.dic",
)
self.assertEqual(len(read_dict(user_dict_path=user_dict_path)), 0)
user_dict.delete_word(word_uuid="aab7dda2-0d97-43c8-8cb7-3f440dab9b4e")
self.assertEqual(len(user_dict.read_dict()), 0)

def test_priority(self) -> None:
for pos in part_of_speech_data:
Expand All @@ -239,18 +245,18 @@ def test_import_dict(self) -> None:
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
import_user_dict(
{"b1affe2a-d5f0-4050-926c-f28e0c1d9a98": import_word},
override=False,
user_dict_path=user_dict_path,
compiled_dict_path=compiled_dict_path,
user_dict = UserDictionary(
user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path
)
user_dict.import_user_dict(
{"b1affe2a-d5f0-4050-926c-f28e0c1d9a98": import_word}, override=False
)
self.assertEqual(
read_dict(user_dict_path)["b1affe2a-d5f0-4050-926c-f28e0c1d9a98"],
user_dict.read_dict()["b1affe2a-d5f0-4050-926c-f28e0c1d9a98"],
import_word,
)
self.assertEqual(
read_dict(user_dict_path)["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"],
user_dict.read_dict()["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"],
UserDictWord(**valid_dict_dict_api["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"]),
)

Expand All @@ -260,14 +266,14 @@ def test_import_dict_no_override(self) -> None:
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
import_user_dict(
{"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": import_word},
override=False,
user_dict_path=user_dict_path,
compiled_dict_path=compiled_dict_path,
user_dict = UserDictionary(
user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path
)
user_dict.import_user_dict(
{"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": import_word}, override=False
)
self.assertEqual(
read_dict(user_dict_path)["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"],
user_dict.read_dict()["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"],
UserDictWord(**valid_dict_dict_api["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"]),
)

Expand All @@ -277,14 +283,14 @@ def test_import_dict_override(self) -> None:
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
import_user_dict(
{"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": import_word},
override=True,
user_dict_path=user_dict_path,
compiled_dict_path=compiled_dict_path,
user_dict = UserDictionary(
user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path
)
user_dict.import_user_dict(
{"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": import_word}, override=True
)
self.assertEqual(
read_dict(user_dict_path)["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"],
user_dict.read_dict()["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"],
import_word,
)

Expand All @@ -296,15 +302,16 @@ def test_import_invalid_word(self) -> None:
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
user_dict = UserDictionary(
user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path
)
self.assertRaises(
AssertionError,
import_user_dict,
user_dict.import_user_dict,
{
"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": invalid_accent_associative_rule_word
},
override=True,
user_dict_path=user_dict_path,
compiled_dict_path=compiled_dict_path,
)
invalid_pos_word = deepcopy(import_word)
invalid_pos_word.context_id = 2
Expand All @@ -314,39 +321,34 @@ def test_import_invalid_word(self) -> None:
invalid_pos_word.part_of_speech_detail_3 = "*"
self.assertRaises(
ValueError,
import_user_dict,
user_dict.import_user_dict,
{"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": invalid_pos_word},
override=True,
user_dict_path=user_dict_path,
compiled_dict_path=compiled_dict_path,
)

def test_update_dict(self) -> None:
user_dict_path = self.tmp_dir_path / "test_update_dict.json"
compiled_dict_path = self.tmp_dir_path / "test_update_dict.dic"
update_dict(
user_dict = UserDictionary(
user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path
)
user_dict.update_dict()
test_text = "テスト用の文字列"
success_pronunciation = "デフォルトノジショデハゼッタイニセイセイサレナイヨミ"

# 既に辞書に登録されていないか確認する
self.assertNotEqual(g2p(text=test_text, kana=True), success_pronunciation)

apply_word(
user_dict.apply_word(
surface=test_text,
pronunciation=success_pronunciation,
accent_type=1,
priority=10,
user_dict_path=user_dict_path,
compiled_dict_path=compiled_dict_path,
)
self.assertEqual(g2p(text=test_text, kana=True), success_pronunciation)

# 疑似的にエンジンを再起動する
unset_user_dict()
update_dict(
user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path
)
user_dict.update_dict()

self.assertEqual(g2p(text=test_text, kana=True), success_pronunciation)
7 changes: 4 additions & 3 deletions voicevox_engine/app/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from voicevox_engine.setting.Setting import CorsPolicyMode
from voicevox_engine.setting.SettingLoader import SettingHandler
from voicevox_engine.tts_pipeline.tts_engine import TTSEngine
from voicevox_engine.user_dict.user_dict import update_dict
from voicevox_engine.user_dict.user_dict import UserDictionary
from voicevox_engine.utility.path_utility import engine_root, get_save_dir


Expand All @@ -36,6 +36,7 @@ def generate_app(
latest_core_version: str,
setting_loader: SettingHandler,
preset_manager: PresetManager,
user_dict: UserDictionary,
cancellable_engine: CancellableEngine | None = None,
root_dir: Path | None = None,
cors_policy_mode: CorsPolicyMode = CorsPolicyMode.localapps,
Expand All @@ -48,7 +49,7 @@ def generate_app(

@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
update_dict()
user_dict.update_dict()
yield

app = FastAPI(
Expand Down Expand Up @@ -102,7 +103,7 @@ def get_core(core_version: str | None) -> CoreAdapter:
app.include_router(
generate_library_router(engine_manifest_data, library_manager)
)
app.include_router(generate_user_dict_router())
app.include_router(generate_user_dict_router(user_dict))
app.include_router(
generate_engine_info_router(get_core, cores, engine_manifest_data)
)
Expand Down
Loading

0 comments on commit c356918

Please sign in to comment.