Skip to content

Commit

Permalink
✅ add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
zhzLuke96 committed Jun 6, 2024
1 parent 49088c5 commit e7f9385
Show file tree
Hide file tree
Showing 10 changed files with 226 additions and 18 deletions.
15 changes: 15 additions & 0 deletions modules/api/impl/google_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from modules.normalization import text_normalize

from modules import generate_audio as generate
from modules.speaker import speaker_mgr


from modules.ssml import parse_ssml
Expand Down Expand Up @@ -74,6 +75,8 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
volume_gain_db = audioConfig.get("volumeGainDb", 0)

batch_size = audioConfig.get("batchSize", 1)

# TODO spliter_threshold
spliter_threshold = audioConfig.get("spliterThreshold", 100)

# TODO sample_rate
Expand All @@ -84,6 +87,18 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
# TODO maybe need to change the sample rate
sample_rate = 24000

# TODO 使用 speaker
spk = speaker_mgr.get_speaker(voice_name)
if spk is None:
raise HTTPException(
status_code=400, detail="The specified voice name is not supported."
)

if audio_format != "mp3" and audio_format != "wav":
raise HTTPException(
status_code=400, detail="Invalid audio encoding format specified."
)

try:
if input.text:
# 处理文本合成逻辑
Expand Down
34 changes: 22 additions & 12 deletions modules/api/impl/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from modules.api import utils as api_utils
from modules.api.Api import APIManager

from modules.speaker import speaker_mgr
from modules.data import styles_mgr

import numpy as np


Expand All @@ -45,20 +48,27 @@ async def openai_speech_api(
..., description="JSON body with model, input text, and voice"
)
):
model = request.model
input_text = request.input
voice = request.voice
style = request.style
response_format = request.response_format
batch_size = request.batch_size
spliter_threshold = request.spliter_threshold
speed = request.speed
speed = clip(speed, 0.1, 10)

if not input_text:
raise HTTPException(status_code=400, detail="Input text is required.")
if speaker_mgr.get_speaker(voice) is None:
raise HTTPException(status_code=400, detail="Invalid voice.")
try:
model = request.model
input_text = request.input
voice = request.voice
style = request.style
response_format = request.response_format
batch_size = request.batch_size
spliter_threshold = request.spliter_threshold
speed = request.speed
speed = clip(speed, 0.1, 10)

if not input_text:
raise HTTPException(status_code=400, detail="Input text is required.")
if style:
styles_mgr.find_item_by_name(style)
except:
raise HTTPException(status_code=400, detail="Invalid style.")

try:
# Normalize the text
text = text_normalize(input_text, is_end=True)

Expand Down
6 changes: 0 additions & 6 deletions modules/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,6 @@ class BaseResponse(BaseModel):
message: str
data: Any

class Config:
json_encoders = {
torch.Tensor: lambda v: v.tolist(),
Speaker: lambda v: v.to_json(),
}


def success_response(data: Any, message: str = "ok") -> BaseResponse:
return BaseResponse(message=message, data=data)
Expand Down
1 change: 1 addition & 0 deletions modules/utils/CsvMgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class DataNotFoundError(Exception):
pass


# FIXME: 😓这个东西写的比较拉跨,最好找个什么csv库替代掉...
class BaseManager:
def __init__(self, csv_file):
self.csv_file = csv_file
Expand Down
Empty file added tests/__init__.py
Empty file.
Empty file added tests/api/__init__.py
Empty file.
111 changes: 111 additions & 0 deletions tests/api/test_google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import pytest
from fastapi.testclient import TestClient
from launch import create_api
import base64
import os

import tests.conftest

app_instance = create_api()
client = TestClient(app_instance.app)


@pytest.fixture
def google_text_synthesize_request():
return {
"input": {"text": "这是一个测试文本。"},
"voice": {
"languageCode": "ZH-CN",
"name": "female2",
"style": "",
"temperature": 0.5,
"topP": 0.8,
"topK": 50,
"seed": 42,
},
"audioConfig": {
"audioEncoding": "mp3",
"speakingRate": 1.0,
"pitch": 0.0,
"volumeGainDb": 0.0,
"sampleRateHertz": 24000,
"batchSize": 1,
"spliterThreshold": 100,
},
}


