Skip to content

Commit

Permalink
✨ add stream mode to openai api
Browse files Browse the repository at this point in the history
- add interrupt to api
- add stream mode to openai api

ref #70
  • Loading branch information
zhzLuke96 committed Jul 4, 2024
1 parent 7f19d4f commit 989c3c5
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 39 deletions.
2 changes: 1 addition & 1 deletion modules/ChatTTS/ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ class InferCodeParams(RefineTextParams):
repetition_penalty: float = 1.05
max_new_token: int = 2048

stream_chunk_size: int = 128
stream_chunk_size: int = 96

prompt1: str = ""
prompt2: str = ""
Expand Down
12 changes: 4 additions & 8 deletions modules/ChatTTS/ChatTTS/model/cuda/te_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,21 @@
#
# Edited by fumiama.

import gc
import os
import re
import gc
from contextlib import contextmanager

import transformer_engine as te
import transformers
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
from transformer_engine.pytorch.fp8 import fp8_model_init

import transformers
from transformers.models.llama.modeling_llama import (
LlamaModel,
LlamaConfig,
)
from transformers.modeling_utils import (
_add_variant,
load_state_dict,
_load_state_dict_into_model,
load_state_dict,
)
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaModel
from transformers.utils import WEIGHTS_INDEX_NAME
from transformers.utils.hub import get_checkpoint_shard_files

Expand Down
13 changes: 7 additions & 6 deletions modules/ChatTTS/ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os, platform
import os
import platform

os.environ["TOKENIZERS_PARALLELISM"] = "false"
"""
https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning
"""

from dataclasses import dataclass
import logging
from typing import Union, List, Optional, Tuple
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import omegaconf
import torch
Expand All @@ -16,13 +17,13 @@
import torch.nn.utils.parametrize as P
from torch.nn.utils.parametrizations import weight_norm
from tqdm import tqdm
from transformers import LlamaModel, LlamaConfig, LogitsWarper
from transformers import LlamaConfig, LlamaModel, LogitsWarper
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import is_flash_attn_2_available

from .processors import CustomRepetitionPenaltyLogitsProcessorRepeat
from ..utils import del_all
from .processors import CustomRepetitionPenaltyLogitsProcessorRepeat


class GPT(nn.Module):
Expand Down Expand Up @@ -364,7 +365,7 @@ def generate(
stream=False,
show_tqdm=True,
ensure_non_empty=True,
stream_chunk_size=128,
stream_chunk_size=96,
context=Context(),
):

Expand Down
21 changes: 16 additions & 5 deletions modules/ChatTTSInfer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import threading
import logging
from dataclasses import dataclass, is_dataclass
from typing import Generator, List, Optional, Union
from typing import Generator, List, Union

import numpy as np
import torch
Expand Down Expand Up @@ -38,14 +38,24 @@ def del_all(d: Union[dict, list]):


class ChatTTSInfer:
infer_lock = threading.Lock()
logger = logging.getLogger(__name__)

current_infer = None

def __init__(self, instance: Chat) -> None:
self.instance = instance
ChatTTSInfer.current_infer = self

def get_tokenizer(self) -> LlamaTokenizer:
return self.instance.pretrain_models["tokenizer"]

@classmethod
def interrupt(cls):
# FIXME: 目前没法立即停止,会等到下一个chunk?好像得改 `gpt.py`
if cls.current_infer:
cls.current_infer.instance.interrupt()
cls.logger.info("Interrupted current infer")

