Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: enhance type hints and improve audio message handling in TTS pub… #11947

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions api/core/app/apps/advanced_chat/app_generator_tts_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
import queue
import re
import threading
from collections.abc import Iterable

from core.app.entities.queue_entities import (
MessageQueueMessage,
QueueAgentMessageEvent,
QueueLLMChunkEvent,
QueueNodeSucceededEvent,
QueueTextChunkEvent,
WorkflowQueueMessage,
)
from core.model_manager import ModelManager
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType


Expand All @@ -21,21 +24,27 @@ def __init__(self, status: str, audio):
self.status = status


def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str):
if not text_content or text_content.isspace():
return
return model_instance.invoke_tts(
content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice
)


def _process_future(future_queue, audio_queue):
def _process_future(
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None],
audio_queue: queue.Queue[AudioTrunk],
):
while True:
try:
future = future_queue.get()
if future is None:
break
for audio in future.result():
invoke_result = future.result()
if not invoke_result:
continue
for audio in invoke_result:
audio_base64 = base64.b64encode(bytes(audio))
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
except Exception as e:
Expand All @@ -49,8 +58,8 @@ def __init__(self, tenant_id: str, voice: str):
self.logger = logging.getLogger(__name__)
self.tenant_id = tenant_id
self.msg_text = ""
self._audio_queue = queue.Queue()
self._msg_queue = queue.Queue()
self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue()
self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
self.match = re.compile(r"[。.!?]")
self.model_manager = ModelManager()
self.model_instance = self.model_manager.get_default_model_instance(
Expand All @@ -66,14 +75,11 @@ def __init__(self, tenant_id: str, voice: str):
self._runtime_thread = threading.Thread(target=self._runtime).start()
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)

def publish(self, message):
try:
self._msg_queue.put(message)
except Exception as e:
self.logger.warning(e)
def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
self._msg_queue.put(message)

def _runtime(self):
future_queue = queue.Queue()
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None] = queue.Queue()
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
while True:
try:
Expand Down Expand Up @@ -110,7 +116,7 @@ def _runtime(self):
break
future_queue.put(None)

def check_and_get_audio(self) -> AudioTrunk | None:
def check_and_get_audio(self):
try:
if self._last_audio_event and self._last_audio_event.status == "finish":
if self.executor:
Expand Down
10 changes: 5 additions & 5 deletions api/core/app/apps/advanced_chat/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,11 @@ def _to_stream_response(
stream_response=stream_response,
)

def _listen_audio_msg(self, publisher, task_id: str):
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish":
audio_msg = publisher.check_and_get_audio()
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None

Expand All @@ -222,7 +222,7 @@ def _wrapper_process_stream_response(

for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
Expand Down Expand Up @@ -511,7 +511,7 @@ def _process_stream_response(

# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(message=queue_message)
tts_publisher.publish(queue_message)

self._task_state.answer += delta_text
yield self._message_to_stream_response(
Expand Down
7 changes: 4 additions & 3 deletions api/core/app/apps/base_app_queue_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import queue
import time
from abc import abstractmethod
from collections.abc import Generator
from enum import Enum
from typing import Any

Expand All @@ -11,9 +10,11 @@
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
MessageQueueMessage,
QueueErrorEvent,
QueuePingEvent,
QueueStopEvent,
WorkflowQueueMessage,
)
from extensions.ext_redis import redis_client

Expand All @@ -37,11 +38,11 @@ def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None:
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
)

q = queue.Queue()
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()

self._q = q

def listen(self) -> Generator:
def listen(self):
"""
Listen to queue
:return:
Expand Down
10 changes: 5 additions & 5 deletions api/core/app/apps/workflow/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,11 @@ def _to_stream_response(

yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)

def _listen_audio_msg(self, publisher, task_id: str):
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish":
audio_msg = publisher.check_and_get_audio()
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None

Expand All @@ -196,7 +196,7 @@ def _wrapper_process_stream_response(

for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
Expand Down Expand Up @@ -421,7 +421,7 @@ def _process_stream_response(

# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(message=queue_message)
tts_publisher.publish(queue_message)

self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,11 @@ def _to_stream_response(
stream_response=stream_response,
)

def _listen_audio_msg(self, publisher, task_id: str):
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if publisher is None:
return None
audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish":
audio_msg = publisher.check_and_get_audio()
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
Expand Down
Loading