Skip to content

Commit

Permalink
chore: Correct type annotations and add mypy for code checking
Browse files Browse the repository at this point in the history
  • Loading branch information
chyroc committed Oct 9, 2024
1 parent 24a371b commit 05bc61e
Show file tree
Hide file tree
Showing 28 changed files with 767 additions and 307 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ jobs:
pip install poetry
poetry install
- name: Build
run: |
poetry build
- name: Check
run: |
poetry run ruff check cozepy
poetry run ruff format --check
poetry build
poetry run mypy .
- name: Run tests
run: poetry run pytest --cov --cov-report=xml
- name: Upload coverage to Codecov
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
.idea/
.venv/
.venv*/
.DS_Store
__pycache__/
dist/
Expand Down
108 changes: 81 additions & 27 deletions cozepy/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import List, Optional
from urllib.parse import quote_plus, urlparse

from authlib.jose import jwt
from authlib.jose import jwt # type: ignore
from typing_extensions import Literal

from cozepy.config import COZE_CN_BASE_URL, COZE_COM_BASE_URL
Expand Down Expand Up @@ -55,7 +55,7 @@ class Scope(CozeModel):
attribute_constraint: Optional[ScopeAttributeConstraint] = None

@staticmethod
def from_bot_chat(bot_id_list: List[str], permission_list: List[str] = None) -> "Scope":
def from_bot_chat(bot_id_list: List[str], permission_list: Optional[List[str]] = None) -> "Scope":
if not permission_list:
permission_list = ["Connector.botChat"]
return Scope(
Expand All @@ -80,9 +80,9 @@ def _get_oauth_url(
self,
redirect_uri: str,
state: str,
code_challenge: str = None,
code_challenge_method: str = None,
workspace_id: str = None,
code_challenge: Optional[str] = None,
code_challenge_method: Optional[str] = None,
workspace_id: Optional[str] = None,
):
params = {
"response_type": "code",
Expand All @@ -92,7 +92,9 @@ def _get_oauth_url(
}
if code_challenge:
params["code_challenge"] = code_challenge
if code_challenge_method:
params["code_challenge_method"] = code_challenge_method

uri = f"{self._get_www_base_url}/api/permission/oauth2/authorize"
if workspace_id:
uri = f"{self._get_www_base_url}/api/permission/oauth2/workspace_id/{workspace_id}/authorize"
Expand All @@ -106,7 +108,7 @@ def _refresh_access_token(self, refresh_token: str, secret: str = "") -> OAuthTo
"client_id": self._client_id,
"refresh_token": refresh_token,
}
return self._requester.request("post", url, OAuthToken, headers=headers, body=body)
return self._requester.request("post", url, False, OAuthToken, headers=headers, body=body)

async def _arefresh_access_token(self, refresh_token: str, secret: str = "") -> OAuthToken:
url = f"{self._base_url}/api/permission/oauth2/token"
Expand All @@ -116,7 +118,7 @@ async def _arefresh_access_token(self, refresh_token: str, secret: str = "") ->
"client_id": self._client_id,
"refresh_token": refresh_token,
}
return await self._requester.arequest("post", url, OAuthToken, headers=headers, body=body)
return await self._requester.arequest("post", url, False, OAuthToken, headers=headers, body=body)

