diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ea6648d..bc52a26 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,10 +32,13 @@ jobs: pip install poetry poetry install - name: Build + run: | + poetry build + - name: Check run: | poetry run ruff check cozepy poetry run ruff format --check - poetry build + poetry run mypy . - name: Run tests run: poetry run pytest --cov --cov-report=xml - name: Upload coverage to Codecov diff --git a/.gitignore b/.gitignore index 6d062cf..5fdcaa5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ .idea/ -.venv/ +.venv*/ .DS_Store __pycache__/ dist/ diff --git a/cozepy/auth/__init__.py b/cozepy/auth/__init__.py index 2e85861..d1dd06e 100644 --- a/cozepy/auth/__init__.py +++ b/cozepy/auth/__init__.py @@ -3,7 +3,7 @@ from typing import List, Optional from urllib.parse import quote_plus, urlparse -from authlib.jose import jwt +from authlib.jose import jwt # type: ignore from typing_extensions import Literal from cozepy.config import COZE_CN_BASE_URL, COZE_COM_BASE_URL @@ -55,7 +55,7 @@ class Scope(CozeModel): attribute_constraint: Optional[ScopeAttributeConstraint] = None @staticmethod - def from_bot_chat(bot_id_list: List[str], permission_list: List[str] = None) -> "Scope": + def from_bot_chat(bot_id_list: List[str], permission_list: Optional[List[str]] = None) -> "Scope": if not permission_list: permission_list = ["Connector.botChat"] return Scope( @@ -80,9 +80,9 @@ def _get_oauth_url( self, redirect_uri: str, state: str, - code_challenge: str = None, - code_challenge_method: str = None, - workspace_id: str = None, + code_challenge: Optional[str] = None, + code_challenge_method: Optional[str] = None, + workspace_id: Optional[str] = None, ): params = { "response_type": "code", @@ -92,7 +92,9 @@ def _get_oauth_url( } if code_challenge: params["code_challenge"] = code_challenge + if code_challenge_method: params["code_challenge_method"] = code_challenge_method + uri = f"{self._get_www_base_url}/api/permission/oauth2/authorize" if workspace_id: uri = f"{self._get_www_base_url}/api/permission/oauth2/workspace_id/{workspace_id}/authorize" @@ -106,7 +108,7 @@ def _refresh_access_token(self, refresh_token: str, secret: str = "") -> OAuthTo "client_id": self._client_id, "refresh_token": refresh_token, } - return self._requester.request("post", url, OAuthToken, headers=headers, body=body) + return self._requester.request("post", url, False, OAuthToken, headers=headers, body=body) async def _arefresh_access_token(self, refresh_token: str, secret: str = "") -> OAuthToken: url = f"{self._base_url}/api/permission/oauth2/token" @@ -116,7 +118,7 @@ async def _arefresh_access_token(self, refresh_token: str, secret: str = "") -> "client_id": self._client_id, "refresh_token": refresh_token, } - return await self._requester.arequest("post", url, OAuthToken, headers=headers, body=body) + return await self._requester.arequest("post", url, False, OAuthToken, headers=headers, body=body) @property def _get_www_base_url(self) -> str: @@ -149,7 +151,7 @@ def get_oauth_url( self, redirect_uri: str, state: str, - workspace_id: str = None, + workspace_id: Optional[str] = None, ): """ Get the pkce flow authorized url. @@ -183,13 +185,13 @@ def get_access_token( "code": code, "redirect_uri": redirect_uri, } - return self._requester.request("post", url, OAuthToken, headers=headers, body=body) + return self._requester.request("post", url, False, OAuthToken, headers=headers, body=body) def refresh_access_token(self, refresh_token: str) -> OAuthToken: return self._refresh_access_token(refresh_token, self._client_secret) -class AsyncWebOAuthApp(WebOAuthApp): +class AsyncWebOAuthApp(OAuthApp): """ Normal OAuth App. """ @@ -205,7 +207,29 @@ def __init__(self, client_id: str, client_secret: str, base_url: str = COZE_COM_ self._base_url = base_url self._api_endpoint = urlparse(base_url).netloc self._token = "" - super().__init__(client_id, client_secret, base_url, www_base_url=www_base_url) + super().__init__(client_id, base_url, www_base_url=www_base_url) + + def get_oauth_url( + self, + redirect_uri: str, + state: str, + workspace_id: Optional[str] = None, + ): + """ + Get the pkce flow authorized url. + + :param redirect_uri: The redirect_uri of your app, where authentication responses can be sent and received by + your app. It must exactly match one of the redirect URIs you registered in the OAuth Apps. + :param state: A value included in the request that is also returned in the token response. It can be a string + of any hash value. + :param workspace_id: + :return: + """ + return self._get_oauth_url( + redirect_uri, + state, + workspace_id=workspace_id, + ) async def get_access_token( self, @@ -223,7 +247,7 @@ async def get_access_token( "code": code, "redirect_uri": redirect_uri, } - return await self._requester.arequest("post", url, OAuthToken, headers=headers, body=body) + return await self._requester.arequest("post", url, False, OAuthToken, headers=headers, body=body) async def refresh_access_token(self, refresh_token: str) -> OAuthToken: return await self._arefresh_access_token(refresh_token, self._client_secret) @@ -249,7 +273,7 @@ def __init__(self, client_id: str, private_key: str, public_key_id: str, base_ur self._public_key_id = public_key_id super().__init__(client_id, base_url, www_base_url="") - def get_access_token(self, ttl: int, scope: Scope = None) -> OAuthToken: + def get_access_token(self, ttl: int, scope: Optional[Scope] = None) -> OAuthToken: """ Get the token by jwt with jwt auth flow. """ @@ -261,7 +285,7 @@ def get_access_token(self, ttl: int, scope: Scope = None) -> OAuthToken: "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", "scope": scope.model_dump() if scope else None, } - return self._requester.request("post", url, OAuthToken, headers=headers, body=body) + return self._requester.request("post", url, False, OAuthToken, headers=headers, body=body) def _gen_jwt(self, ttl: int): now = int(time.time()) @@ -297,7 +321,7 @@ def __init__(self, client_id: str, private_key: str, public_key_id: str, base_ur self._public_key_id = public_key_id super().__init__(client_id, base_url, www_base_url="") - async def get_access_token(self, ttl: int, scope: Scope = None) -> OAuthToken: + async def get_access_token(self, ttl: int, scope: Optional[Scope] = None) -> OAuthToken: """ Get the token by jwt with jwt auth flow. """ @@ -309,7 +333,7 @@ async def get_access_token(self, ttl: int, scope: Scope = None) -> OAuthToken: "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", "scope": scope.model_dump() if scope else None, } - return await self._requester.arequest("post", url, OAuthToken, headers=headers, body=body) + return await self._requester.arequest("post", url, False, OAuthToken, headers=headers, body=body) def _gen_jwt(self, ttl: int): now = int(time.time()) @@ -345,7 +369,7 @@ def get_oauth_url( state: str, code_verifier: str, code_challenge_method: Literal["plain", "S256"] = "plain", - workspace_id: str = None, + workspace_id: Optional[str] = None, ): """ Get the pkce flow authorized url. @@ -386,13 +410,13 @@ def get_access_token(self, redirect_uri: str, code: str, code_verifier: str) -> "redirect_uri": redirect_uri, "code_verifier": code_verifier, } - return self._requester.request("post", url, OAuthToken, body=body) + return self._requester.request("post", url, False, OAuthToken, body=body) def refresh_access_token(self, refresh_token: str) -> OAuthToken: return self._refresh_access_token(refresh_token) -class AsyncPKCEOAuthApp(PKCEOAuthApp): +class AsyncPKCEOAuthApp(OAuthApp): """ PKCE OAuth App. """ @@ -406,6 +430,36 @@ def __init__(self, client_id: str, base_url: str = COZE_COM_BASE_URL, www_base_u www_base_url, ) + def get_oauth_url( + self, + redirect_uri: str, + state: str, + code_verifier: str, + code_challenge_method: Literal["plain", "S256"] = "plain", + workspace_id: Optional[str] = None, + ): + """ + Get the pkce flow authorized url. + + :param redirect_uri: The redirect_uri of your app, where authentication responses can be sent and received by + your app. It must exactly match one of the redirect URIs you registered in the OAuth Apps. + :param state: A value included in the request that is also returned in the token response. It can be a string + of any hash value. + :param code_verifier: + :param code_challenge_method: + :param workspace_id: + :return: + """ + code_challenge = code_verifier if code_challenge_method == "plain" else gen_s256_code_challenge(code_verifier) + + return self._get_oauth_url( + redirect_uri, + state, + code_challenge=code_challenge, + code_challenge_method=code_challenge_method, + workspace_id=workspace_id, + ) + async def get_access_token(self, redirect_uri: str, code: str, code_verifier: str) -> OAuthToken: """ Get the token with pkce auth flow. @@ -423,7 +477,7 @@ async def get_access_token(self, redirect_uri: str, code: str, code_verifier: st "redirect_uri": redirect_uri, "code_verifier": code_verifier, } - return await self._requester.arequest("post", url, OAuthToken, body=body) + return await self._requester.arequest("post", url, False, OAuthToken, body=body) async def refresh_access_token(self, refresh_token: str) -> OAuthToken: return await self._arefresh_access_token(refresh_token) @@ -445,7 +499,7 @@ def __init__(self, client_id: str, base_url: str = COZE_COM_BASE_URL, www_base_u def get_device_code( self, - workspace_id: str = None, + workspace_id: Optional[str] = None, ) -> DeviceAuthCode: """ Get the pkce flow authorized url. @@ -463,7 +517,7 @@ def get_device_code( headers = { "Content-Type": "application/json", } - res = self._requester.request("post", uri, DeviceAuthCode, headers=headers, body=body) + res = self._requester.request("post", uri, False, DeviceAuthCode, headers=headers, body=body) res.verification_url = f"{res.verification_uri}?user_code={res.user_code}" return res @@ -508,13 +562,13 @@ def _get_access_token(self, device_code: str, poll: bool = False) -> OAuthToken: "grant_type": "urn:ietf:params:oauth:grant-type:device_code", "device_code": device_code, } - return self._requester.request("post", url, OAuthToken, body=body) + return self._requester.request("post", url, False, OAuthToken, body=body) def refresh_access_token(self, refresh_token: str) -> OAuthToken: return self._refresh_access_token(refresh_token) -class AsyncDeviceOAuthApp(DeviceOAuthApp): +class AsyncDeviceOAuthApp(OAuthApp): """ Device OAuth App. """ @@ -528,7 +582,7 @@ def __init__(self, client_id: str, base_url: str = COZE_COM_BASE_URL, www_base_u www_base_url, ) - async def get_device_code(self, workspace_id: str = None) -> DeviceAuthCode: + async def get_device_code(self, workspace_id: Optional[str] = None) -> DeviceAuthCode: """ Get the pkce flow authorized url. @@ -545,7 +599,7 @@ async def get_device_code(self, workspace_id: str = None) -> DeviceAuthCode: headers = { "Content-Type": "application/json", } - res = await self._requester.arequest("post", uri, DeviceAuthCode, headers=headers, body=body) + res = await self._requester.arequest("post", uri, False, DeviceAuthCode, headers=headers, body=body) res.verification_url = f"{res.verification_uri}?user_code={res.user_code}" return res @@ -590,7 +644,7 @@ async def _get_access_token(self, device_code: str, poll: bool = False) -> OAuth "grant_type": "urn:ietf:params:oauth:grant-type:device_code", "device_code": device_code, } - return await self._requester.arequest("post", url, OAuthToken, body=body) + return await self._requester.arequest("post", url, False, OAuthToken, body=body) async def refresh_access_token(self, refresh_token: str) -> OAuthToken: return await self._arefresh_access_token(refresh_token) diff --git a/cozepy/bots/__init__.py b/cozepy/bots/__init__.py index af2e986..d8005d4 100644 --- a/cozepy/bots/__init__.py +++ b/cozepy/bots/__init__.py @@ -58,17 +58,17 @@ class Bot(CozeModel): # The ID for the bot. bot_id: str # The name of the bot. - name: str = None + name: Optional[str] = None # The description of the bot. - description: str = None + description: Optional[str] = None # The URL address for the bot's avatar. - icon_url: str = None + icon_url: Optional[str] = None # The creation time, in the format of a 10-digit Unix timestamp in seconds (s). - create_time: int = None + create_time: Optional[int] = None # The update time, in the format of a 10-digit Unix timestamp in seconds (s). - update_time: int = None + update_time: Optional[int] = None # The latest version of the bot. - version: str = None + version: Optional[str] = None # The prompt configuration for the bot. For more information, see Prompt object. prompt_info: Optional[BotPromptInfo] = None # The onboarding message configuration for the bot. For more information, see Onboarding object. @@ -110,10 +110,10 @@ def create( *, space_id: str, name: str, - description: str = None, - icon_file_id: str = None, - prompt_info: BotPromptInfo = None, - onboarding_info: BotOnboardingInfo = None, + description: Optional[str] = None, + icon_file_id: Optional[str] = None, + prompt_info: Optional[BotPromptInfo] = None, + onboarding_info: Optional[BotOnboardingInfo] = None, ) -> Bot: url = f"{self._base_url}/v1/bot/create" body = { @@ -125,17 +125,17 @@ def create( "onboarding_info": onboarding_info.model_dump() if onboarding_info else None, } - return self._requester.request("post", url, Bot, body=body) + return self._requester.request("post", url, False, Bot, body=body) def update( self, *, bot_id: str, - name: str = None, - description: str = None, - icon_file_id: str = None, - prompt_info: BotPromptInfo = None, - onboarding_info: BotOnboardingInfo = None, + name: Optional[str] = None, + description: Optional[str] = None, + icon_file_id: Optional[str] = None, + prompt_info: Optional[BotPromptInfo] = None, + onboarding_info: Optional[BotOnboardingInfo] = None, ) -> None: """ Update the configuration of a bot. @@ -166,13 +166,19 @@ def update( "onboarding_info": onboarding_info.model_dump() if onboarding_info else None, } - return self._requester.request("post", url, None, body=body) + return self._requester.request( + "post", + url, + False, + None, + body=body, + ) def publish( self, *, bot_id: str, - connector_ids: List[str] = None, + connector_ids: Optional[List[str]] = None, ) -> Bot: url = f"{self._base_url}/v1/bot/publish" if not connector_ids: @@ -182,7 +188,7 @@ def publish( "connector_ids": connector_ids, } - return self._requester.request("post", url, Bot, body=body) + return self._requester.request("post", url, False, Bot, body=body) def retrieve(self, *, bot_id: str) -> Bot: """ @@ -201,7 +207,7 @@ def retrieve(self, *, bot_id: str) -> Bot: url = f"{self._base_url}/v1/bot/get_online_info" params = {"bot_id": bot_id} - return self._requester.request("get", url, Bot, params=params) + return self._requester.request("get", url, False, Bot, params=params) def list(self, *, space_id: str, page_num: int = 1, page_size: int = 20) -> NumberPaged[SimpleBot]: """ @@ -227,7 +233,7 @@ def list(self, *, space_id: str, page_num: int = 1, page_size: int = 20) -> Numb "page_size": page_size, "page_index": page_num, } - data = self._requester.request("get", url, self._PrivateListPublishedBotsV1Data, params=params) + data = self._requester.request("get", url, False, self._PrivateListPublishedBotsV1Data, params=params) return NumberPaged( items=data.space_bots, page_num=page_num, @@ -255,10 +261,10 @@ async def create( *, space_id: str, name: str, - description: str = None, - icon_file_id: str = None, - prompt_info: BotPromptInfo = None, - onboarding_info: BotOnboardingInfo = None, + description: Optional[str] = None, + icon_file_id: Optional[str] = None, + prompt_info: Optional[BotPromptInfo] = None, + onboarding_info: Optional[BotOnboardingInfo] = None, ) -> Bot: url = f"{self._base_url}/v1/bot/create" body = { @@ -270,17 +276,17 @@ async def create( "onboarding_info": onboarding_info.model_dump() if onboarding_info else None, } - return await self._requester.arequest("post", url, Bot, body=body) + return await self._requester.arequest("post", url, False, Bot, body=body) async def update( self, *, bot_id: str, - name: str = None, - description: str = None, - icon_file_id: str = None, - prompt_info: BotPromptInfo = None, - onboarding_info: BotOnboardingInfo = None, + name: Optional[str] = None, + description: Optional[str] = None, + icon_file_id: Optional[str] = None, + prompt_info: Optional[BotPromptInfo] = None, + onboarding_info: Optional[BotOnboardingInfo] = None, ) -> None: """ Update the configuration of a bot. @@ -311,13 +317,13 @@ async def update( "onboarding_info": onboarding_info.model_dump() if onboarding_info else None, } - return await self._requester.arequest("post", url, None, body=body) + return await self._requester.arequest("post", url, False, None, body=body) async def publish( self, *, bot_id: str, - connector_ids: List[str] = None, + connector_ids: Optional[List[str]] = None, ) -> Bot: url = f"{self._base_url}/v1/bot/publish" if not connector_ids: @@ -327,7 +333,7 @@ async def publish( "connector_ids": connector_ids, } - return await self._requester.arequest("post", url, Bot, body=body) + return await self._requester.arequest("post", url, False, Bot, body=body) async def retrieve(self, *, bot_id: str) -> Bot: """ @@ -346,7 +352,7 @@ async def retrieve(self, *, bot_id: str) -> Bot: url = f"{self._base_url}/v1/bot/get_online_info" params = {"bot_id": bot_id} - return await self._requester.arequest("get", url, Bot, params=params) + return await self._requester.arequest("get", url, False, Bot, params=params) async def list(self, *, space_id: str, page_num: int = 1, page_size: int = 20) -> NumberPaged[SimpleBot]: """ @@ -372,7 +378,7 @@ async def list(self, *, space_id: str, page_num: int = 1, page_size: int = 20) - "page_size": page_size, "page_index": page_num, } - data = await self._requester.arequest("get", url, self._PrivateListPublishedBotsV1Data, params=params) + data = await self._requester.arequest("get", url, False, self._PrivateListPublishedBotsV1Data, params=params) return NumberPaged( items=data.space_bots, page_num=page_num, diff --git a/cozepy/chat/__init__.py b/cozepy/chat/__init__.py index b96de56..4722694 100644 --- a/cozepy/chat/__init__.py +++ b/cozepy/chat/__init__.py @@ -1,14 +1,15 @@ from enum import Enum from functools import partial -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Union, cast, overload + +from typing_extensions import Literal from cozepy.auth import Auth from cozepy.model import AsyncStream, CozeModel, Stream from cozepy.request import Requester if TYPE_CHECKING: - from .message import AsyncMessagesClient as AsyncChatMessagesClient - from .message import MessagesClient as ChatMessagesClient + from .message import AsyncChatMessagesClient, ChatMessagesClient class MessageRole(str, Enum): @@ -96,7 +97,7 @@ class Message(CozeModel): # The entity that sent this message. role: MessageRole # The type of message. - type: MessageType = "" + type: MessageType = MessageType.UNKNOWN # The content of the message. It supports various types of content, including plain text, multimodal (a mix of text, images, and files), message cards, and more. # 消息的内容,支持纯文本、多模态(文本、图片、文件混合输入)、卡片等多种类型的内容。 content: str @@ -229,24 +230,24 @@ class ChatEventType(str, Enum): class ChatEvent(CozeModel): event: ChatEventType - chat: Chat = None - message: Message = None + chat: Optional[Chat] = None + message: Optional[Message] = None def _chat_stream_handler(data: Dict, is_async: bool = False) -> ChatEvent: event = data["event"] - data = data["data"] + event_data = data["data"] # type: str if event == ChatEventType.DONE: if is_async: raise StopAsyncIteration raise StopIteration elif event == ChatEventType.ERROR: - raise Exception(f"error event: {data}") # TODO: error struct format + raise Exception(f"error event: {event_data}") # TODO: error struct format elif event in [ ChatEventType.CONVERSATION_MESSAGE_DELTA, ChatEventType.CONVERSATION_MESSAGE_COMPLETED, ]: - return ChatEvent(event=event, message=Message.model_validate_json(data)) + return ChatEvent(event=event, message=Message.model_validate_json(event_data)) elif event in [ ChatEventType.CONVERSATION_CHAT_CREATED, ChatEventType.CONVERSATION_CHAT_IN_PROGRESS, @@ -254,13 +255,15 @@ def _chat_stream_handler(data: Dict, is_async: bool = False) -> ChatEvent: ChatEventType.CONVERSATION_CHAT_FAILED, ChatEventType.CONVERSATION_CHAT_REQUIRES_ACTION, ]: - return ChatEvent(event=event, chat=Chat.model_validate_json(data)) + return ChatEvent(event=event, chat=Chat.model_validate_json(event_data)) else: raise ValueError(f"invalid chat.event: {event}, {data}") -_sync_chat_stream_handler = partial(_chat_stream_handler, is_async=False) -_async_chat_stream_handler = partial(_chat_stream_handler, is_async=True) +_sync_chat_stream_handler = cast(Callable[[Dict[str, str]], ChatEvent], partial(_chat_stream_handler, is_async=False)) +_async_chat_stream_handler = cast( + Callable[[Dict[str, str]], Awaitable[ChatEvent]], partial(_chat_stream_handler, is_async=True) +) class ToolOutput(CozeModel): @@ -277,18 +280,18 @@ def __init__(self, base_url: str, auth: Auth, requester: Requester): self._base_url = base_url self._auth = auth self._requester = requester - self._messages = None + self._messages: Optional[ChatMessagesClient] = None def create( self, *, bot_id: str, user_id: str, - conversation_id: str = None, - additional_messages: List[Message] = None, - custom_variables: Dict[str, str] = None, + conversation_id: Optional[str] = None, + additional_messages: Optional[List[Message]] = None, + custom_variables: Optional[Dict[str, str]] = None, auto_save_history: bool = True, - meta_data: Dict[str, str] = None, + meta_data: Optional[Dict[str, str]] = None, ) -> Chat: """ Call the Chat API with non-streaming to send messages to a published Coze bot. @@ -310,8 +313,8 @@ def create( return self._create( bot_id=bot_id, user_id=user_id, - additional_messages=additional_messages, stream=False, + additional_messages=additional_messages, custom_variables=custom_variables, auto_save_history=auto_save_history, meta_data=meta_data, @@ -323,11 +326,11 @@ def stream( *, bot_id: str, user_id: str, - additional_messages: List[Message] = None, - custom_variables: Dict[str, str] = None, + additional_messages: Optional[List[Message]] = None, + custom_variables: Optional[Dict[str, str]] = None, auto_save_history: bool = True, - meta_data: Dict[str, str] = None, - conversation_id: str = None, + meta_data: Optional[Dict[str, str]] = None, + conversation_id: Optional[str] = None, ) -> Stream[ChatEvent]: """ Call the Chat API with streaming to send messages to a published Coze bot. @@ -349,25 +352,53 @@ def stream( return self._create( bot_id=bot_id, user_id=user_id, - additional_messages=additional_messages, stream=True, + additional_messages=additional_messages, custom_variables=custom_variables, auto_save_history=auto_save_history, meta_data=meta_data, conversation_id=conversation_id, ) + @overload def _create( self, *, bot_id: str, user_id: str, - additional_messages: List[Message] = None, - stream: bool = False, - custom_variables: Dict[str, str] = None, + stream: Literal[True], + additional_messages: Optional[List[Message]] = ..., + custom_variables: Optional[Dict[str, str]] = ..., + auto_save_history: bool = ..., + meta_data: Optional[Dict[str, str]] = ..., + conversation_id: Optional[str] = ..., + ) -> Stream[ChatEvent]: ... + + @overload + def _create( + self, + *, + bot_id: str, + user_id: str, + stream: Literal[False], + additional_messages: Optional[List[Message]] = ..., + custom_variables: Optional[Dict[str, str]] = ..., + auto_save_history: bool = ..., + meta_data: Optional[Dict[str, str]] = ..., + conversation_id: Optional[str] = ..., + ) -> Chat: ... + + def _create( + self, + *, + bot_id: str, + user_id: str, + stream: Literal[True, False], + additional_messages: Optional[List[Message]] = None, + custom_variables: Optional[Dict[str, str]] = None, auto_save_history: bool = True, - meta_data: Dict[str, str] = None, - conversation_id: str = None, + meta_data: Optional[Dict[str, str]] = None, + conversation_id: Optional[str] = None, ) -> Union[Chat, Stream[ChatEvent]]: """ Create a conversation. @@ -385,9 +416,21 @@ def _create( "meta_data": meta_data, } if not stream: - return self._requester.request("post", url, Chat, body=body, stream=stream) - - steam_iters, logid = self._requester.request("post", url, Chat, body=body, stream=stream) + return self._requester.request( + "post", + url, + False, + Chat, + body=body, + ) + + steam_iters, logid = self._requester.request( + "post", + url, + True, + None, + body=body, + ) return Stream(steam_iters, fields=["event", "data"], handler=_sync_chat_stream_handler, logid=logid) def retrieve( @@ -411,7 +454,7 @@ def retrieve( "conversation_id": conversation_id, "chat_id": chat_id, } - return self._requester.request("post", url, Chat, params=params) + return self._requester.request("post", url, False, Chat, params=params) def submit_tool_outputs( self, *, conversation_id: str, chat_id: str, tool_outputs: List[ToolOutput], stream: bool @@ -443,9 +486,23 @@ def submit_tool_outputs( } if not stream: - return self._requester.request("post", url, Chat, params=params, body=body, stream=stream) - - steam_iters, logid = self._requester.request("post", url, Chat, params=params, body=body, stream=stream) + return self._requester.request( + "post", + url, + False, + Chat, + params=params, + body=body, + ) + + steam_iters, logid = self._requester.request( + "post", + url, + True, + None, + params=params, + body=body, + ) return Stream(steam_iters, fields=["event", "data"], handler=_sync_chat_stream_handler, logid=logid) def cancel( @@ -471,16 +528,16 @@ def cancel( "conversation_id": conversation_id, "chat_id": chat_id, } - return self._requester.request("post", url, Chat, params=params) + return self._requester.request("post", url, False, Chat, params=params) @property def messages( self, ) -> "ChatMessagesClient": if self._messages is None: - from .message import MessagesClient + from .message import ChatMessagesClient - self._messages = MessagesClient(self._base_url, self._auth, self._requester) + self._messages = ChatMessagesClient(self._base_url, self._auth, self._requester) return self._messages @@ -489,18 +546,18 @@ def __init__(self, base_url: str, auth: Auth, requester: Requester): self._base_url = base_url self._auth = auth self._requester = requester - self._messages = None + self._messages: Optional[AsyncChatMessagesClient] = None async def create( self, *, bot_id: str, user_id: str, - conversation_id: str = None, - additional_messages: List[Message] = None, - custom_variables: Dict[str, str] = None, + conversation_id: Optional[str] = None, + additional_messages: Optional[List[Message]] = None, + custom_variables: Optional[Dict[str, str]] = None, auto_save_history: bool = True, - meta_data: Dict[str, str] = None, + meta_data: Optional[Dict[str, str]] = None, ) -> Chat: """ Call the Chat API with non-streaming to send messages to a published Coze bot. @@ -535,11 +592,11 @@ async def stream( *, bot_id: str, user_id: str, - additional_messages: List[Message] = None, - custom_variables: Dict[str, str] = None, + additional_messages: Optional[List[Message]] = None, + custom_variables: Optional[Dict[str, str]] = None, auto_save_history: bool = True, - meta_data: Dict[str, str] = None, - conversation_id: str = None, + meta_data: Optional[Dict[str, str]] = None, + conversation_id: Optional[str] = None, ) -> AsyncStream[ChatEvent]: """ Call the Chat API with streaming to send messages to a published Coze bot. @@ -569,17 +626,45 @@ async def stream( conversation_id=conversation_id, ) + @overload + async def _create( + self, + *, + bot_id: str, + user_id: str, + stream: Literal[True], + additional_messages: Optional[List[Message]] = ..., + custom_variables: Optional[Dict[str, str]] = ..., + auto_save_history: bool = ..., + meta_data: Optional[Dict[str, str]] = ..., + conversation_id: Optional[str] = ..., + ) -> AsyncStream[ChatEvent]: ... + + @overload + async def _create( + self, + *, + bot_id: str, + user_id: str, + stream: Literal[False], + additional_messages: Optional[List[Message]] = ..., + custom_variables: Optional[Dict[str, str]] = ..., + auto_save_history: bool = ..., + meta_data: Optional[Dict[str, str]] = ..., + conversation_id: Optional[str] = ..., + ) -> Chat: ... + async def _create( self, *, bot_id: str, user_id: str, - additional_messages: List[Message] = None, - stream: bool = False, - custom_variables: Dict[str, str] = None, + stream: Literal[True, False], + additional_messages: Optional[List[Message]] = None, + custom_variables: Optional[Dict[str, str]] = None, auto_save_history: bool = True, - meta_data: Dict[str, str] = None, - conversation_id: str = None, + meta_data: Optional[Dict[str, str]] = None, + conversation_id: Optional[str] = None, ) -> Union[Chat, AsyncStream[ChatEvent]]: """ Create a conversation. @@ -597,9 +682,21 @@ async def _create( "meta_data": meta_data, } if not stream: - return await self._requester.arequest("post", url, Chat, body=body, stream=stream) - - steam_iters, logid = await self._requester.arequest("post", url, Chat, body=body, stream=stream) + return await self._requester.arequest( + "post", + url, + False, + Chat, + body=body, + ) + + steam_iters, logid = await self._requester.arequest( + "post", + url, + True, + None, + body=body, + ) return AsyncStream(steam_iters, fields=["event", "data"], handler=_async_chat_stream_handler, logid=logid) async def retrieve( @@ -623,7 +720,7 @@ async def retrieve( "conversation_id": conversation_id, "chat_id": chat_id, } - return await self._requester.arequest("post", url, Chat, params=params) + return await self._requester.arequest("post", url, False, Chat, params=params) async def submit_tool_outputs( self, *, conversation_id: str, chat_id: str, tool_outputs: List[ToolOutput], stream: bool @@ -655,9 +752,9 @@ async def submit_tool_outputs( } if not stream: - return await self._requester.arequest("post", url, Chat, params=params, body=body, stream=stream) + return await self._requester.arequest("post", url, False, Chat, params=params, body=body) - steam_iters, logid = await self._requester.arequest("post", url, Chat, params=params, body=body, stream=stream) + steam_iters, logid = await self._requester.arequest("post", url, True, None, params=params, body=body) return AsyncStream(steam_iters, fields=["event", "data"], handler=_async_chat_stream_handler, logid=logid) async def cancel( @@ -683,14 +780,14 @@ async def cancel( "conversation_id": conversation_id, "chat_id": chat_id, } - return await self._requester.arequest("post", url, Chat, params=params) + return await self._requester.arequest("post", url, False, Chat, params=params) @property def messages( self, ) -> "AsyncChatMessagesClient": if self._messages is None: - from .message import AsyncMessagesClient + from .message import AsyncChatMessagesClient - self._messages = AsyncMessagesClient(self._base_url, self._auth, self._requester) + self._messages = AsyncChatMessagesClient(self._base_url, self._auth, self._requester) return self._messages diff --git a/cozepy/chat/message/__init__.py b/cozepy/chat/message/__init__.py index f4988e4..ec754ae 100644 --- a/cozepy/chat/message/__init__.py +++ b/cozepy/chat/message/__init__.py @@ -5,7 +5,7 @@ from cozepy.request import Requester -class MessagesClient(object): +class ChatMessagesClient(object): def __init__(self, base_url: str, auth: Auth, requester: Requester): self._base_url = base_url self._auth = auth @@ -33,10 +33,10 @@ def list( "conversation_id": conversation_id, "chat_id": chat_id, } - return self._requester.request("post", url, [Message], params=params) + return self._requester.request("post", url, False, [Message], params=params) -class AsyncMessagesClient(object): +class AsyncChatMessagesClient(object): def __init__(self, base_url: str, auth: Auth, requester: Requester): self._base_url = base_url self._auth = auth @@ -64,4 +64,4 @@ async def list( "conversation_id": conversation_id, "chat_id": chat_id, } - return await self._requester.arequest("post", url, [Message], params=params) + return await self._requester.arequest("post", url, False, [Message], params=params) diff --git a/cozepy/conversations/__init__.py b/cozepy/conversations/__init__.py index 31fa431..5a6b410 100644 --- a/cozepy/conversations/__init__.py +++ b/cozepy/conversations/__init__.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Dict, List, Optional from cozepy.auth import Auth from cozepy.chat import Message @@ -19,7 +19,9 @@ def __init__(self, base_url: str, auth: Auth, requester: Requester): self._requester = requester self._messages = None - def create(self, *, messages: List[Message] = None, meta_data: Dict[str, str] = None) -> Conversation: + def create( + self, *, messages: Optional[List[Message]] = None, meta_data: Optional[Dict[str, str]] = None + ) -> Conversation: """ Create a conversation. Conversation is an interaction between a bot and a user, including one or more messages. @@ -37,7 +39,7 @@ def create(self, *, messages: List[Message] = None, meta_data: Dict[str, str] = "messages": [i.model_dump() for i in messages] if messages and len(messages) > 0 else [], "meta_data": meta_data, } - return self._requester.request("post", url, Conversation, body=body) + return self._requester.request("post", url, False, Conversation, body=body) def retrieve(self, *, conversation_id: str) -> Conversation: """ @@ -53,7 +55,7 @@ def retrieve(self, *, conversation_id: str) -> Conversation: params = { "conversation_id": conversation_id, } - return self._requester.request("get", url, Conversation, params=params) + return self._requester.request("get", url, False, Conversation, params=params) @property def messages(self): @@ -71,7 +73,9 @@ def __init__(self, base_url: str, auth: Auth, requester: Requester): self._requester = requester self._messages = None - async def create(self, *, messages: List[Message] = None, meta_data: Dict[str, str] = None) -> Conversation: + async def create( + self, *, messages: Optional[List[Message]] = None, meta_data: Optional[Dict[str, str]] = None + ) -> Conversation: """ Create a conversation. Conversation is an interaction between a bot and a user, including one or more messages. @@ -89,7 +93,7 @@ async def create(self, *, messages: List[Message] = None, meta_data: Dict[str, s "messages": [i.model_dump() for i in messages] if messages and len(messages) > 0 else [], "meta_data": meta_data, } - return await self._requester.arequest("post", url, Conversation, body=body) + return await self._requester.arequest("post", url, False, Conversation, body=body) async def retrieve(self, *, conversation_id: str) -> Conversation: """ @@ -105,7 +109,7 @@ async def retrieve(self, *, conversation_id: str) -> Conversation: params = { "conversation_id": conversation_id, } - return await self._requester.arequest("get", url, Conversation, params=params) + return await self._requester.arequest("get", url, False, Conversation, params=params) @property def messages(self): diff --git a/cozepy/conversations/message/__init__.py b/cozepy/conversations/message/__init__.py index 9dcda42..a680086 100644 --- a/cozepy/conversations/message/__init__.py +++ b/cozepy/conversations/message/__init__.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Dict, List, Optional from cozepy.auth import Auth from cozepy.chat import Message, MessageContentType, MessageRole @@ -23,7 +23,7 @@ def create( role: MessageRole, content: str, content_type: MessageContentType, - meta_data: Dict[str, str] = None, + meta_data: Optional[Dict[str, str]] = None, ) -> Message: """ Create a message and add it to the specified conversation. @@ -51,16 +51,16 @@ def create( "meta_data": meta_data, } - return self._requester.request("post", url, Message, params=params, body=body) + return self._requester.request("post", url, False, Message, params=params, body=body) def list( self, *, conversation_id: str, order: str = "desc", - chat_id: str = None, - before_id: str = None, - after_id: str = None, + chat_id: Optional[str] = None, + before_id: Optional[str] = None, + after_id: Optional[str] = None, limit: int = 50, ) -> LastIDPaged[Message]: """ @@ -89,7 +89,7 @@ def list( "limit": limit, } - res = self._requester.request("post", url, self._PrivateListMessageResp, params=params, body=body) + res = self._requester.request("post", url, False, self._PrivateListMessageResp, params=params, body=body) return LastIDPaged(res.items, res.first_id, res.last_id, res.has_more) def retrieve( @@ -114,16 +114,16 @@ def retrieve( "message_id": message_id, } - return self._requester.request("get", url, Message, params=params) + return self._requester.request("get", url, False, Message, params=params) def update( self, *, conversation_id: str, message_id: str, - content: str = None, - content_type: MessageContentType = None, - meta_data: Dict[str, str] = None, + content: Optional[str] = None, + content_type: Optional[MessageContentType] = None, + meta_data: Optional[Dict[str, str]] = None, ) -> Message: """ Modify a message, supporting the modification of message content, additional content, and message type. @@ -150,7 +150,7 @@ def update( "meta_data": meta_data, } - return self._requester.request("post", url, Message, params=params, body=body, data_field="message") + return self._requester.request("post", url, False, Message, params=params, body=body, data_field="message") def delete( self, @@ -174,7 +174,7 @@ def delete( "message_id": message_id, } - return self._requester.request("post", url, Message, params=params) + return self._requester.request("post", url, False, Message, params=params) class _PrivateListMessageResp(CozeModel): first_id: str @@ -200,7 +200,7 @@ async def create( role: MessageRole, content: str, content_type: MessageContentType, - meta_data: Dict[str, str] = None, + meta_data: Optional[Dict[str, str]] = None, ) -> Message: """ Create a message and add it to the specified conversation. @@ -228,16 +228,16 @@ async def create( "meta_data": meta_data, } - return await self._requester.arequest("post", url, Message, params=params, body=body) + return await self._requester.arequest("post", url, False, Message, params=params, body=body) async def list( self, *, conversation_id: str, order: str = "desc", - chat_id: str = None, - before_id: str = None, - after_id: str = None, + chat_id: Optional[str] = None, + before_id: Optional[str] = None, + after_id: Optional[str] = None, limit: int = 50, ) -> LastIDPaged[Message]: """ @@ -266,7 +266,7 @@ async def list( "limit": limit, } - res = await self._requester.arequest("post", url, self._PrivateListMessageResp, params=params, body=body) + res = await self._requester.arequest("post", url, False, self._PrivateListMessageResp, params=params, body=body) return LastIDPaged(res.items, res.first_id, res.last_id, res.has_more) async def retrieve( @@ -291,16 +291,16 @@ async def retrieve( "message_id": message_id, } - return await self._requester.arequest("get", url, Message, params=params) + return await self._requester.arequest("get", url, False, Message, params=params) async def update( self, *, conversation_id: str, message_id: str, - content: str = None, - content_type: MessageContentType = None, - meta_data: Dict[str, str] = None, + content: Optional[str] = None, + content_type: Optional[MessageContentType] = None, + meta_data: Optional[Dict[str, str]] = None, ) -> Message: """ Modify a message, supporting the modification of message content, additional content, and message type. @@ -327,7 +327,9 @@ async def update( "meta_data": meta_data, } - return await self._requester.arequest("post", url, Message, params=params, body=body, data_field="message") + return await self._requester.arequest( + "post", url, False, Message, params=params, body=body, data_field="message" + ) async def delete( self, @@ -351,7 +353,7 @@ async def delete( "message_id": message_id, } - return await self._requester.arequest("post", url, Message, params=params) + return await self._requester.arequest("post", url, False, Message, params=params) class _PrivateListMessageResp(CozeModel): first_id: str diff --git a/cozepy/coze.py b/cozepy/coze.py index 26ca2f5..80b6856 100644 --- a/cozepy/coze.py +++ b/cozepy/coze.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from cozepy.auth import Auth from cozepy.config import COZE_COM_BASE_URL @@ -19,20 +19,20 @@ def __init__( self, auth: Auth, base_url: str = COZE_COM_BASE_URL, - http_client: SyncHTTPClient = None, + http_client: Optional[SyncHTTPClient] = None, ): self._auth = auth self._base_url = base_url self._requester = Requester(auth=auth, sync_client=http_client) # service client - self._bots = None - self._workspaces = None - self._conversations = None - self._chat = None - self._files = None - self._workflows = None - self._knowledge = None + self._bots: Optional[BotsClient] = None + self._workspaces: Optional[WorkspacesClient] = None + self._conversations: Optional[ConversationsClient] = None + self._chat: Optional[ChatClient] = None + self._files: Optional[FilesClient] = None + self._workflows: Optional[WorkflowsClient] = None + self._knowledge: Optional[KnowledgeClient] = None @property def bots(self) -> "BotsClient": @@ -96,20 +96,20 @@ def __init__( self, auth: Auth, base_url: str = COZE_COM_BASE_URL, - http_client: AsyncHTTPClient = None, + http_client: Optional[AsyncHTTPClient] = None, ): self._auth = auth self._base_url = base_url self._requester = Requester(auth=auth, async_client=http_client) # service client - self._bots = None - self._chat = None - self._conversations = None - self._files = None - self._knowledge = None - self._workflows = None - self._workspaces = None + self._bots: Optional[AsyncBotsClient] = None + self._chat: Optional[AsyncChatClient] = None + self._conversations: Optional[AsyncConversationsClient] = None + self._files: Optional[AsyncFilesClient] = None + self._knowledge: Optional[AsyncKnowledgeClient] = None + self._workflows: Optional[AsyncWorkflowsClient] = None + self._workspaces: Optional[AsyncWorkspacesClient] = None @property def bots(self) -> "AsyncBotsClient": diff --git a/cozepy/exception.py b/cozepy/exception.py index 6ffe3bc..9fcb9bb 100644 --- a/cozepy/exception.py +++ b/cozepy/exception.py @@ -1,4 +1,5 @@ from enum import Enum +from typing import Optional class CozeError(Exception): @@ -14,7 +15,7 @@ class CozeAPIError(CozeError): base class for all api errors """ - def __init__(self, code: int = None, msg: str = "", logid: str = None): + def __init__(self, code: Optional[int] = None, msg: str = "", logid: Optional[str] = None): self.code = code self.msg = msg self.logid = logid @@ -39,7 +40,7 @@ class CozePKCEAuthError(CozeError): base class for all pkce auth errors """ - def __init__(self, error: CozePKCEAuthErrorType, logid: str = None): + def __init__(self, error: CozePKCEAuthErrorType, logid: Optional[str] = None): super().__init__(f"pkce auth error: {error.value}") self.error = error self.logid = logid diff --git a/cozepy/files/__init__.py b/cozepy/files/__init__.py index d48e585..34ceb2d 100644 --- a/cozepy/files/__init__.py +++ b/cozepy/files/__init__.py @@ -46,7 +46,7 @@ def upload(self, *, file: str) -> File: """ url = f"{self._base_url}/v1/files/upload" files = {"file": open(file, "rb")} - return self._requester.request("post", url, File, files=files) + return self._requester.request("post", url, False, File, files=files) def retrieve(self, *, file_id: str): """ @@ -62,7 +62,7 @@ def retrieve(self, *, file_id: str): """ url = f"{self._base_url}/v1/files/retrieve" params = {"file_id": file_id} - return self._requester.request("get", url, File, params=params) + return self._requester.request("get", url, False, File, params=params) class AsyncFilesClient(object): @@ -90,7 +90,7 @@ async def upload(self, *, file: str) -> File: """ url = f"{self._base_url}/v1/files/upload" files = {"file": open(file, "rb")} - return await self._requester.arequest("post", url, File, files=files) + return await self._requester.arequest("post", url, False, File, files=files) async def retrieve(self, *, file_id: str): """ @@ -106,4 +106,4 @@ async def retrieve(self, *, file_id: str): """ url = f"{self._base_url}/v1/files/retrieve" params = {"file_id": file_id} - return await self._requester.arequest("get", url, File, params=params) + return await self._requester.arequest("get", url, False, File, params=params) diff --git a/cozepy/knowledge/__init__.py b/cozepy/knowledge/__init__.py index 7475265..18b50cf 100644 --- a/cozepy/knowledge/__init__.py +++ b/cozepy/knowledge/__init__.py @@ -1,3 +1,5 @@ +from typing import Optional + from cozepy.auth import Auth from cozepy.request import Requester @@ -9,7 +11,7 @@ def __init__(self, base_url: str, auth: Auth, requester: Requester): self._base_url = base_url self._auth = auth self._requester = requester - self._documents = None + self._documents: Optional[DocumentsClient] = None @property def documents(self) -> DocumentsClient: @@ -23,7 +25,7 @@ def __init__(self, base_url: str, auth: Auth, requester: Requester): self._base_url = base_url self._auth = auth self._requester = requester - self._documents = None + self._documents: Optional[AsyncDocumentsClient] = None @property def documents(self) -> AsyncDocumentsClient: diff --git a/cozepy/knowledge/documents/__init__.py b/cozepy/knowledge/documents/__init__.py index 30b5039..ceeb95b 100644 --- a/cozepy/knowledge/documents/__init__.py +++ b/cozepy/knowledge/documents/__init__.py @@ -15,13 +15,13 @@ class DocumentChunkStrategy(CozeModel): # 分段设置。取值包括: # 0:自动分段与清洗。采用扣子预置规则进行数据分段与处理。 # 1:自定义。此时需要通过 separator、max_tokens、remove_extra_spaces 和 remove_urls_emails 分段规则细节。 - chunk_type: int = None + chunk_type: Optional[int] = None # Maximum segment length, with a range of 100 to 2000. # Required when chunk_type=1. # 最大分段长度,取值范围为 100~2000。 # 在 chunk_type=1 时必选。 - max_tokens: int = None + max_tokens: Optional[int] = None # Whether to automatically filter continuous spaces, line breaks, and tabs. Values include: # true: Automatically filter @@ -30,7 +30,7 @@ class DocumentChunkStrategy(CozeModel): # true:自动过滤 # false:(默认)不自动过滤 # 在 chunk_type=1 时生效。 - remove_extra_spaces: bool = None + remove_extra_spaces: Optional[bool] = None # Whether to automatically filter all URLs and email addresses. Values include: # true: Automatically filter @@ -40,13 +40,13 @@ class DocumentChunkStrategy(CozeModel): # true:自动过滤 # false:(默认)不自动过滤 # 在 chunk_type=1 时生效。 - remove_urls_emails: bool = None + remove_urls_emails: Optional[bool] = None # Segmentation identifier. # Required when chunk_type=1. # 分段标识符。 # 在 chunk_type=1 时必选。 - separator: str = None + separator: Optional[str] = None @staticmethod def auto() -> "DocumentChunkStrategy": @@ -199,20 +199,20 @@ class Document(CozeModel): class DocumentSourceInfo(CozeModel): # 本地文件的 Base64 编码。 # 上传本地文件时必选 - file_base64: str = None + file_base64: Optional[str] = None # 本地文件格式,即文件后缀,例如 txt。格式支持 pdf、txt、doc、docx 类型。 # 上传的文件类型应与知识库类型匹配,例如 txt 文件只能上传到文档类型的知识库中。 # 上传本地文件时必选 - file_type: str = None + file_type: Optional[str] = None # 网页的 URL 地址。 # 上传在线网页时必选 - web_url: str = None + web_url: Optional[str] = None # 文件的上传方式。支持设置为 1,表示上传在线网页。 # 上传在线网页时必选 - document_source: int = None + document_source: Optional[int] = None @staticmethod def from_local_file(content: str, file_type: str = "txt") -> "DocumentSourceInfo": @@ -249,7 +249,7 @@ class DocumentBase(CozeModel): source_info: DocumentSourceInfo # 在线网页的更新策略。默认不自动更新。 - update_rule: DocumentUpdateRule = None + update_rule: Optional[DocumentUpdateRule] = None class DocumentsClient(object): @@ -263,7 +263,7 @@ def create( *, dataset_id: str, document_bases: List[DocumentBase], - chunk_strategy: DocumentChunkStrategy = None, + chunk_strategy: Optional[DocumentChunkStrategy] = None, ) -> List[Document]: """ Upload files to the specific knowledge. @@ -288,14 +288,16 @@ def create( "document_bases": [i.model_dump() for i in document_bases], "chunk_strategy": chunk_strategy.model_dump() if chunk_strategy else None, } - return self._requester.request("post", url, [Document], headers=headers, body=body, data_field="document_infos") + return self._requester.request( + "post", url, False, [Document], headers=headers, body=body, data_field="document_infos" + ) def update( self, *, document_id: str, - document_name: str = None, - update_rule: DocumentUpdateRule = None, + document_name: Optional[str] = None, + update_rule: Optional[DocumentUpdateRule] = None, ) -> None: """ Modify the knowledge base file name and update strategy. @@ -319,6 +321,7 @@ def update( return self._requester.request( "post", url, + False, None, headers=headers, body=body, @@ -347,6 +350,7 @@ def delete( return self._requester.request( "post", url, + False, None, headers=headers, body=body, @@ -379,7 +383,7 @@ def list( "size": page_size, } headers = {"Agw-Js-Conv": "str"} - res = self._requester.request("post", url, self._PrivateListDocumentsData, body=body, headers=headers) + res = self._requester.request("post", url, False, self._PrivateListDocumentsData, body=body, headers=headers) return NumberPaged( items=res.document_infos, page_num=page_num, @@ -403,7 +407,7 @@ async def create( *, dataset_id: str, document_bases: List[DocumentBase], - chunk_strategy: DocumentChunkStrategy = None, + chunk_strategy: Optional[DocumentChunkStrategy] = None, ) -> List[Document]: """ Upload files to the specific knowledge. @@ -429,15 +433,15 @@ async def create( "chunk_strategy": chunk_strategy.model_dump() if chunk_strategy else None, } return await self._requester.arequest( - "post", url, [Document], headers=headers, body=body, data_field="document_infos" + "post", url, False, [Document], headers=headers, body=body, data_field="document_infos" ) async def update( self, *, document_id: str, - document_name: str = None, - update_rule: DocumentUpdateRule = None, + document_name: Optional[str] = None, + update_rule: Optional[DocumentUpdateRule] = None, ) -> None: """ Modify the knowledge base file name and update strategy. @@ -461,7 +465,8 @@ async def update( return await self._requester.arequest( "post", url, - None, + False, + model=None, headers=headers, body=body, ) @@ -489,7 +494,8 @@ async def delete( return await self._requester.arequest( "post", url, - None, + False, + model=None, headers=headers, body=body, ) @@ -521,7 +527,9 @@ async def list( "size": page_size, } headers = {"Agw-Js-Conv": "str"} - res = await self._requester.arequest("post", url, self._PrivateListDocumentsData, body=body, headers=headers) + res = await self._requester.arequest( + "post", url, False, self._PrivateListDocumentsData, body=body, headers=headers + ) return NumberPaged( items=res.document_infos, page_num=page_num, diff --git a/cozepy/model.py b/cozepy/model.py index 39a1a66..52bebcb 100644 --- a/cozepy/model.py +++ b/cozepy/model.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, Awaitable, Callable, Dict, Generic, Iterator, List, Tuple, TypeVar +from typing import AsyncIterator, Awaitable, Callable, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar from pydantic import BaseModel, ConfigDict @@ -28,7 +28,7 @@ class TokenPaged(PagedBase[T]): return is next_page_token + has_more. """ - def __init__(self, items: List[T], next_page_token: str = "", has_more: bool = None): + def __init__(self, items: List[T], next_page_token: str = "", has_more: Optional[bool] = None): has_more = has_more if has_more is not None else next_page_token != "" super().__init__(items, has_more) self.next_page_token = next_page_token @@ -38,7 +38,7 @@ def __repr__(self): class NumberPaged(PagedBase[T]): - def __init__(self, items: List[T], page_num: int, page_size: int, total: int = None): + def __init__(self, items: List[T], page_num: int, page_size: int, total: Optional[int] = None): has_more = len(items) >= page_size super().__init__(items, has_more) self.page_num = page_num @@ -57,7 +57,7 @@ def __init__( items: List[T], first_id: str = "", last_id: str = "", - has_more: bool = None, + has_more: Optional[bool] = None, ): has_more = has_more if has_more is not None else last_id != "" super().__init__(items, has_more) @@ -123,7 +123,7 @@ def __aiter__(self): return self async def __anext__(self) -> T: - return self._handler(await self._extra_event()) + return self._handler(await self._extra_event()) # type: ignore async def _extra_event(self) -> Dict[str, str]: data = dict(map(lambda x: (x, ""), self._fields)) diff --git a/cozepy/request.py b/cozepy/request.py index 1d94d17..f7ad88d 100644 --- a/cozepy/request.py +++ b/cozepy/request.py @@ -9,11 +9,13 @@ Type, TypeVar, Union, + overload, ) import httpx from httpx import Response from pydantic import BaseModel +from typing_extensions import Literal from cozepy.config import DEFAULT_CONNECTION_LIMITS, DEFAULT_TIMEOUT from cozepy.exception import COZE_PKCE_AUTH_ERROR_TYPE_ENUMS, CozeAPIError, CozePKCEAuthError, CozePKCEAuthErrorType @@ -47,21 +49,82 @@ class Requester(object): http request helper class. """ - def __init__(self, auth: "Auth" = None, sync_client: SyncHTTPClient = None, async_client: AsyncHTTPClient = None): + def __init__( + self, + auth: Optional["Auth"] = None, + sync_client: Optional[SyncHTTPClient] = None, + async_client: Optional[AsyncHTTPClient] = None, + ): self._auth = auth self._sync_client = sync_client self._async_client = async_client + @overload + def request( + self, + method: str, + url: str, + stream: Literal[False], + model: Type[T], + params: dict = ..., + headers: dict = ..., + body: dict = ..., + files: dict = ..., + data_field: str = ..., + ) -> T: ... + + @overload + def request( + self, + method: str, + url: str, + stream: Literal[False], + model: List[Type[T]], + params: dict = ..., + headers: dict = ..., + body: dict = ..., + files: dict = ..., + data_field: str = ..., + ) -> List[T]: ... + + @overload + def request( + self, + method: str, + url: str, + stream: Literal[True], + model: None, + params: dict = ..., + headers: dict = ..., + body: dict = ..., + files: dict = ..., + data_field: str = ..., + ) -> Tuple[Iterator[str], str]: ... + + @overload + def request( + self, + method: str, + url: str, + stream: Literal[False], + model: None, + params: dict = ..., + headers: dict = ..., + body: dict = ..., + files: dict = ..., + data_field: str = ..., + ) -> None: ... + def request( self, method: str, url: str, + stream: Literal[True, False], model: Union[Type[T], List[Type[T]], None], - params: dict = None, - headers: dict = None, - body: dict = None, - files: dict = None, - stream: bool = False, + params: Optional[dict] = None, + headers: Optional[dict] = None, + body: Optional[dict] = None, + files: Optional[dict] = None, data_field: str = "data", ) -> Union[T, List[T], Tuple[Iterator[str], str], None]: """ @@ -80,18 +143,76 @@ def request( log_debug("request %s#%s sending, params=%s, json=%s, stream=%s", method, url, params, body, stream) response = self.sync_client.send(request, stream=stream) - return self._parse_response(method, url, response=response, model=model, stream=stream, data_field=data_field) + return self._parse_response( + method, url, False, response=response, model=model, stream=stream, data_field=data_field + ) + @overload async def arequest( self, method: str, url: str, + stream: Literal[False], + model: Type[T], + params: dict = ..., + headers: dict = ..., + body: dict = ..., + files: dict = ..., + data_field: str = ..., + ) -> T: ... + + @overload + async def arequest( + self, + method: str, + url: str, + stream: Literal[False], + model: List[Type[T]], + params: dict = ..., + headers: dict = ..., + body: dict = ..., + files: dict = ..., + data_field: str = ..., + ) -> List[T]: ... + + @overload + async def arequest( + self, + method: str, + url: str, + stream: Literal[False], + model: None, + params: Optional[dict] = ..., + headers: Optional[dict] = ..., + body: Optional[dict] = ..., + files: Optional[dict] = ..., + data_field: str = ..., + ) -> None: ... + + @overload + async def arequest( + self, + method: str, + url: str, + stream: Literal[True], + model: None, + params: Optional[dict] = ..., + headers: Optional[dict] = ..., + body: Optional[dict] = ..., + files: Optional[dict] = ..., + data_field: str = ..., + ) -> Tuple[AsyncIterator[str], str]: ... + + async def arequest( + self, + method: str, + url: str, + stream: Literal[True, False], model: Union[Type[T], List[Type[T]], None], - params: dict = None, - headers: dict = None, - body: dict = None, - files: dict = None, - stream: bool = False, + params: Optional[dict] = None, + headers: Optional[dict] = None, + body: Optional[dict] = None, + files: Optional[dict] = None, data_field: str = "data", ) -> Union[T, List[T], Tuple[AsyncIterator[str], str], None]: """ @@ -110,7 +231,7 @@ async def arequest( response = await self.async_client.send(request, stream=stream) return self._parse_response( - method, url, response=response, model=model, stream=stream, data_field=data_field, is_async=True + method, url, True, response=response, model=model, stream=stream, data_field=data_field ) @property @@ -129,10 +250,10 @@ def _make_request( self, method: str, url: str, - params: dict = None, - headers: dict = None, - json: dict = None, - files: dict = None, + params: Optional[dict] = None, + headers: Optional[dict] = None, + json: Optional[dict] = None, + files: Optional[dict] = None, ) -> httpx.Request: if headers is None: headers = {} @@ -149,15 +270,39 @@ def _make_request( files=files, ) + @overload + def _parse_response( + self, + method: str, + url: str, + is_async: Literal[False], + response: httpx.Response, + model: Union[Type[T], List[Type[T]], None], + stream: bool = ..., + data_field: str = ..., + ) -> Union[T, List[T], Tuple[Iterator[str], str], None]: ... + + @overload + def _parse_response( + self, + method: str, + url: str, + is_async: Literal[True], + response: httpx.Response, + model: Union[Type[T], List[Type[T]], None], + stream: bool = ..., + data_field: str = ..., + ) -> Union[T, List[T], Tuple[AsyncIterator[str], str], None]: ... + def _parse_response( self, method: str, url: str, + is_async: Literal[True, False], response: httpx.Response, model: Union[Type[T], List[Type[T]], None], stream: bool = False, data_field: str = "data", - is_async: bool = False, ) -> Union[T, List[T], Tuple[Iterator[str], str], Tuple[AsyncIterator[str], str], None]: logid = response.headers.get("x-tt-logid") if stream: diff --git a/cozepy/workflows/__init__.py b/cozepy/workflows/__init__.py index 8e00be8..bd4604e 100644 --- a/cozepy/workflows/__init__.py +++ b/cozepy/workflows/__init__.py @@ -1,11 +1,10 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from cozepy.auth import Auth from cozepy.request import Requester if TYPE_CHECKING: - from .runs import AsyncWorkflowsClient as AsyncWorkflowsRunsClient - from .runs import WorkflowsClient as WorkflowsRunsClient + from .runs import AsyncWorkflowsRunsClient, WorkflowsRunsClient class WorkflowsClient(object): @@ -13,14 +12,14 @@ def __init__(self, base_url: str, auth: Auth, requester: Requester): self._base_url = base_url self._auth = auth self._requester = requester - self._runs = None + self._runs: Optional[WorkflowsRunsClient] = None @property def runs(self) -> "WorkflowsRunsClient": if not self._runs: - from .runs import WorkflowsClient + from .runs import WorkflowsRunsClient - self._runs = WorkflowsClient(self._base_url, self._auth, self._requester) + self._runs = WorkflowsRunsClient(self._base_url, self._auth, self._requester) return self._runs @@ -29,12 +28,12 @@ def __init__(self, base_url: str, auth: Auth, requester: Requester): self._base_url = base_url self._auth = auth self._requester = requester - self._runs = None + self._runs: Optional[AsyncWorkflowsRunsClient] = None @property def runs(self) -> "AsyncWorkflowsRunsClient": if not self._runs: - from .runs import AsyncWorkflowsClient + from .runs import AsyncWorkflowsRunsClient - self._runs = AsyncWorkflowsClient(self._base_url, self._auth, self._requester) + self._runs = AsyncWorkflowsRunsClient(self._base_url, self._auth, self._requester) return self._runs diff --git a/cozepy/workflows/runs/__init__.py b/cozepy/workflows/runs/__init__.py index 40574f4..ee173f2 100644 --- a/cozepy/workflows/runs/__init__.py +++ b/cozepy/workflows/runs/__init__.py @@ -1,6 +1,6 @@ from enum import Enum from functools import partial -from typing import Any, Dict +from typing import Any, Awaitable, Callable, Dict, Optional, cast from cozepy.auth import Auth from cozepy.model import AsyncStream, CozeModel, Stream @@ -52,7 +52,7 @@ class WorkflowEventMessage(CozeModel): # Additional fields. # 额外字段。 - ext: Dict[str, Any] = None + ext: Optional[Dict[str, Any]] = None class WorkflowEventInterruptData(CozeModel): @@ -93,17 +93,17 @@ class WorkflowEvent(CozeModel): # The current streaming data packet event. event: WorkflowEventType - message: WorkflowEventMessage = None + message: Optional[WorkflowEventMessage] = None - interrupt: WorkflowEventInterrupt = None + interrupt: Optional[WorkflowEventInterrupt] = None - error: WorkflowEventError = None + error: Optional[WorkflowEventError] = None def _workflow_stream_handler(data: Dict[str, str], is_async: bool = False) -> WorkflowEvent: - id = data["id"] + id = int(data["id"]) event = data["event"] - data = data["data"] + event_data = data["data"] # type: str if event == WorkflowEventType.DONE: if is_async: raise StopAsyncIteration @@ -112,25 +112,29 @@ def _workflow_stream_handler(data: Dict[str, str], is_async: bool = False) -> Wo return WorkflowEvent( id=id, event=event, - message=WorkflowEventMessage.model_validate_json(data), + message=WorkflowEventMessage.model_validate_json(event_data), ) elif event == WorkflowEventType.ERROR: - return WorkflowEvent(id=id, event=event, error=WorkflowEventError.model_validate_json(data)) + return WorkflowEvent(id=id, event=event, error=WorkflowEventError.model_validate_json(event_data)) elif event == WorkflowEventType.INTERRUPT: return WorkflowEvent( id=id, event=event, - interrupt=WorkflowEventInterrupt.model_validate_json(data), + interrupt=WorkflowEventInterrupt.model_validate_json(event_data), ) else: - raise ValueError(f"invalid workflows.event: {event}, {data}") + raise ValueError(f"invalid workflows.event: {event}, {event_data}") -_sync_workflow_stream_handler = partial(_workflow_stream_handler, is_async=False) -_async_workflow_stream_handler = partial(_workflow_stream_handler, is_async=True) +_sync_workflow_stream_handler = cast( + Callable[[Dict[str, str]], WorkflowEvent], partial(_workflow_stream_handler, is_async=False) +) +_async_workflow_stream_handler = cast( + Callable[[Dict[str, str]], Awaitable[WorkflowEvent]], partial(_workflow_stream_handler, is_async=True) +) -class WorkflowsClient(object): +class WorkflowsRunsClient(object): def __init__(self, base_url: str, auth: Auth, requester: Requester): self._base_url = base_url self._auth = auth @@ -140,9 +144,9 @@ def create( self, *, workflow_id: str, - parameters: Dict[str, Any] = None, - bot_id: str = None, - ext: Dict[str, Any] = None, + parameters: Optional[Dict[str, Any]] = None, + bot_id: Optional[str] = None, + ext: Optional[Dict[str, Any]] = None, ) -> WorkflowRunResult: """ Run the published workflow. @@ -167,15 +171,15 @@ def create( "bot_id": bot_id, "ext": ext, } - return self._requester.request("post", url, WorkflowRunResult, body=body) + return self._requester.request("post", url, False, WorkflowRunResult, body=body) def stream( self, *, workflow_id: str, - parameters: Dict[str, Any] = None, - bot_id: str = None, - ext: Dict[str, Any] = None, + parameters: Optional[Dict[str, Any]] = None, + bot_id: Optional[str] = None, + ext: Optional[Dict[str, Any]] = None, ) -> Stream[WorkflowEvent]: """ Execute the published workflow with a streaming response method. @@ -198,7 +202,13 @@ def stream( "bot_id": bot_id, "ext": ext, } - steam_iters, logid = self._requester.request("post", url, None, body=body, stream=True) + steam_iters, logid = self._requester.request( + "post", + url, + True, + None, + body=body, + ) return Stream(steam_iters, fields=["id", "event", "data"], handler=_sync_workflow_stream_handler, logid=logid) def resume( @@ -225,11 +235,17 @@ def resume( "resume_data": resume_data, "interrupt_type": interrupt_type, } - steam_iters, logid = self._requester.request("post", url, None, body=body, stream=True) + steam_iters, logid = self._requester.request( + "post", + url, + True, + None, + body=body, + ) return Stream(steam_iters, fields=["id", "event", "data"], handler=_sync_workflow_stream_handler, logid=logid) -class AsyncWorkflowsClient(object): +class AsyncWorkflowsRunsClient(object): def __init__(self, base_url: str, auth: Auth, requester: Requester): self._base_url = base_url self._auth = auth @@ -239,9 +255,9 @@ async def create( self, *, workflow_id: str, - parameters: Dict[str, Any] = None, - bot_id: str = None, - ext: Dict[str, Any] = None, + parameters: Optional[Dict[str, Any]] = None, + bot_id: Optional[str] = None, + ext: Optional[Dict[str, Any]] = None, ) -> WorkflowRunResult: """ Run the published workflow. @@ -266,15 +282,15 @@ async def create( "bot_id": bot_id, "ext": ext, } - return await self._requester.arequest("post", url, WorkflowRunResult, body=body) + return await self._requester.arequest("post", url, False, WorkflowRunResult, body=body) async def stream( self, *, workflow_id: str, - parameters: Dict[str, Any] = None, - bot_id: str = None, - ext: Dict[str, Any] = None, + parameters: Optional[Dict[str, Any]] = None, + bot_id: Optional[str] = None, + ext: Optional[Dict[str, Any]] = None, ) -> AsyncStream[WorkflowEvent]: """ Execute the published workflow with a streaming response method. @@ -297,7 +313,13 @@ async def stream( "bot_id": bot_id, "ext": ext, } - steam_iters, logid = await self._requester.arequest("post", url, None, body=body, stream=True) + steam_iters, logid = await self._requester.arequest( + "post", + url, + True, + None, + body=body, + ) return AsyncStream( steam_iters, fields=["id", "event", "data"], handler=_async_workflow_stream_handler, logid=logid ) @@ -326,7 +348,13 @@ async def resume( "resume_data": resume_data, "interrupt_type": interrupt_type, } - steam_iters, logid = await self._requester.arequest("post", url, None, body=body, stream=True) + steam_iters, logid = await self._requester.arequest( + "post", + url, + True, + None, + body=body, + ) return AsyncStream( steam_iters, fields=["id", "event", "data"], handler=_async_workflow_stream_handler, logid=logid ) diff --git a/cozepy/workspaces/__init__.py b/cozepy/workspaces/__init__.py index 68abc74..deb2547 100644 --- a/cozepy/workspaces/__init__.py +++ b/cozepy/workspaces/__init__.py @@ -49,6 +49,7 @@ def list(self, *, page_num: int = 1, page_size: int = 20, headers=None) -> Numbe data = self._requester.request( "get", url, + False, self._PrivateListPublishedBotsV1Data, headers=headers, params=params, @@ -84,6 +85,7 @@ async def list(self, *, page_num: int = 1, page_size: int = 20, headers=None) -> data = await self._requester.arequest( "get", url, + False, self._PrivateListPublishedBotsV1Data, headers=headers, params=params, diff --git a/poetry.lock b/poetry.lock index afccc76..5d1bee5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -451,6 +451,64 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "mypy" +version = "1.4.1" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mypy-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:566e72b0cd6598503e48ea610e0052d1b8168e60a46e0bfd34b3acf2d57f96a8"}, + {file = "mypy-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ca637024ca67ab24a7fd6f65d280572c3794665eaf5edcc7e90a866544076878"}, + {file = "mypy-1.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0dde1d180cd84f0624c5dcaaa89c89775550a675aff96b5848de78fb11adabcd"}, + {file = "mypy-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8c4d8e89aa7de683e2056a581ce63c46a0c41e31bd2b6d34144e2c80f5ea53dc"}, + {file = "mypy-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:bfdca17c36ae01a21274a3c387a63aa1aafe72bff976522886869ef131b937f1"}, + {file = "mypy-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7549fbf655e5825d787bbc9ecf6028731973f78088fbca3a1f4145c39ef09462"}, + {file = "mypy-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:98324ec3ecf12296e6422939e54763faedbfcc502ea4a4c38502082711867258"}, + {file = "mypy-1.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:141dedfdbfe8a04142881ff30ce6e6653c9685b354876b12e4fe6c78598b45e2"}, + {file = "mypy-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8207b7105829eca6f3d774f64a904190bb2231de91b8b186d21ffd98005f14a7"}, + {file = "mypy-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:16f0db5b641ba159eff72cff08edc3875f2b62b2fa2bc24f68c1e7a4e8232d01"}, + {file = "mypy-1.4.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:470c969bb3f9a9efcedbadcd19a74ffb34a25f8e6b0e02dae7c0e71f8372f97b"}, + {file = "mypy-1.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5952d2d18b79f7dc25e62e014fe5a23eb1a3d2bc66318df8988a01b1a037c5b"}, + {file = "mypy-1.4.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:190b6bab0302cec4e9e6767d3eb66085aef2a1cc98fe04936d8a42ed2ba77bb7"}, + {file = "mypy-1.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9d40652cc4fe33871ad3338581dca3297ff5f2213d0df345bcfbde5162abf0c9"}, + {file = "mypy-1.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:01fd2e9f85622d981fd9063bfaef1aed6e336eaacca00892cd2d82801ab7c042"}, + {file = "mypy-1.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2460a58faeea905aeb1b9b36f5065f2dc9a9c6e4c992a6499a2360c6c74ceca3"}, + {file = "mypy-1.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2746d69a8196698146a3dbe29104f9eb6a2a4d8a27878d92169a6c0b74435b6"}, + {file = "mypy-1.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ae704dcfaa180ff7c4cfbad23e74321a2b774f92ca77fd94ce1049175a21c97f"}, + {file = "mypy-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:43d24f6437925ce50139a310a64b2ab048cb2d3694c84c71c3f2a1626d8101dc"}, + {file = "mypy-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c482e1246726616088532b5e964e39765b6d1520791348e6c9dc3af25b233828"}, + {file = "mypy-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:43b592511672017f5b1a483527fd2684347fdffc041c9ef53428c8dc530f79a3"}, + {file = "mypy-1.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:34a9239d5b3502c17f07fd7c0b2ae6b7dd7d7f6af35fbb5072c6208e76295816"}, + {file = "mypy-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5703097c4936bbb9e9bce41478c8d08edd2865e177dc4c52be759f81ee4dd26c"}, + {file = "mypy-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e02d700ec8d9b1859790c0475df4e4092c7bf3272a4fd2c9f33d87fac4427b8f"}, + {file = "mypy-1.4.1-py3-none-any.whl", hash = "sha256:45d32cec14e7b97af848bddd97d85ea4f0db4d5a149ed9676caa4eb2f7402bb4"}, + {file = "mypy-1.4.1.tar.gz", hash = "sha256:9bbcd9ab8ea1f2e1c8031c21445b511442cc45c89951e49bbf852cbb70755b1b"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typed-ast = {version = ">=1.4.0,<2", markers = "python_version < \"3.8\""} +typing-extensions = ">=4.1.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +install-types = ["pip"] +python2 = ["typed-ast (>=1.4.0,<2)"] +reports = ["lxml"] + +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + [[package]] name = "nodeenv" version = "1.9.1" @@ -858,6 +916,56 @@ files = [ {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, ] +[[package]] +name = "typed-ast" +version = "1.5.5" +description = "a fork of Python 2 and 3 ast modules with type comment support" +optional = false +python-versions = ">=3.6" +files = [ + {file = "typed_ast-1.5.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4bc1efe0ce3ffb74784e06460f01a223ac1f6ab31c6bc0376a21184bf5aabe3b"}, + {file = "typed_ast-1.5.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5f7a8c46a8b333f71abd61d7ab9255440d4a588f34a21f126bbfc95f6049e686"}, + {file = "typed_ast-1.5.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:597fc66b4162f959ee6a96b978c0435bd63791e31e4f410622d19f1686d5e769"}, + {file = "typed_ast-1.5.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d41b7a686ce653e06c2609075d397ebd5b969d821b9797d029fccd71fdec8e04"}, + {file = "typed_ast-1.5.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5fe83a9a44c4ce67c796a1b466c270c1272e176603d5e06f6afbc101a572859d"}, + {file = "typed_ast-1.5.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d5c0c112a74c0e5db2c75882a0adf3133adedcdbfd8cf7c9d6ed77365ab90a1d"}, + {file = "typed_ast-1.5.5-cp310-cp310-win_amd64.whl", hash = "sha256:e1a976ed4cc2d71bb073e1b2a250892a6e968ff02aa14c1f40eba4f365ffec02"}, + {file = "typed_ast-1.5.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c631da9710271cb67b08bd3f3813b7af7f4c69c319b75475436fcab8c3d21bee"}, + {file = "typed_ast-1.5.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b445c2abfecab89a932b20bd8261488d574591173d07827c1eda32c457358b18"}, + {file = "typed_ast-1.5.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc95ffaaab2be3b25eb938779e43f513e0e538a84dd14a5d844b8f2932593d88"}, + {file = "typed_ast-1.5.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61443214d9b4c660dcf4b5307f15c12cb30bdfe9588ce6158f4a005baeb167b2"}, + {file = "typed_ast-1.5.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6eb936d107e4d474940469e8ec5b380c9b329b5f08b78282d46baeebd3692dc9"}, + {file = "typed_ast-1.5.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e48bf27022897577d8479eaed64701ecaf0467182448bd95759883300ca818c8"}, + {file = "typed_ast-1.5.5-cp311-cp311-win_amd64.whl", hash = "sha256:83509f9324011c9a39faaef0922c6f720f9623afe3fe220b6d0b15638247206b"}, + {file = "typed_ast-1.5.5-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:44f214394fc1af23ca6d4e9e744804d890045d1643dd7e8229951e0ef39429b5"}, + {file = "typed_ast-1.5.5-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:118c1ce46ce58fda78503eae14b7664163aa735b620b64b5b725453696f2a35c"}, + {file = "typed_ast-1.5.5-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be4919b808efa61101456e87f2d4c75b228f4e52618621c77f1ddcaae15904fa"}, + {file = "typed_ast-1.5.5-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:fc2b8c4e1bc5cd96c1a823a885e6b158f8451cf6f5530e1829390b4d27d0807f"}, + {file = "typed_ast-1.5.5-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:16f7313e0a08c7de57f2998c85e2a69a642e97cb32f87eb65fbfe88381a5e44d"}, + {file = "typed_ast-1.5.5-cp36-cp36m-win_amd64.whl", hash = "sha256:2b946ef8c04f77230489f75b4b5a4a6f24c078be4aed241cfabe9cbf4156e7e5"}, + {file = "typed_ast-1.5.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2188bc33d85951ea4ddad55d2b35598b2709d122c11c75cffd529fbc9965508e"}, + {file = "typed_ast-1.5.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0635900d16ae133cab3b26c607586131269f88266954eb04ec31535c9a12ef1e"}, + {file = "typed_ast-1.5.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:57bfc3cf35a0f2fdf0a88a3044aafaec1d2f24d8ae8cd87c4f58d615fb5b6311"}, + {file = "typed_ast-1.5.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:fe58ef6a764de7b4b36edfc8592641f56e69b7163bba9f9c8089838ee596bfb2"}, + {file = "typed_ast-1.5.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d09d930c2d1d621f717bb217bf1fe2584616febb5138d9b3e8cdd26506c3f6d4"}, + {file = "typed_ast-1.5.5-cp37-cp37m-win_amd64.whl", hash = "sha256:d40c10326893ecab8a80a53039164a224984339b2c32a6baf55ecbd5b1df6431"}, + {file = "typed_ast-1.5.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:fd946abf3c31fb50eee07451a6aedbfff912fcd13cf357363f5b4e834cc5e71a"}, + {file = "typed_ast-1.5.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ed4a1a42df8a3dfb6b40c3d2de109e935949f2f66b19703eafade03173f8f437"}, + {file = "typed_ast-1.5.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:045f9930a1550d9352464e5149710d56a2aed23a2ffe78946478f7b5416f1ede"}, + {file = "typed_ast-1.5.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:381eed9c95484ceef5ced626355fdc0765ab51d8553fec08661dce654a935db4"}, + {file = "typed_ast-1.5.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:bfd39a41c0ef6f31684daff53befddae608f9daf6957140228a08e51f312d7e6"}, + {file = "typed_ast-1.5.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8c524eb3024edcc04e288db9541fe1f438f82d281e591c548903d5b77ad1ddd4"}, + {file = "typed_ast-1.5.5-cp38-cp38-win_amd64.whl", hash = "sha256:7f58fabdde8dcbe764cef5e1a7fcb440f2463c1bbbec1cf2a86ca7bc1f95184b"}, + {file = "typed_ast-1.5.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:042eb665ff6bf020dd2243307d11ed626306b82812aba21836096d229fdc6a10"}, + {file = "typed_ast-1.5.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:622e4a006472b05cf6ef7f9f2636edc51bda670b7bbffa18d26b255269d3d814"}, + {file = "typed_ast-1.5.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1efebbbf4604ad1283e963e8915daa240cb4bf5067053cf2f0baadc4d4fb51b8"}, + {file = "typed_ast-1.5.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0aefdd66f1784c58f65b502b6cf8b121544680456d1cebbd300c2c813899274"}, + {file = "typed_ast-1.5.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:48074261a842acf825af1968cd912f6f21357316080ebaca5f19abbb11690c8a"}, + {file = "typed_ast-1.5.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:429ae404f69dc94b9361bb62291885894b7c6fb4640d561179548c849f8492ba"}, + {file = "typed_ast-1.5.5-cp39-cp39-win_amd64.whl", hash = "sha256:335f22ccb244da2b5c296e6f96b06ee9bed46526db0de38d2f0e5a6597b81155"}, + {file = "typed_ast-1.5.5.tar.gz", hash = "sha256:94282f7a354f36ef5dbce0ef3467ebf6a258e370ab33d5b40c249fa996e590dd"}, +] + [[package]] name = "typing-extensions" version = "4.7.1" @@ -908,4 +1016,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.7" -content-hash = "dcb3d43044689ce57f32d0d15c55fed77aa38dabf9197cc22e635346e3aab130" +content-hash = "7f8efa60c14958d3af1d0c4bf574ce658c263852c45ebbcc6acf9d58ac2c1d56" diff --git a/pyproject.toml b/pyproject.toml index 738da26..5d6834c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ pytest-asyncio = "^0.21.0" ruff = "^0.6.0" pre-commit = "^2.9.0" respx = "^0.21.1" +mypy = "^1.0.0" [tool.ruff] line-length = 120 diff --git a/tests/test_bot.py b/tests/test_bot.py index 1a9db57..f43675d 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -6,7 +6,7 @@ @pytest.mark.respx(base_url="https://api.coze.com") class TestBot: - def test_create(self, respx_mock): + def test_bot_create(self, respx_mock): coze = Coze(auth=TokenAuth(token="token")) respx_mock.post("/v1/bot/create").mock( @@ -30,14 +30,14 @@ def test_create(self, respx_mock): assert bot assert bot.bot_id == "bot_id" - def test_update(self, respx_mock): + def test_bot_update(self, respx_mock): coze = Coze(auth=TokenAuth(token="token")) respx_mock.post("/v1/bot/update").mock(httpx.Response(200, json={"data": None})) coze.bots.update(bot_id="bot id", name="name") - def test_publish(self, respx_mock): + def test_bot_publish(self, respx_mock): coze = Coze(auth=TokenAuth(token="token")) respx_mock.post("/v1/bot/publish").mock( @@ -117,7 +117,7 @@ def test_list(self, respx_mock): @pytest.mark.respx(base_url="https://api.coze.com") @pytest.mark.asyncio class TestAsyncBot: - async def test_create(self, respx_mock): + async def test_bot_create(self, respx_mock): coze = AsyncCoze(auth=TokenAuth(token="token")) respx_mock.post("/v1/bot/create").mock( @@ -141,7 +141,7 @@ async def test_create(self, respx_mock): assert bot assert bot.bot_id == "bot_id" - async def test_update(self, respx_mock): + async def test_bot_update(self, respx_mock): coze = AsyncCoze(auth=TokenAuth(token="token")) respx_mock.post("/v1/bot/update").mock(httpx.Response(200, json={"data": None})) diff --git a/tests/test_chat.py b/tests/test_chat.py index 7d9dfdb..eaa79f1 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -48,8 +48,8 @@ @pytest.mark.respx(base_url="https://api.coze.com") -class TestConversationMessage: - def test_create(self, respx_mock): +class TestChat: + def test_chat_create(self, respx_mock): coze = Coze(auth=TokenAuth(token="token")) respx_mock.post("/v3/chat").mock(httpx.Response(200, json={"data": chat_testdata.model_dump()})) @@ -58,7 +58,7 @@ def test_create(self, respx_mock): assert res assert res.conversation_id == chat_testdata.conversation_id - def test_stream(self, respx_mock): + def test_chat_stream(self, respx_mock): coze = Coze(auth=TokenAuth(token="token")) respx_mock.post("/v3/chat").mock(httpx.Response(200, content=chat_stream_testdata)) @@ -81,7 +81,7 @@ def test_stream(self, respx_mock): ) assert events[len(events) - 1].event == ChatEventType.CONVERSATION_CHAT_COMPLETED - def test_stream_error(self, respx_mock): + def test_chat_stream_error(self, respx_mock): coze = Coze(auth=TokenAuth(token="token")) respx_mock.post("/v3/chat").mock( @@ -96,7 +96,7 @@ def test_stream_error(self, respx_mock): with pytest.raises(Exception, match="error event"): list(coze.chat.stream(bot_id="bot", user_id="user")) - def test_stream_invalid_event(self, respx_mock): + def test_chat_stream_invalid_event(self, respx_mock): coze = Coze(auth=TokenAuth(token="token")) respx_mock.post("/v3/chat").mock( @@ -111,7 +111,7 @@ def test_stream_invalid_event(self, respx_mock): with pytest.raises(Exception, match="invalid chat.event: invalid"): list(coze.chat.stream(bot_id="bot", user_id="user")) - def test_retrieve(self, respx_mock): + def test_chat_retrieve(self, respx_mock): coze = Coze(auth=TokenAuth(token="token")) respx_mock.post("/v3/chat/retrieve").mock(httpx.Response(200, json={"data": chat_testdata.model_dump()})) @@ -285,7 +285,7 @@ async def test_submit_tool_outputs_stream(self, respx_mock): ) assert events[len(events) - 1].event == ChatEventType.CONVERSATION_CHAT_COMPLETED - async def test_cancel(self, respx_mock): + async def test_chat_cancel(self, respx_mock): coze = AsyncCoze(auth=TokenAuth(token="token")) respx_mock.post("/v3/chat/cancel").mock(httpx.Response(200, json={"data": chat_testdata.model_dump()})) diff --git a/tests/test_chat_message.py b/tests/test_chat_message.py index f49f54c..3dfb0ce 100644 --- a/tests/test_chat_message.py +++ b/tests/test_chat_message.py @@ -26,7 +26,7 @@ def test_create(self, respx_mock): @pytest.mark.respx(base_url="https://api.coze.com") @pytest.mark.asyncio class TestAsyncChatMessage: - async def test_create(self, respx_mock): + async def test_chat_message_list(self, respx_mock): coze = AsyncCoze(auth=TokenAuth(token="token")) msg = Message.user_text_message("hi") diff --git a/tests/test_conversation.py b/tests/test_conversation.py index 91b83e0..7143b26 100644 --- a/tests/test_conversation.py +++ b/tests/test_conversation.py @@ -16,7 +16,7 @@ def test_create(self, respx_mock): assert res assert res.id == conversation.id - def test_retrieve(self, respx_mock): + def test_conversations_retrieve(self, respx_mock): coze = Coze(auth=TokenAuth(token="token")) conversation = Conversation(id="id", created_at=1, meta_data={}) @@ -30,7 +30,7 @@ def test_retrieve(self, respx_mock): @pytest.mark.respx(base_url="https://api.coze.com") @pytest.mark.asyncio class TestAsyncConversation: - async def test_create(self, respx_mock): + async def test_conversation_create(self, respx_mock): coze = AsyncCoze(auth=TokenAuth(token="token")) conversation = Conversation(id="id", created_at=1, meta_data={}) diff --git a/tests/test_conversation_message.py b/tests/test_conversation_message.py index 1230014..11b5ce1 100644 --- a/tests/test_conversation_message.py +++ b/tests/test_conversation_message.py @@ -6,7 +6,7 @@ @pytest.mark.respx(base_url="https://api.coze.com") class TestConversationMessage: - def test_create(self, respx_mock): + def test_conversations_messages_create(self, respx_mock): coze = Coze(auth=TokenAuth(token="token")) msg = Message.assistant_text_message("hi") @@ -18,7 +18,7 @@ def test_create(self, respx_mock): assert message assert message.content == msg.content - def test_list(self, respx_mock): + def test_conversations_messages_list(self, respx_mock): coze = Coze(auth=TokenAuth(token="token")) msg = Message.user_text_message("hi") @@ -32,7 +32,7 @@ def test_list(self, respx_mock): assert message_list assert len(message_list.items) == 1 - def test_retrieve(self, respx_mock): + def test_conversations_messages_retrieve(self, respx_mock): coze = Coze(auth=TokenAuth(token="token")) msg = Message.user_text_message("hi") @@ -42,7 +42,7 @@ def test_retrieve(self, respx_mock): assert message assert message.content == msg.content - def test_update(self, respx_mock): + def test_conversations_messages_update(self, respx_mock): coze = Coze(auth=TokenAuth(token="token")) msg = Message.user_text_message("hi") @@ -52,7 +52,7 @@ def test_update(self, respx_mock): assert message assert message.content == msg.content - def test_delete(self, respx_mock): + def test_conversations_messages_delete(self, respx_mock): coze = Coze(auth=TokenAuth(token="token")) msg = Message.user_text_message("hi") diff --git a/tests/test_file.py b/tests/test_file.py index 56e873d..cf5373f 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -8,7 +8,7 @@ @pytest.mark.respx(base_url="https://api.coze.com") class TestFile: - def test_create(self, respx_mock): + def test_file_create(self, respx_mock): coze = Coze(auth=TokenAuth(token="token")) with patch("builtins.open", mock_open(read_data="data")): @@ -20,7 +20,7 @@ def test_create(self, respx_mock): assert file assert "name" == file.file_name - def test_retrieve(self, respx_mock): + def test_file_retrieve(self, respx_mock): coze = Coze(auth=TokenAuth(token="token")) respx_mock.get("/v1/files/retrieve").mock( diff --git a/tests/test_knowledge_documents.py b/tests/test_knowledge_documents.py index c8bbbf9..413138f 100644 --- a/tests/test_knowledge_documents.py +++ b/tests/test_knowledge_documents.py @@ -93,14 +93,14 @@ def test_update(self, respx_mock): coze.knowledge.documents.update(document_id="id", document_name="name") - def test_delete(self, respx_mock): + def test_knowledge_documents_delete(self, respx_mock): coze = Coze(auth=TokenAuth(token="token")) respx_mock.post("/open_api/knowledge/document/delete").mock(httpx.Response(200, json={"data": None})) coze.knowledge.documents.delete(document_ids=["id"]) - def test_list(self, respx_mock): + def test_knowledge_documents_list(self, respx_mock): coze = Coze(auth=TokenAuth(token="token")) respx_mock.post("/open_api/knowledge/document/list").mock( diff --git a/tests/test_request.py b/tests/test_request.py index 343b6d2..71bf0ab 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -23,7 +23,7 @@ def test_code_msg(self, respx_mock): ) with pytest.raises(CozeAPIError, match="code: 100, msg: request failed, logid: mock-logid"): - Requester().request("post", "https://api.coze.com/api/test", ModelForTest) + Requester().request("post", "https://api.coze.com/api/test", False, ModelForTest) def test_auth_slow_down(self, respx_mock): respx_mock.post("/api/test").mock( @@ -37,7 +37,7 @@ def test_auth_slow_down(self, respx_mock): ) with pytest.raises(CozePKCEAuthError, match="pkce auth error: slow_down"): - Requester().request("post", "https://api.coze.com/api/test", ModelForTest) + Requester().request("post", "https://api.coze.com/api/test", False, ModelForTest) def test_error_message(self, respx_mock): respx_mock.post("/api/test").mock( @@ -51,7 +51,7 @@ def test_error_message(self, respx_mock): ) with pytest.raises(CozeAPIError, match="msg: error_message, logid: mock-logid"): - Requester().request("post", "https://api.coze.com/api/test", ModelForTest) + Requester().request("post", "https://api.coze.com/api/test", False, ModelForTest) def test_debug_url(self, respx_mock): respx_mock.post("/api/test").mock( @@ -65,7 +65,7 @@ def test_debug_url(self, respx_mock): ) ) - Requester().request("post", "https://api.coze.com/api/test", DebugModelForTest) + Requester().request("post", "https://api.coze.com/api/test", False, DebugModelForTest) @pytest.mark.respx(base_url="https://api.coze.com") @@ -77,7 +77,7 @@ async def test_code_msg(self, respx_mock): ) with pytest.raises(CozeAPIError, match="code: 100, msg: request failed, logid: mock-logid"): - await Requester().arequest("post", "https://api.coze.com/api/test", ModelForTest) + await Requester().arequest("post", "https://api.coze.com/api/test", False, ModelForTest) async def test_auth_slow_down(self, respx_mock): respx_mock.post("/api/test").mock( @@ -91,7 +91,7 @@ async def test_auth_slow_down(self, respx_mock): ) with pytest.raises(CozePKCEAuthError, match="pkce auth error: slow_down"): - await Requester().arequest("post", "https://api.coze.com/api/test", ModelForTest) + await Requester().arequest("post", "https://api.coze.com/api/test", False, ModelForTest) async def test_error_message(self, respx_mock): respx_mock.post("/api/test").mock( @@ -105,7 +105,7 @@ async def test_error_message(self, respx_mock): ) with pytest.raises(CozeAPIError, match="msg: error_message, logid: mock-logid"): - await Requester().arequest("post", "https://api.coze.com/api/test", ModelForTest) + await Requester().arequest("post", "https://api.coze.com/api/test", False, ModelForTest) async def test_debug_url(self, respx_mock): respx_mock.post("/api/test").mock( @@ -119,4 +119,4 @@ async def test_debug_url(self, respx_mock): ) ) - await Requester().arequest("post", "https://api.coze.com/api/test", DebugModelForTest) + await Requester().arequest("post", "https://api.coze.com/api/test", False, DebugModelForTest)