def test_google_text_synthesize_success(google_text_synthesize_request):
response = client.post("/v1/text:synthesize", json=google_text_synthesize_request)
assert response.status_code == 200
assert "audioContent" in response.json()

with open(
os.path.join(
tests.conftest.test_outputs_dir, "google_text_synthesize_success.mp3"
),
"wb",
) as f:
b64_str = response.json()["audioContent"]
b64_str = b64_str.split(",")[1]
f.write(base64.b64decode(b64_str))


def test_google_text_synthesize_missing_input():
response = client.post("/v1/text:synthesize", json={})
assert response.status_code == 422
assert "Field required" == response.json()["detail"][0]["msg"]


def test_google_text_synthesize_invalid_voice():
request = {
"input": {"text": "这是一个测试文本。"},
"voice": {
"languageCode": "EN-US",
"name": "invalid_voice",
"style": "",
"temperature": 0.5,
"topP": 0.8,
"topK": 50,
"seed": 42,
},
"audioConfig": {
"audioEncoding": "mp3",
"speakingRate": 1.0,
"pitch": 0.0,
"volumeGainDb": 0.0,
"sampleRateHertz": 24000,
"batchSize": 1,
"spliterThreshold": 100,
},
}
response = client.post("/v1/text:synthesize", json=request)
assert response.status_code == 400
assert "detail" in response.json()


def test_google_text_synthesize_invalid_audio_encoding():
request = {
"input": {"text": "这是一个测试文本。"},
"voice": {
"languageCode": "ZH-CN",
"name": "female2",
"style": "",
"temperature": 0.5,
"topP": 0.8,
"topK": 50,
"seed": 42,
},
"audioConfig": {
"audioEncoding": "invalid_format",
"speakingRate": 1.0,
"pitch": 0.0,
"volumeGainDb": 0.0,
"sampleRateHertz": 24000,
"batchSize": 1,
"spliterThreshold": 100,
},
}
response = client.post("/v1/text:synthesize", json=request)
assert response.status_code == 400
assert "detail" in response.json()
72 changes: 72 additions & 0 deletions tests/api/test_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
from pytest import fixture, mark, raises
from fastapi.testclient import TestClient
from modules.api.impl.openai_api import AudioSpeechRequest

from launch import create_api

import tests.conftest

app_instance = create_api()


@fixture
def client():
yield TestClient(app_instance.app)


@mark.parametrize(
"input_text, voice",
[
("Hello, world", "female2"),
("Test text", "Alice"),
("Invalid voice", "unknown_voice"),
],
)
def test_openai_speech_api(client, input_text, voice):
request = AudioSpeechRequest(input=input_text, voice=voice)
response = client.post("/v1/audio/speech", json=request.model_dump())

if voice == "unknown_voice":
assert response.status_code == 400
assert "Invalid voice" in response.json().get("detail", "")
else:
assert response.status_code == 200
assert response.headers["Content-Type"] == "audio/mp3"
with open(
os.path.join(
tests.conftest.test_outputs_dir,
f"{input_text.replace(' ', '_')}_{voice}.mp3",
),
"wb",
) as f:
f.write(response.content)


def test_openai_speech_api_with_invalid_style(client):
request = AudioSpeechRequest(
input="Test text", voice="female2", style="invalid_style"
)
response = client.post("/v1/audio/speech", json=request.model_dump())

assert response.status_code == 400
assert "Invalid style" in response.json().get("detail", "")


# def test_transcribe_not_implemented(client):
# file = {"file": ("test.wav", b"test audio data")}
# response = client.post("/v1/audio/transcriptions", files=file)

# assert response.status_code == 200
# assert response.json() == success_response("not implemented yet")


# TODO
# @mark.parametrize("file_name, file_content", [("test.wav", b"test audio data")])
# def test_transcribe_with_file(client, file_name, file_content):
# file = {"file": (file_name, file_content)}
# response = client.post("/v1/audio/transcriptions", files=file)

# assert response.status_code == 200
# assert isinstance(response.json(), TranscriptionsVerboseResponse)
# assert response.json().text == "not implemented yet"
Empty file added tests/api/test_speakers.py
Empty file.
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import os


test_inputs_dir = os.path.dirname(__file__) + "/test_inputs"
test_outputs_dir = os.path.dirname(__file__) + "/test_outputs"

0 comments on commit e7f9385

Please sign in to comment.