@property
def _get_www_base_url(self) -> str:
Expand Down Expand Up @@ -149,7 +151,7 @@ def get_oauth_url(
self,
redirect_uri: str,
state: str,
workspace_id: str = None,
workspace_id: Optional[str] = None,
):
"""
Get the pkce flow authorized url.
Expand Down Expand Up @@ -183,13 +185,13 @@ def get_access_token(
"code": code,
"redirect_uri": redirect_uri,
}
return self._requester.request("post", url, OAuthToken, headers=headers, body=body)
return self._requester.request("post", url, False, OAuthToken, headers=headers, body=body)

def refresh_access_token(self, refresh_token: str) -> OAuthToken:
return self._refresh_access_token(refresh_token, self._client_secret)


class AsyncWebOAuthApp(WebOAuthApp):
class AsyncWebOAuthApp(OAuthApp):
"""
Normal OAuth App.
"""
Expand All @@ -205,7 +207,29 @@ def __init__(self, client_id: str, client_secret: str, base_url: str = COZE_COM_
self._base_url = base_url
self._api_endpoint = urlparse(base_url).netloc
self._token = ""
super().__init__(client_id, client_secret, base_url, www_base_url=www_base_url)
super().__init__(client_id, base_url, www_base_url=www_base_url)

def get_oauth_url(
self,
redirect_uri: str,
state: str,
workspace_id: Optional[str] = None,
):
"""
Get the pkce flow authorized url.
:param redirect_uri: The redirect_uri of your app, where authentication responses can be sent and received by
your app. It must exactly match one of the redirect URIs you registered in the OAuth Apps.
:param state: A value included in the request that is also returned in the token response. It can be a string
of any hash value.
:param workspace_id:
:return:
"""
return self._get_oauth_url(
redirect_uri,
state,
workspace_id=workspace_id,
)

async def get_access_token(
self,
Expand All @@ -223,7 +247,7 @@ async def get_access_token(
"code": code,
"redirect_uri": redirect_uri,
}
return await self._requester.arequest("post", url, OAuthToken, headers=headers, body=body)
return await self._requester.arequest("post", url, False, OAuthToken, headers=headers, body=body)

async def refresh_access_token(self, refresh_token: str) -> OAuthToken:
return await self._arefresh_access_token(refresh_token, self._client_secret)
Expand All @@ -249,7 +273,7 @@ def __init__(self, client_id: str, private_key: str, public_key_id: str, base_ur
self._public_key_id = public_key_id
super().__init__(client_id, base_url, www_base_url="")

def get_access_token(self, ttl: int, scope: Scope = None) -> OAuthToken:
def get_access_token(self, ttl: int, scope: Optional[Scope] = None) -> OAuthToken:
"""
Get the token by jwt with jwt auth flow.
"""
Expand All @@ -261,7 +285,7 @@ def get_access_token(self, ttl: int, scope: Scope = None) -> OAuthToken:
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
"scope": scope.model_dump() if scope else None,
}
return self._requester.request("post", url, OAuthToken, headers=headers, body=body)
return self._requester.request("post", url, False, OAuthToken, headers=headers, body=body)

def _gen_jwt(self, ttl: int):
now = int(time.time())
Expand Down Expand Up @@ -297,7 +321,7 @@ def __init__(self, client_id: str, private_key: str, public_key_id: str, base_ur
self._public_key_id = public_key_id
super().__init__(client_id, base_url, www_base_url="")

async def get_access_token(self, ttl: int, scope: Scope = None) -> OAuthToken:
async def get_access_token(self, ttl: int, scope: Optional[Scope] = None) -> OAuthToken:
"""
Get the token by jwt with jwt auth flow.
"""
Expand All @@ -309,7 +333,7 @@ async def get_access_token(self, ttl: int, scope: Scope = None) -> OAuthToken:
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
"scope": scope.model_dump() if scope else None,
}
return await self._requester.arequest("post", url, OAuthToken, headers=headers, body=body)
return await self._requester.arequest("post", url, False, OAuthToken, headers=headers, body=body)

def _gen_jwt(self, ttl: int):
now = int(time.time())
Expand Down Expand Up @@ -345,7 +369,7 @@ def get_oauth_url(
state: str,
code_verifier: str,
code_challenge_method: Literal["plain", "S256"] = "plain",
workspace_id: str = None,
workspace_id: Optional[str] = None,
):
"""
Get the pkce flow authorized url.
Expand Down Expand Up @@ -386,13 +410,13 @@ def get_access_token(self, redirect_uri: str, code: str, code_verifier: str) ->
"redirect_uri": redirect_uri,
"code_verifier": code_verifier,
}
return self._requester.request("post", url, OAuthToken, body=body)
return self._requester.request("post", url, False, OAuthToken, body=body)

def refresh_access_token(self, refresh_token: str) -> OAuthToken:
return self._refresh_access_token(refresh_token)


class AsyncPKCEOAuthApp(PKCEOAuthApp):
class AsyncPKCEOAuthApp(OAuthApp):
"""
PKCE OAuth App.
"""
Expand All @@ -406,6 +430,36 @@ def __init__(self, client_id: str, base_url: str = COZE_COM_BASE_URL, www_base_u
www_base_url,
)

