-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
broadcast_queue.py
127 lines (94 loc) · 4.51 KB
/
broadcast_queue.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from collections import deque
from dataclasses import dataclass, field
from pydantic import Field, SkipValidation, ValidationError, model_validator
from semantic_kernel.agents.channels.agent_channel import AgentChannel
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.kernel_pydantic import KernelBaseModel
from semantic_kernel.utils.experimental_decorator import experimental_class
@experimental_class
class QueueReference(KernelBaseModel):
"""Utility class to associate a queue with its specific lock."""
queue: deque = Field(default_factory=deque)
queue_lock: SkipValidation[asyncio.Lock] = Field(default_factory=asyncio.Lock, exclude=True)
receive_task: SkipValidation[asyncio.Task | None] = None
receive_failure: Exception | None = None
@property
def is_empty(self):
"""Check if the queue is empty."""
return len(self.queue) == 0
@model_validator(mode="before")
def validate_receive_task(cls, values):
"""Validate the receive task."""
receive_task = values.get("receive_task")
if receive_task is not None and not isinstance(receive_task, asyncio.Task):
raise ValidationError("receive_task must be an instance of asyncio.Task or None")
return values
@experimental_class
@dataclass
class ChannelReference:
"""Tracks a channel along with its hashed key."""
hash: str
channel: AgentChannel = field(default_factory=AgentChannel)
@experimental_class
class BroadcastQueue(KernelBaseModel):
"""A queue for broadcasting messages to listeners."""
queues: dict[str, QueueReference] = Field(default_factory=dict)
block_duration: float = 0.1
async def enqueue(self, channel_refs: list[ChannelReference], messages: list[ChatMessageContent]) -> None:
"""Enqueue a set of messages for a given channel.
Args:
channel_refs: The channel references.
messages: The messages to broadcast.
"""
for channel_ref in channel_refs:
if channel_ref.hash not in self.queues:
self.queues[channel_ref.hash] = QueueReference()
queue_ref = self.queues[channel_ref.hash]
async with queue_ref.queue_lock:
queue_ref.queue.append(messages)
if not queue_ref.receive_task or queue_ref.receive_task.done():
queue_ref.receive_task = asyncio.create_task(self.receive(channel_ref, queue_ref))
async def ensure_synchronized(self, channel_ref: ChannelReference) -> None:
"""Blocks until a channel-queue is not in a receive state to ensure that channel history is complete.
Args:
channel_ref: The channel reference.
"""
if channel_ref.hash not in self.queues:
return
queue_ref = self.queues[channel_ref.hash]
while True:
async with queue_ref.queue_lock:
is_empty = queue_ref.is_empty
if queue_ref.receive_failure is not None:
failure = queue_ref.receive_failure
queue_ref.receive_failure = None
raise Exception(
f"Unexpected failure broadcasting to channel: {type(channel_ref.channel)}, failure: {failure}"
) from failure
if not is_empty and (not queue_ref.receive_task or queue_ref.receive_task.done()):
queue_ref.receive_task = asyncio.create_task(self.receive(channel_ref, queue_ref))
if is_empty:
break
await asyncio.sleep(self.block_duration)
async def receive(self, channel_ref: ChannelReference, queue_ref: QueueReference) -> None:
"""Processes the specified queue with the provided channel, until the queue is empty.
Args:
channel_ref: The channel reference.
queue_ref: The queue reference.
"""
while True:
async with queue_ref.queue_lock:
if queue_ref.is_empty:
break
messages = queue_ref.queue[0]
try:
await channel_ref.channel.receive(messages)
except Exception as e:
queue_ref.receive_failure = e
async with queue_ref.queue_lock:
if not queue_ref.is_empty:
queue_ref.queue.popleft()
if queue_ref.receive_failure is not None or queue_ref.is_empty:
break