Skip to content

Commit

Permalink
🧪 add tests
Browse files Browse the repository at this point in the history
- 增加 speaker 测试
- 修复 ssml parser 测试
  • Loading branch information
zhzLuke96 committed Jun 10, 2024
1 parent ef665da commit a807640
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 25 deletions.
26 changes: 13 additions & 13 deletions modules/api/impl/speaker_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,14 @@ def setup(app: APIManager):

@app.get("/v1/speakers/list", response_model=api_utils.BaseResponse)
async def list_speakers():
return {
"message": "ok",
"data": [spk.to_json() for spk in speaker_mgr.list_speakers()],
}
return api_utils.success_response(
[spk.to_json() for spk in speaker_mgr.list_speakers()]
)

@app.post("/v1/speakers/refresh", response_model=api_utils.BaseResponse)
async def refresh_speakers():
speaker_mgr.refresh_speakers()
return api_utils.success_response(None)

@app.post("/v1/speakers/update", response_model=api_utils.BaseResponse)
async def update_speakers(request: SpeakersUpdate):
Expand All @@ -59,7 +63,8 @@ async def update_speakers(request: SpeakersUpdate):
# number array => Tensor
speaker.emb = torch.tensor(spk["tensor"])
speaker_mgr.save_all()
return {"message": "ok", "data": None}

return api_utils.success_response(None)

@app.post("/v1/speaker/create", response_model=api_utils.BaseResponse)
async def create_speaker(request: CreateSpeaker):
Expand Down Expand Up @@ -88,12 +93,7 @@ async def create_speaker(request: CreateSpeaker):
raise HTTPException(
status_code=400, detail="Missing tensor or seed in request"
)
return {"message": "ok", "data": speaker.to_json()}

@app.post("/v1/speaker/refresh", response_model=api_utils.BaseResponse)
async def refresh_speakers():
speaker_mgr.refresh_speakers()
return {"message": "ok"}
return api_utils.success_response(speaker.to_json())

@app.post("/v1/speaker/update", response_model=api_utils.BaseResponse)
async def update_speaker(request: UpdateSpeaker):
Expand All @@ -113,11 +113,11 @@ async def update_speaker(request: UpdateSpeaker):
# number array => Tensor
speaker.emb = torch.tensor(request.tensor)
speaker_mgr.update_speaker(speaker)
return {"message": "ok"}
return api_utils.success_response(None)

@app.post("/v1/speaker/detail", response_model=api_utils.BaseResponse)
async def speaker_detail(request: SpeakerDetail):
speaker = speaker_mgr.get_speaker_by_id(request.id)
if speaker is None:
raise HTTPException(status_code=404, detail="Speaker not found")
return {"message": "ok", "data": speaker.to_json(with_emb=request.with_emb)}
return api_utils.success_response(speaker.to_json(with_emb=request.with_emb))
4 changes: 2 additions & 2 deletions tests/api/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
@mark.parametrize(
"input_text, voice",
[
("Hello, world", "female2"),
("Test text", "Alice"),
("Hello, world [lbreak]", "female2"),
("Test text [lbreak]", "Alice"),
("Invalid voice", "unknown_voice"),
],
)
Expand Down
6 changes: 3 additions & 3 deletions tests/api/test_speakers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
"path, method, status_code",
[
("/v1/speakers/list", "GET", 200),
# ("/v1/speakers/update", "POST", 200),
("/v1/speakers/refresh", "POST", 200),
],
)
@mark.speakers
@mark.speakers_api
def test_api_endpoints(client, path, method, status_code):
response = client.request(method, path)
assert response.status_code == status_code


@mark.speakers
@mark.speakers_api
def test_create_speaker(client):
data = {
"name": "测试发言人",
Expand Down
14 changes: 7 additions & 7 deletions tests/test_ssml/test_ssml_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ def test_speak_tag(parser):
assert len(segments) == 3
assert isinstance(segments[0], SSMLSegment)
assert segments[0].text == "你好"
assert segments[0].params.rate == "fast"
assert segments[0].attrs.rate == "fast"
assert isinstance(segments[1], SSMLBreak)
assert segments[1].duration == 500
assert segments[1].attrs.duration == 500
assert isinstance(segments[2], SSMLSegment)
assert segments[2].text == "你好"
assert segments[2].params.rate == "slow"
assert segments[2].attrs.rate == "slow"


@pytest.mark.ssml_parser
Expand All @@ -47,8 +47,8 @@ def test_voice_tag(parser):
assert len(segments) == 1
assert isinstance(segments[0], SSMLSegment)
assert segments[0].text == "你好"
assert segments[0].params.spk == "xiaoyan"
assert segments[0].params.style == "news"
assert segments[0].attrs.spk == "xiaoyan"
assert segments[0].attrs.style == "news"


@pytest.mark.ssml_parser
Expand All @@ -61,7 +61,7 @@ def test_break_tag(parser):
segments = parser.parse(ssml)
assert len(segments) == 1
assert isinstance(segments[0], SSMLBreak)
assert segments[0].duration == 500
assert segments[0].attrs.duration == 500


@pytest.mark.ssml_parser
Expand All @@ -75,7 +75,7 @@ def test_prosody_tag(parser):
assert len(segments) == 1
assert isinstance(segments[0], SSMLSegment)
assert segments[0].text == "你好"
assert segments[0].params.rate == "fast"
assert segments[0].attrs.rate == "fast"


@pytest.mark.ssml_parser
Expand Down

0 comments on commit a807640

Please sign in to comment.