def get_oauth_url(
self,
redirect_uri: str,
state: str,
code_verifier: str,
code_challenge_method: Literal["plain", "S256"] = "plain",
workspace_id: Optional[str] = None,
):
"""
Get the pkce flow authorized url.
:param redirect_uri: The redirect_uri of your app, where authentication responses can be sent and received by
your app. It must exactly match one of the redirect URIs you registered in the OAuth Apps.
:param state: A value included in the request that is also returned in the token response. It can be a string
of any hash value.
:param code_verifier:
:param code_challenge_method:
:param workspace_id:
:return:
"""
code_challenge = code_verifier if code_challenge_method == "plain" else gen_s256_code_challenge(code_verifier)

return self._get_oauth_url(
redirect_uri,
state,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
workspace_id=workspace_id,
)

async def get_access_token(self, redirect_uri: str, code: str, code_verifier: str) -> OAuthToken:
"""
Get the token with pkce auth flow.
Expand All @@ -423,7 +477,7 @@ async def get_access_token(self, redirect_uri: str, code: str, code_verifier: st
"redirect_uri": redirect_uri,
"code_verifier": code_verifier,
}
return await self._requester.arequest("post", url, OAuthToken, body=body)
return await self._requester.arequest("post", url, False, OAuthToken, body=body)

async def refresh_access_token(self, refresh_token: str) -> OAuthToken:
return await self._arefresh_access_token(refresh_token)
Expand All @@ -445,7 +499,7 @@ def __init__(self, client_id: str, base_url: str = COZE_COM_BASE_URL, www_base_u

def get_device_code(
self,
workspace_id: str = None,
workspace_id: Optional[str] = None,
) -> DeviceAuthCode:
"""
Get the pkce flow authorized url.
Expand All @@ -463,7 +517,7 @@ def get_device_code(
headers = {
"Content-Type": "application/json",
}
res = self._requester.request("post", uri, DeviceAuthCode, headers=headers, body=body)
res = self._requester.request("post", uri, False, DeviceAuthCode, headers=headers, body=body)
res.verification_url = f"{res.verification_uri}?user_code={res.user_code}"
return res

Expand Down Expand Up @@ -508,13 +562,13 @@ def _get_access_token(self, device_code: str, poll: bool = False) -> OAuthToken:
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": device_code,
}
return self._requester.request("post", url, OAuthToken, body=body)
return self._requester.request("post", url, False, OAuthToken, body=body)

def refresh_access_token(self, refresh_token: str) -> OAuthToken:
return self._refresh_access_token(refresh_token)


class AsyncDeviceOAuthApp(DeviceOAuthApp):
class AsyncDeviceOAuthApp(OAuthApp):
"""
Device OAuth App.
"""
Expand All @@ -528,7 +582,7 @@ def __init__(self, client_id: str, base_url: str = COZE_COM_BASE_URL, www_base_u
www_base_url,
)

async def get_device_code(self, workspace_id: str = None) -> DeviceAuthCode:
async def get_device_code(self, workspace_id: Optional[str] = None) -> DeviceAuthCode:
"""
Get the pkce flow authorized url.
Expand All @@ -545,7 +599,7 @@ async def get_device_code(self, workspace_id: str = None) -> DeviceAuthCode:
headers = {
"Content-Type": "application/json",
}
res = await self._requester.arequest("post", uri, DeviceAuthCode, headers=headers, body=body)
res = await self._requester.arequest("post", uri, False, DeviceAuthCode, headers=headers, body=body)
res.verification_url = f"{res.verification_uri}?user_code={res.user_code}"
return res

Expand Down Expand Up @@ -590,7 +644,7 @@ async def _get_access_token(self, device_code: str, poll: bool = False) -> OAuth
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": device_code,
}
return await self._requester.arequest("post", url, OAuthToken, body=body)
return await self._requester.arequest("post", url, False, OAuthToken, body=body)

async def refresh_access_token(self, refresh_token: str) -> OAuthToken:
return await self._arefresh_access_token(refresh_token)
Expand Down
Loading

0 comments on commit 05bc61e

Please sign in to comment.