Skip to content

Commit

Permalink
chore: Correct type annotation to accurately identify (#60)
Browse files Browse the repository at this point in the history
- Fixed type annotation issue to correctly recognize model's Iterable
type
- Ensured proper unpacking of sub-elements within the Iterable
  • Loading branch information
chyroc authored Oct 9, 2024
1 parent db19fc5 commit 24a371b
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 13 deletions.
4 changes: 2 additions & 2 deletions cozepy/chat/message/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def list(
"conversation_id": conversation_id,
"chat_id": chat_id,
}
return self._requester.request("post", url, List[Message], params=params)
return self._requester.request("post", url, [Message], params=params)


class AsyncMessagesClient(object):
Expand Down Expand Up @@ -64,4 +64,4 @@ async def list(
"conversation_id": conversation_id,
"chat_id": chat_id,
}
return await self._requester.arequest("post", url, List[Message], params=params)
return await self._requester.arequest("post", url, [Message], params=params)
6 changes: 2 additions & 4 deletions cozepy/knowledge/documents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,7 @@ def create(
"document_bases": [i.model_dump() for i in document_bases],
"chunk_strategy": chunk_strategy.model_dump() if chunk_strategy else None,
}
return self._requester.request(
"post", url, List[Document], headers=headers, body=body, data_field="document_infos"
)
return self._requester.request("post", url, [Document], headers=headers, body=body, data_field="document_infos")

def update(
self,
Expand Down Expand Up @@ -431,7 +429,7 @@ async def create(
"chunk_strategy": chunk_strategy.model_dump() if chunk_strategy else None,
}
return await self._requester.arequest(
"post", url, List[Document], headers=headers, body=body, data_field="document_infos"
"post", url, [Document], headers=headers, body=body, data_field="document_infos"
)

async def update(
Expand Down
12 changes: 5 additions & 7 deletions cozepy/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
TYPE_CHECKING,
Any,
AsyncIterator,
Iterable,
Iterator,
List,
Optional,
Expand All @@ -15,7 +14,6 @@
import httpx
from httpx import Response
from pydantic import BaseModel
from typing_extensions import get_args, get_origin

from cozepy.config import DEFAULT_CONNECTION_LIMITS, DEFAULT_TIMEOUT
from cozepy.exception import COZE_PKCE_AUTH_ERROR_TYPE_ENUMS, CozeAPIError, CozePKCEAuthError, CozePKCEAuthErrorType
Expand Down Expand Up @@ -58,7 +56,7 @@ def request(
self,
method: str,
url: str,
model: Union[Type[T], Type[Iterable[T]], None],
model: Union[Type[T], List[Type[T]], None],
params: dict = None,
headers: dict = None,
body: dict = None,
Expand Down Expand Up @@ -88,7 +86,7 @@ async def arequest(
self,
method: str,
url: str,
model: Union[Type[T], Type[Iterable[T]], None],
model: Union[Type[T], List[Type[T]], None],
params: dict = None,
headers: dict = None,
body: dict = None,
Expand Down Expand Up @@ -156,7 +154,7 @@ def _parse_response(
method: str,
url: str,
response: httpx.Response,
model: Union[Type[T], Iterable[Type[T]], None],
model: Union[Type[T], List[Type[T]], None],
stream: bool = False,
data_field: str = "data",
is_async: bool = False,
Expand All @@ -177,8 +175,8 @@ def _parse_response(
if msg in COZE_PKCE_AUTH_ERROR_TYPE_ENUMS:
raise CozePKCEAuthError(CozePKCEAuthErrorType(msg), logid)
raise CozeAPIError(code, msg, logid)
if get_origin(model) is list:
item_model = get_args(model)[0]
if isinstance(model, List):
item_model = model[0]
return [item_model.model_validate(item) for item in data]
else:
if model is None:
Expand Down

0 comments on commit 24a371b

Please sign in to comment.