-
Notifications
You must be signed in to change notification settings - Fork 118
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
226 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import pytest | ||
from fastapi.testclient import TestClient | ||
from launch import create_api | ||
import base64 | ||
import os | ||
|
||
import tests.conftest | ||
|
||
app_instance = create_api() | ||
client = TestClient(app_instance.app) | ||
|
||
|
||
@pytest.fixture | ||
def google_text_synthesize_request(): | ||
return { | ||
"input": {"text": "这是一个测试文本。"}, | ||
"voice": { | ||
"languageCode": "ZH-CN", | ||
"name": "female2", | ||
"style": "", | ||
"temperature": 0.5, | ||
"topP": 0.8, | ||
"topK": 50, | ||
"seed": 42, | ||
}, | ||
"audioConfig": { | ||
"audioEncoding": "mp3", | ||
"speakingRate": 1.0, | ||
"pitch": 0.0, | ||
"volumeGainDb": 0.0, | ||
"sampleRateHertz": 24000, | ||
"batchSize": 1, | ||
"spliterThreshold": 100, | ||
}, | ||
} | ||
|
||
|
||
def test_google_text_synthesize_success(google_text_synthesize_request): | ||
response = client.post("/v1/text:synthesize", json=google_text_synthesize_request) | ||
assert response.status_code == 200 | ||
assert "audioContent" in response.json() | ||
|
||
with open( | ||
os.path.join( | ||
tests.conftest.test_outputs_dir, "google_text_synthesize_success.mp3" | ||
), | ||
"wb", | ||
) as f: | ||
b64_str = response.json()["audioContent"] | ||
b64_str = b64_str.split(",")[1] | ||
f.write(base64.b64decode(b64_str)) | ||
|
||
|
||
def test_google_text_synthesize_missing_input(): | ||
response = client.post("/v1/text:synthesize", json={}) | ||
assert response.status_code == 422 | ||
assert "Field required" == response.json()["detail"][0]["msg"] | ||
|
||
|
||
def test_google_text_synthesize_invalid_voice(): | ||
request = { | ||
"input": {"text": "这是一个测试文本。"}, | ||
"voice": { | ||
"languageCode": "EN-US", | ||
"name": "invalid_voice", | ||
"style": "", | ||
"temperature": 0.5, | ||
"topP": 0.8, | ||
"topK": 50, | ||
"seed": 42, | ||
}, | ||
"audioConfig": { | ||
"audioEncoding": "mp3", | ||
"speakingRate": 1.0, | ||
"pitch": 0.0, | ||
"volumeGainDb": 0.0, | ||
"sampleRateHertz": 24000, | ||
"batchSize": 1, | ||
"spliterThreshold": 100, | ||
}, | ||
} | ||
response = client.post("/v1/text:synthesize", json=request) | ||
assert response.status_code == 400 | ||
assert "detail" in response.json() | ||
|
||
|
||
def test_google_text_synthesize_invalid_audio_encoding(): | ||
request = { | ||
"input": {"text": "这是一个测试文本。"}, | ||
"voice": { | ||
"languageCode": "ZH-CN", | ||
"name": "female2", | ||
"style": "", | ||
"temperature": 0.5, | ||
"topP": 0.8, | ||
"topK": 50, | ||
"seed": 42, | ||
}, | ||
"audioConfig": { | ||
"audioEncoding": "invalid_format", | ||
"speakingRate": 1.0, | ||
"pitch": 0.0, | ||
"volumeGainDb": 0.0, | ||
"sampleRateHertz": 24000, | ||
"batchSize": 1, | ||
"spliterThreshold": 100, | ||
}, | ||
} | ||
response = client.post("/v1/text:synthesize", json=request) | ||
assert response.status_code == 400 | ||
assert "detail" in response.json() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import os | ||
from pytest import fixture, mark, raises | ||
from fastapi.testclient import TestClient | ||
from modules.api.impl.openai_api import AudioSpeechRequest | ||
|
||
from launch import create_api | ||
|
||
import tests.conftest | ||
|
||
app_instance = create_api() | ||
|
||
|
||
@fixture | ||
def client(): | ||
yield TestClient(app_instance.app) | ||
|
||
|
||
@mark.parametrize( | ||
"input_text, voice", | ||
[ | ||
("Hello, world", "female2"), | ||
("Test text", "Alice"), | ||
("Invalid voice", "unknown_voice"), | ||
], | ||
) | ||
def test_openai_speech_api(client, input_text, voice): | ||
request = AudioSpeechRequest(input=input_text, voice=voice) | ||
response = client.post("/v1/audio/speech", json=request.model_dump()) | ||
|
||
if voice == "unknown_voice": | ||
assert response.status_code == 400 | ||
assert "Invalid voice" in response.json().get("detail", "") | ||
else: | ||
assert response.status_code == 200 | ||
assert response.headers["Content-Type"] == "audio/mp3" | ||
with open( | ||
os.path.join( | ||
tests.conftest.test_outputs_dir, | ||
f"{input_text.replace(' ', '_')}_{voice}.mp3", | ||
), | ||
"wb", | ||
) as f: | ||
f.write(response.content) | ||
|
||
|
||
def test_openai_speech_api_with_invalid_style(client): | ||
request = AudioSpeechRequest( | ||
input="Test text", voice="female2", style="invalid_style" | ||
) | ||
response = client.post("/v1/audio/speech", json=request.model_dump()) | ||
|
||
assert response.status_code == 400 | ||
assert "Invalid style" in response.json().get("detail", "") | ||
|
||
|
||
# def test_transcribe_not_implemented(client): | ||
# file = {"file": ("test.wav", b"test audio data")} | ||
# response = client.post("/v1/audio/transcriptions", files=file) | ||
|
||
# assert response.status_code == 200 | ||
# assert response.json() == success_response("not implemented yet") | ||
|
||
|
||
# TODO | ||
# @mark.parametrize("file_name, file_content", [("test.wav", b"test audio data")]) | ||
# def test_transcribe_with_file(client, file_name, file_content): | ||
# file = {"file": (file_name, file_content)} | ||
# response = client.post("/v1/audio/transcriptions", files=file) | ||
|
||
# assert response.status_code == 200 | ||
# assert isinstance(response.json(), TranscriptionsVerboseResponse) | ||
# assert response.json().text == "not implemented yet" |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
import os | ||
|
||
|
||
test_inputs_dir = os.path.dirname(__file__) + "/test_inputs" | ||
test_outputs_dir = os.path.dirname(__file__) + "/test_outputs" |