diff --git a/modules/api/impl/google_api.py b/modules/api/impl/google_api.py index a329803..ae1863e 100644 --- a/modules/api/impl/google_api.py +++ b/modules/api/impl/google_api.py @@ -11,6 +11,7 @@ from modules.normalization import text_normalize from modules import generate_audio as generate +from modules.speaker import speaker_mgr from modules.ssml import parse_ssml @@ -74,6 +75,8 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest): volume_gain_db = audioConfig.get("volumeGainDb", 0) batch_size = audioConfig.get("batchSize", 1) + + # TODO spliter_threshold spliter_threshold = audioConfig.get("spliterThreshold", 100) # TODO sample_rate @@ -84,6 +87,18 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest): # TODO maybe need to change the sample rate sample_rate = 24000 + # TODO 使用 speaker + spk = speaker_mgr.get_speaker(voice_name) + if spk is None: + raise HTTPException( + status_code=400, detail="The specified voice name is not supported." + ) + + if audio_format != "mp3" and audio_format != "wav": + raise HTTPException( + status_code=400, detail="Invalid audio encoding format specified." + ) + try: if input.text: # 处理文本合成逻辑 diff --git a/modules/api/impl/openai_api.py b/modules/api/impl/openai_api.py index d3ae9ad..e644b10 100644 --- a/modules/api/impl/openai_api.py +++ b/modules/api/impl/openai_api.py @@ -20,6 +20,9 @@ from modules.api import utils as api_utils from modules.api.Api import APIManager +from modules.speaker import speaker_mgr +from modules.data import styles_mgr + import numpy as np @@ -45,20 +48,27 @@ async def openai_speech_api( ..., description="JSON body with model, input text, and voice" ) ): + model = request.model + input_text = request.input + voice = request.voice + style = request.style + response_format = request.response_format + batch_size = request.batch_size + spliter_threshold = request.spliter_threshold + speed = request.speed + speed = clip(speed, 0.1, 10) + + if not input_text: + raise HTTPException(status_code=400, detail="Input text is required.") + if speaker_mgr.get_speaker(voice) is None: + raise HTTPException(status_code=400, detail="Invalid voice.") try: - model = request.model - input_text = request.input - voice = request.voice - style = request.style - response_format = request.response_format - batch_size = request.batch_size - spliter_threshold = request.spliter_threshold - speed = request.speed - speed = clip(speed, 0.1, 10) - - if not input_text: - raise HTTPException(status_code=400, detail="Input text is required.") + if style: + styles_mgr.find_item_by_name(style) + except: + raise HTTPException(status_code=400, detail="Invalid style.") + try: # Normalize the text text = text_normalize(input_text, is_end=True) diff --git a/modules/api/utils.py b/modules/api/utils.py index dcda587..e75a615 100644 --- a/modules/api/utils.py +++ b/modules/api/utils.py @@ -29,12 +29,6 @@ class BaseResponse(BaseModel): message: str data: Any - class Config: - json_encoders = { - torch.Tensor: lambda v: v.tolist(), - Speaker: lambda v: v.to_json(), - } - def success_response(data: Any, message: str = "ok") -> BaseResponse: return BaseResponse(message=message, data=data) diff --git a/modules/utils/CsvMgr.py b/modules/utils/CsvMgr.py index 31c0ebc..61ee99e 100644 --- a/modules/utils/CsvMgr.py +++ b/modules/utils/CsvMgr.py @@ -15,6 +15,7 @@ class DataNotFoundError(Exception): pass +# FIXME: 😓这个东西写的比较拉跨,最好找个什么csv库替代掉... class BaseManager: def __init__(self, csv_file): self.csv_file = csv_file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/api/__init__.py b/tests/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/api/test_google.py b/tests/api/test_google.py new file mode 100644 index 0000000..6ce819c --- /dev/null +++ b/tests/api/test_google.py @@ -0,0 +1,111 @@ +import pytest +from fastapi.testclient import TestClient +from launch import create_api +import base64 +import os + +import tests.conftest + +app_instance = create_api() +client = TestClient(app_instance.app) + + +@pytest.fixture +def google_text_synthesize_request(): + return { + "input": {"text": "这是一个测试文本。"}, + "voice": { + "languageCode": "ZH-CN", + "name": "female2", + "style": "", + "temperature": 0.5, + "topP": 0.8, + "topK": 50, + "seed": 42, + }, + "audioConfig": { + "audioEncoding": "mp3", + "speakingRate": 1.0, + "pitch": 0.0, + "volumeGainDb": 0.0, + "sampleRateHertz": 24000, + "batchSize": 1, + "spliterThreshold": 100, + }, + } + + +def test_google_text_synthesize_success(google_text_synthesize_request): + response = client.post("/v1/text:synthesize", json=google_text_synthesize_request) + assert response.status_code == 200 + assert "audioContent" in response.json() + + with open( + os.path.join( + tests.conftest.test_outputs_dir, "google_text_synthesize_success.mp3" + ), + "wb", + ) as f: + b64_str = response.json()["audioContent"] + b64_str = b64_str.split(",")[1] + f.write(base64.b64decode(b64_str)) + + +def test_google_text_synthesize_missing_input(): + response = client.post("/v1/text:synthesize", json={}) + assert response.status_code == 422 + assert "Field required" == response.json()["detail"][0]["msg"] + + +def test_google_text_synthesize_invalid_voice(): + request = { + "input": {"text": "这是一个测试文本。"}, + "voice": { + "languageCode": "EN-US", + "name": "invalid_voice", + "style": "", + "temperature": 0.5, + "topP": 0.8, + "topK": 50, + "seed": 42, + }, + "audioConfig": { + "audioEncoding": "mp3", + "speakingRate": 1.0, + "pitch": 0.0, + "volumeGainDb": 0.0, + "sampleRateHertz": 24000, + "batchSize": 1, + "spliterThreshold": 100, + }, + } + response = client.post("/v1/text:synthesize", json=request) + assert response.status_code == 400 + assert "detail" in response.json() + + +def test_google_text_synthesize_invalid_audio_encoding(): + request = { + "input": {"text": "这是一个测试文本。"}, + "voice": { + "languageCode": "ZH-CN", + "name": "female2", + "style": "", + "temperature": 0.5, + "topP": 0.8, + "topK": 50, + "seed": 42, + }, + "audioConfig": { + "audioEncoding": "invalid_format", + "speakingRate": 1.0, + "pitch": 0.0, + "volumeGainDb": 0.0, + "sampleRateHertz": 24000, + "batchSize": 1, + "spliterThreshold": 100, + }, + } + response = client.post("/v1/text:synthesize", json=request) + assert response.status_code == 400 + assert "detail" in response.json() diff --git a/tests/api/test_openai.py b/tests/api/test_openai.py new file mode 100644 index 0000000..7195d59 --- /dev/null +++ b/tests/api/test_openai.py @@ -0,0 +1,72 @@ +import os +from pytest import fixture, mark, raises +from fastapi.testclient import TestClient +from modules.api.impl.openai_api import AudioSpeechRequest + +from launch import create_api + +import tests.conftest + +app_instance = create_api() + + +@fixture +def client(): + yield TestClient(app_instance.app) + + +@mark.parametrize( + "input_text, voice", + [ + ("Hello, world", "female2"), + ("Test text", "Alice"), + ("Invalid voice", "unknown_voice"), + ], +) +def test_openai_speech_api(client, input_text, voice): + request = AudioSpeechRequest(input=input_text, voice=voice) + response = client.post("/v1/audio/speech", json=request.model_dump()) + + if voice == "unknown_voice": + assert response.status_code == 400 + assert "Invalid voice" in response.json().get("detail", "") + else: + assert response.status_code == 200 + assert response.headers["Content-Type"] == "audio/mp3" + with open( + os.path.join( + tests.conftest.test_outputs_dir, + f"{input_text.replace(' ', '_')}_{voice}.mp3", + ), + "wb", + ) as f: + f.write(response.content) + + +def test_openai_speech_api_with_invalid_style(client): + request = AudioSpeechRequest( + input="Test text", voice="female2", style="invalid_style" + ) + response = client.post("/v1/audio/speech", json=request.model_dump()) + + assert response.status_code == 400 + assert "Invalid style" in response.json().get("detail", "") + + +# def test_transcribe_not_implemented(client): +# file = {"file": ("test.wav", b"test audio data")} +# response = client.post("/v1/audio/transcriptions", files=file) + +# assert response.status_code == 200 +# assert response.json() == success_response("not implemented yet") + + +# TODO +# @mark.parametrize("file_name, file_content", [("test.wav", b"test audio data")]) +# def test_transcribe_with_file(client, file_name, file_content): +# file = {"file": (file_name, file_content)} +# response = client.post("/v1/audio/transcriptions", files=file) + +# assert response.status_code == 200 +# assert isinstance(response.json(), TranscriptionsVerboseResponse) +# assert response.json().text == "not implemented yet" diff --git a/tests/api/test_speakers.py b/tests/api/test_speakers.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..30e3f82 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,5 @@ +import os + + +test_inputs_dir = os.path.dirname(__file__) + "/test_inputs" +test_outputs_dir = os.path.dirname(__file__) + "/test_outputs"