Skip to content

Commit

Permalink
fix: enhance type hints and improve audio message handling in TTS pub…
Browse files Browse the repository at this point in the history
…lisher

Signed-off-by: -LAN- <laipz8200@outlook.com>
  • Loading branch information
laipz8200 committed Dec 21, 2024
1 parent 9ee9e9c commit 5093e4c
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 29 deletions.
32 changes: 19 additions & 13 deletions api/core/app/apps/advanced_chat/app_generator_tts_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
import queue
import re
import threading
from collections.abc import Iterable

from core.app.entities.queue_entities import (
MessageQueueMessage,
QueueAgentMessageEvent,
QueueLLMChunkEvent,
QueueNodeSucceededEvent,
QueueTextChunkEvent,
WorkflowQueueMessage,
)
from core.model_manager import ModelManager
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType


Expand All @@ -21,21 +24,27 @@ def __init__(self, status: str, audio):
self.status = status


def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str):
if not text_content or text_content.isspace():
return
return model_instance.invoke_tts(
content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice
)


def _process_future(future_queue, audio_queue):
def _process_future(
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None],
audio_queue: queue.Queue[AudioTrunk],
):
while True:
try:
future = future_queue.get()
if future is None:
break
for audio in future.result():
invoke_result = future.result()
if not invoke_result:
continue
for audio in invoke_result:
audio_base64 = base64.b64encode(bytes(audio))
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
except Exception as e:
Expand All @@ -49,8 +58,8 @@ def __init__(self, tenant_id: str, voice: str):
self.logger = logging.getLogger(__name__)
self.tenant_id = tenant_id
self.msg_text = ""
self._audio_queue = queue.Queue()
self._msg_queue = queue.Queue()
self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue()
self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
self.match = re.compile(r"[。.!?]")
self.model_manager = ModelManager()
self.model_instance = self.model_manager.get_default_model_instance(
Expand All @@ -66,14 +75,11 @@ def __init__(self, tenant_id: str, voice: str):
self._runtime_thread = threading.Thread(target=self._runtime).start()
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)

def publish(self, message):
try:
self._msg_queue.put(message)
except Exception as e:
self.logger.warning(e)
def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
self._msg_queue.put(message)

def _runtime(self):
future_queue = queue.Queue()
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None] = queue.Queue()
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
while True:
try:
Expand Down Expand Up @@ -110,7 +116,7 @@ def _runtime(self):
break
future_queue.put(None)

def check_and_get_audio(self) -> AudioTrunk | None:
def check_and_get_audio(self):
try:
if self._last_audio_event and self._last_audio_event.status == "finish":
if self.executor:
Expand Down
10 changes: 5 additions & 5 deletions api/core/app/apps/advanced_chat/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,11 @@ def _to_stream_response(
stream_response=stream_response,
)

def _listen_audio_msg(self, publisher, task_id: str):
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish":
audio_msg = publisher.check_and_get_audio()
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None

Expand All @@ -222,7 +222,7 @@ def _wrapper_process_stream_response(

for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
Expand Down Expand Up @@ -511,7 +511,7 @@ def _process_stream_response(

# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(message=queue_message)
tts_publisher.publish(queue_message)

self._task_state.answer += delta_text
yield self._message_to_stream_response(
Expand Down
7 changes: 4 additions & 3 deletions api/core/app/apps/base_app_queue_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import queue
import time
from abc import abstractmethod
from collections.abc import Generator
from enum import Enum
from typing import Any

Expand All @@ -11,9 +10,11 @@
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
MessageQueueMessage,
QueueErrorEvent,
QueuePingEvent,
QueueStopEvent,
WorkflowQueueMessage,
)
from extensions.ext_redis import redis_client

Expand All @@ -37,11 +38,11 @@ def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None:
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
)

q = queue.Queue()
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()

self._q = q

def listen(self) -> Generator:
def listen(self):
"""
Listen to queue
:return:
Expand Down
10 changes: 5 additions & 5 deletions api/core/app/apps/workflow/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,11 @@ def _to_stream_response(

yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)

def _listen_audio_msg(self, publisher, task_id: str):
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish":
audio_msg = publisher.check_and_get_audio()
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None

Expand All @@ -196,7 +196,7 @@ def _wrapper_process_stream_response(

for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
Expand Down Expand Up @@ -421,7 +421,7 @@ def _process_stream_response(

# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(message=queue_message)
tts_publisher.publish(queue_message)

self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,11 @@ def _to_stream_response(
stream_response=stream_response,
)

def _listen_audio_msg(self, publisher, task_id: str):
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if publisher is None:
return None
audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish":
audio_msg = publisher.check_and_get_audio()
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
Expand Down

0 comments on commit 5093e4c

Please sign in to comment.