Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Merge Chat and Workflow Iterator into a generic Stream Iterator #32

Merged
merged 1 commit into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ coze.files.retrieve(file_id=file.id)
### Workflows

```python
from cozepy import Coze, TokenAuth, WorkflowEventType, WorkflowEventIterator
from cozepy import Coze, TokenAuth, Stream, WorkflowEvent, WorkflowEventType

coze = Coze(auth=TokenAuth("your_token"))

Expand All @@ -197,7 +197,7 @@ result = coze.workflows.runs.create(


# stream workflow run
def handle_workflow_iterator(iterator: WorkflowEventIterator):
def handle_workflow_iterator(iterator: Stream[WorkflowEvent]):
for event in iterator:
if event.event == WorkflowEventType.MESSAGE:
print('got message', event.message)
Expand Down
6 changes: 2 additions & 4 deletions cozepy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)
from .chat import (
Chat,
ChatChatIterator,
ChatEvent,
ChatEventType,
ChatStatus,
Expand Down Expand Up @@ -47,6 +46,7 @@
from .model import (
LastIDPaged,
NumberPaged,
Stream,
TokenPaged,
)
from .request import HTTPClient
Expand All @@ -56,7 +56,6 @@
WorkflowEventError,
WorkflowEventInterrupt,
WorkflowEventInterruptData,
WorkflowEventIterator,
WorkflowEventMessage,
WorkflowEventType,
WorkflowRunResult,
Expand Down Expand Up @@ -91,7 +90,6 @@
"Message",
"Chat",
"ChatEvent",
"ChatChatIterator",
"ToolOutput",
# conversations
"Conversation",
Expand All @@ -115,7 +113,6 @@
"WorkflowEventInterrupt",
"WorkflowEventError",
"WorkflowEvent",
"WorkflowEventIterator",
# workspaces
"WorkspaceRoleType",
"WorkspaceType",
Expand All @@ -137,6 +134,7 @@
"TokenPaged",
"NumberPaged",
"LastIDPaged",
"Stream",
# request
"HTTPClient",
]
96 changes: 37 additions & 59 deletions cozepy/chat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from enum import Enum
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from cozepy.auth import Auth
from cozepy.exception import CozeEventError
from cozepy.model import CozeModel
from cozepy.model import CozeModel, Stream
from cozepy.request import Requester

if TYPE_CHECKING:
Expand Down Expand Up @@ -232,57 +231,28 @@
message: Message = None


class ChatChatIterator(object):
def __init__(self, iters: Iterator[str]):
self._iters = iters

def __iter__(self):
return self

def __next__(self) -> ChatEvent:
event = ""
data = ""
line = ""
times = 0

while times < 2:
line = next(self._iters)
if line == "":
continue
elif line.startswith("event:"):
if event == "":
event = line[6:]
else:
raise CozeEventError("event", line)
elif line.startswith("data:"):
if data == "":
data = line[5:]
else:
raise CozeEventError("data", line)
else:
raise CozeEventError("", line)

times += 1

if event == ChatEventType.DONE:
raise StopIteration
elif event == ChatEventType.ERROR:
raise Exception(f"error event: {line}") # TODO: error struct format
elif event in [
ChatEventType.CONVERSATION_MESSAGE_DELTA,
ChatEventType.CONVERSATION_MESSAGE_COMPLETED,
]:
return ChatEvent(event=event, message=Message.model_validate_json(data))
elif event in [
ChatEventType.CONVERSATION_CHAT_CREATED,
ChatEventType.CONVERSATION_CHAT_IN_PROGRESS,
ChatEventType.CONVERSATION_CHAT_COMPLETED,
ChatEventType.CONVERSATION_CHAT_FAILED,
ChatEventType.CONVERSATION_CHAT_REQUIRES_ACTION,
]:
return ChatEvent(event=event, chat=Chat.model_validate_json(data))
else:
raise ValueError(f"invalid chat.event: {event}, {data}")
def _chat_stream_handler(data: Dict) -> ChatEvent:
event = data["event"]
data = data["data"]
if event == ChatEventType.DONE:
raise StopIteration
elif event == ChatEventType.ERROR:
raise Exception(f"error event: {data}") # TODO: error struct format

Check warning on line 240 in cozepy/chat/__init__.py

View check run for this annotation

Codecov / codecov/patch

cozepy/chat/__init__.py#L240

Added line #L240 was not covered by tests
elif event in [
ChatEventType.CONVERSATION_MESSAGE_DELTA,
ChatEventType.CONVERSATION_MESSAGE_COMPLETED,
]:
return ChatEvent(event=event, message=Message.model_validate_json(data))
elif event in [
ChatEventType.CONVERSATION_CHAT_CREATED,
ChatEventType.CONVERSATION_CHAT_IN_PROGRESS,
ChatEventType.CONVERSATION_CHAT_COMPLETED,
ChatEventType.CONVERSATION_CHAT_FAILED,
ChatEventType.CONVERSATION_CHAT_REQUIRES_ACTION,
]:
return ChatEvent(event=event, chat=Chat.model_validate_json(data))
else:
raise ValueError(f"invalid chat.event: {event}, {data}")

Check warning on line 255 in cozepy/chat/__init__.py

View check run for this annotation

Codecov / codecov/patch

cozepy/chat/__init__.py#L255

Added line #L255 was not covered by tests


class ToolOutput(CozeModel):
Expand Down Expand Up @@ -350,7 +320,7 @@
auto_save_history: bool = True,
meta_data: Dict[str, str] = None,
conversation_id: str = None,
) -> ChatChatIterator:
) -> Stream[ChatEvent]:
"""
Call the Chat API with streaming to send messages to a published Coze bot.

Expand Down Expand Up @@ -390,7 +360,7 @@
auto_save_history: bool = True,
meta_data: Dict[str, str] = None,
conversation_id: str = None,
) -> Union[Chat, ChatChatIterator]:
) -> Union[Chat, Stream[ChatEvent]]:
"""
Create a conversation.
Conversation is an interaction between a bot and a user, including one or more messages.
Expand All @@ -409,7 +379,11 @@
if not stream:
return self._requester.request("post", url, Chat, body=body, stream=stream)

return ChatChatIterator(self._requester.request("post", url, Chat, body=body, stream=stream))
return Stream(
self._requester.request("post", url, Chat, body=body, stream=stream),
fields=["event", "data"],
handler=_chat_stream_handler,
)

def retrieve(
self,
Expand All @@ -436,7 +410,7 @@

def submit_tool_outputs(
self, *, conversation_id: str, chat_id: str, tool_outputs: List[ToolOutput], stream: bool
) -> Union[Chat, ChatChatIterator]:
) -> Union[Chat, Stream[ChatEvent]]:
"""
Call this API to submit the results of tool execution.

Expand Down Expand Up @@ -466,7 +440,11 @@
if not stream:
return self._requester.request("post", url, Chat, params=params, body=body, stream=stream)

return ChatChatIterator(self._requester.request("post", url, Chat, params=params, body=body, stream=stream))
return Stream(

Check warning on line 443 in cozepy/chat/__init__.py

View check run for this annotation

Codecov / codecov/patch

cozepy/chat/__init__.py#L443

Added line #L443 was not covered by tests
self._requester.request("post", url, Chat, params=params, body=body, stream=stream),
fields=["event", "data"],
handler=_chat_stream_handler,
)

def cancel(
self,
Expand Down
42 changes: 40 additions & 2 deletions cozepy/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Generic, List, TypeVar
from typing import Callable, Dict, Generic, Iterator, List, Tuple, TypeVar

from pydantic import BaseModel, ConfigDict

T = TypeVar("T", bound=BaseModel)
from cozepy.exception import CozeEventError

T = TypeVar("T")


class CozeModel(BaseModel):
Expand Down Expand Up @@ -67,3 +69,39 @@

def __repr__(self):
return f"LastIDPaged(items={self.items}, first_id={self.first_id}, last_id={self.last_id}, has_more={self.has_more})"


class Stream(Generic[T]):
def __init__(self, iters: Iterator[str], fields: List[str], handler: Callable[[Dict[str, str]], T]):
self._iters = iters
self._fields = fields
self._handler = handler

def __iter__(self):
return self

def __next__(self) -> T:
return self._handler(self._extra_event())

def _extra_event(self) -> Dict[str, str]:
data = dict(map(lambda x: (x, ""), self._fields))
times = 0

while times < len(data):
line = next(self._iters)
if line == "":
continue

field, value = self._extra_field_data(line, data)
data[field] = value
times += 1
return data

def _extra_field_data(self, line: str, data: Dict[str, str]) -> Tuple[str, str]:
for field in self._fields:
if line.startswith(field + ":"):
if data[field] == "":
return field, line[len(field) + 1 :].strip()
else:
raise CozeEventError(field, line)
raise CozeEventError("", line)

Check warning on line 107 in cozepy/model.py

View check run for this annotation

Codecov / codecov/patch

cozepy/model.py#L106-L107

Added lines #L106 - L107 were not covered by tests
Loading
Loading