def infer(
self,
text: str,
Expand Down Expand Up @@ -85,11 +95,12 @@ def _infer(
if not isinstance(text, list):
text = [text]

# NOTE: 有点问题,感觉能提升,但是其实没有任何提升...所以暂时关闭
# NOTE: 作用就是尽量不让 vocos 处理短序列 (但是可能导致略微性能降低)
# 但是效果不太好...暂时关闭
# smooth_decoding = stream
smooth_decoding = False

with torch.no_grad(), self.infer_lock:
with torch.no_grad():

if not skip_refine_text:
refined = self.instance._refine_text(
Expand Down
16 changes: 15 additions & 1 deletion modules/api/impl/handler/AudioHandler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import base64
import io
import wave
from typing import Generator
from typing import AsyncGenerator, Generator

import numpy as np
from fastapi import Request
from pydub import AudioSegment

from modules.api.impl.model.audio_model import AudioFormat
from modules.ChatTTSInfer import ChatTTSInfer
from modules.utils.audio import ndarray_to_segment


Expand Down Expand Up @@ -91,6 +93,18 @@ def enqueue_to_stream(self, format: AudioFormat) -> Generator[bytes, None, None]

# print("AudioHandler: enqueue_to_stream done")

async def enqueue_to_stream_with_request(
self, request: Request, format: AudioFormat
) -> AsyncGenerator[bytes, None]:
buffer_gen = self.enqueue_to_stream(format=AudioFormat(format))
for chunk in buffer_gen:
disconnected = await request.is_disconnected()
if disconnected:
ChatTTSInfer.interrupt()
break

yield chunk

# just for test
def enqueue_to_stream_join(
self, format: AudioFormat
Expand Down
6 changes: 6 additions & 0 deletions modules/api/impl/handler/SSMLHandler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Generator

import numpy as np
from fastapi import HTTPException

Expand Down Expand Up @@ -97,3 +99,7 @@ def enqueue(self) -> tuple[np.ndarray, int]:
)

return audio_data, sample_rate

def enqueue_stream(self) -> Generator[tuple[np.ndarray, int], None, None]:
# TODO: 应该很好支持stream...
raise NotImplementedError("Stream is not supported for SSMLHandler.")
16 changes: 13 additions & 3 deletions modules/api/impl/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class AudioSpeechRequest(BaseModel):
enhance: bool = False
denoise: bool = False

stream: bool = False


async def openai_speech_api(
request: AudioSpeechRequest = Body(
Expand All @@ -50,6 +52,7 @@ async def openai_speech_api(
style = request.style
eos = request.eos
seed = request.seed
stream = request.stream

response_format = request.response_format
if not isinstance(response_format, AudioFormat) and isinstance(
Expand Down Expand Up @@ -105,12 +108,19 @@ async def openai_speech_api(
enhancer_config=enhancer_config,
)

buffer = handler.enqueue_to_buffer(response_format)

mime_type = f"audio/{response_format.value}"
if response_format == AudioFormat.mp3:
mime_type = "audio/mpeg"
return StreamingResponse(buffer, media_type=mime_type)

if stream:
gen = handler.enqueue_to_stream_with_request(
request=request,
format=response_format,
)
return StreamingResponse(gen, media_type=mime_type)
else:
buffer = handler.enqueue_to_buffer(response_format)
return StreamingResponse(buffer, media_type=mime_type)

except Exception as e:
import logging
Expand Down
13 changes: 6 additions & 7 deletions modules/api/impl/tts_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import io
import logging

from fastapi import Depends, HTTPException, Query
from fastapi import Depends, HTTPException, Query, Request
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel

Expand Down Expand Up @@ -52,7 +52,7 @@ class TTSParams(BaseModel):
stream: bool = Query(False, description="Stream the audio")


async def synthesize_tts(params: TTSParams = Depends()):
async def synthesize_tts(request: Request, params: TTSParams = Depends()):
try:
# Validate text
if not params.text.strip():
Expand Down Expand Up @@ -150,11 +150,10 @@ async def synthesize_tts(params: TTSParams = Depends()):
f"Batch size {infer_config.batch_size} is not supported in streaming mode, will set to 1"
)

buffer_gen = handler.enqueue_to_stream(format=AudioFormat(params.format))
# buffer_gen = handler.enqueue_to_stream_join(
# format=AudioFormat(params.format)
# )
return StreamingResponse(buffer_gen, media_type=media_type)
gen = handler.enqueue_to_stream_with_request(
request=request, format=AudioFormat(params.format)
)
return StreamingResponse(gen, media_type=media_type)
else:
buffer = handler.enqueue_to_buffer(format=AudioFormat(params.format))
return StreamingResponse(buffer, media_type=media_type)
Expand Down
12 changes: 4 additions & 8 deletions modules/api/impl/xtts_v2_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,11 @@ async def tts_stream(
enhancer_config=enhancer_config,
)

async def generator():
for chunk in handler.enqueue_to_stream(AudioFormat.mp3):
disconnected = await request.is_disconnected()
if disconnected:
break

yield chunk
gen = handler.enqueue_to_stream_with_request(
request=request, format=AudioFormat.mp3
)

return StreamingResponse(generator(), media_type="audio/mpeg")
return StreamingResponse(gen, media_type="audio/mpeg")

@app.post("/v1/xtts_v2/set_tts_settings")
async def set_tts_settings(request: TTSSettingsRequest):
Expand Down

0 comments on commit 989c3c5

Please sign in to comment.