Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: get oauth token by jwt flow #4

Merged
merged 3 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ jobs:
env:
COZE_TOKEN: ${{ secrets.COZE_TOKEN }}
SPACE_ID_1: ${{ secrets.SPACE_ID_1 }}
COZE_JWT_AUTH_CLIENT_ID: ${{ secrets.COZE_JWT_AUTH_CLIENT_ID }}
COZE_JWT_AUTH_PRIVATE_KEY: ${{ secrets.COZE_JWT_AUTH_PRIVATE_KEY }}
COZE_JWT_AUTH_KEY_ID: ${{ secrets.COZE_JWT_AUTH_KEY_ID }}
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3

Expand Down
9 changes: 6 additions & 3 deletions cozepy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from .auth import Auth, TokenAuth

from .auth import ApplicationOAuth, Auth, TokenAuth
from .config import COZE_COM_BASE_URL, COZE_CN_BASE_URL
from .coze import Coze

from .model import TokenPaged, NumberPaged

__all__ = [
'ApplicationOAuth',
'Auth',
'TokenAuth',

'COZE_COM_BASE_URL',
'COZE_CN_BASE_URL',

'Coze',

'TokenPaged',
Expand Down
84 changes: 84 additions & 0 deletions cozepy/auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,88 @@
import abc
import random
import time
from urllib.parse import urlparse

from authlib.jose import jwt

from cozepy.model import CozeModel
from cozepy.request import Requester
from .config import COZE_COM_BASE_URL


def _random_hex(length):
hex_characters = '0123456789abcdef'
return ''.join(random.choice(hex_characters) for _ in range(length))


class OAuthToken(CozeModel):
# The requested access token. The app can use this token to authenticate to the Coze resource.
access_token: str
# How long the access token is valid, in seconds (UNIX timestamp)
expires_in: int
# An OAuth 2.0 refresh token. The app can use this token to acquire other access tokens after the current access token expires. Refresh tokens are long-lived.
refresh_token: str = ''
# fixed: Bearer
token_type: str = ''


class DeviceAuthCode(CozeModel):
# device code
device_code: str
# The user code
user_code: str
# The verification uri
verification_uri: str
# The interval of the polling request
interval: int = 5
# The expiration time of the device code
expires_in: int

@property
def verification_url(self):
return f'{self.verification_uri}?user_code={self.user_code}'


class ApplicationOAuth(object):
"""
App OAuth process to support obtaining token and refreshing token.
"""

def __init__(self, client_id: str, client_secret: str = '', base_url: str = COZE_COM_BASE_URL):
self._client_id = client_id
self._client_secret = client_secret
self._base_url = base_url
self._api_endpoint = urlparse(base_url).netloc
self._token = ''
self._requester = Requester()

def jwt_auth(self, private_key: str, kid: str, ttl: int):
"""
Get the token by jwt with jwt auth flow.
"""
jwt_token = self._gen_jwt(self._api_endpoint, private_key, self._client_id, kid, 3600)
url = f'{self._base_url}/api/permission/oauth2/token'
headers = {
'Authorization': f'Bearer {jwt_token}'
}
body = {
'duration_seconds': ttl,
'grant_type': 'urn:ietf:params:oauth:grant-type:jwt-bearer',
}
return self._requester.request('post', url, OAuthToken, headers=headers, body=body)

def _gen_jwt(self, api_endpoint: str, private_key: str, client_id: str, kid: str, ttl: int):
now = int(time.time())
header = {'alg': 'RS256', 'typ': 'JWT', 'kid': kid}
payload = {
"iss": client_id,
'aud': api_endpoint,
"iat": now,
"exp": now + ttl,
'jti': _random_hex(16),
}
s = jwt.encode(header, payload, private_key)
return s.decode('utf-8')


class Auth(abc.ABC):
Expand Down
2 changes: 2 additions & 0 deletions cozepy/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
COZE_COM_BASE_URL = 'https://api.coze.com'
COZE_CN_BASE_URL = 'https://api.coze.cn'
3 changes: 2 additions & 1 deletion cozepy/coze.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import TYPE_CHECKING

from cozepy.auth import Auth
from cozepy.config import COZE_COM_BASE_URL
from cozepy.request import Requester

if TYPE_CHECKING:
Expand All @@ -10,7 +11,7 @@
class Coze(object):
def __init__(self,
auth: Auth,
base_url: str = 'https://api.coze.com',
base_url: str = COZE_COM_BASE_URL,
):
self._auth = auth
self._base_url = base_url
Expand Down
47 changes: 29 additions & 18 deletions cozepy/request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Tuple, Optional

import requests
from requests import Response

if TYPE_CHECKING:
from cozepy.auth import Auth
Expand All @@ -25,39 +26,49 @@ class Requester(object):
"""

def __init__(self,
auth: Optional["Auth"]
auth: 'Auth' = None
):
self._auth = auth

def request(self, method: str, url: str, model: Type[T], params: dict = None, headers: dict = None) -> T:
def request(self, method: str, url: str, model: Type[T], params: dict = None, headers: dict = None,
body: dict = None, ) -> T:
"""
Send a request to the server.
"""
if headers is None:
headers = {}
self._auth.authentication(headers)
r = requests.request(method, url, params=params, headers=headers)
logid = r.headers.get('x-tt-logid')
if self._auth:
self._auth.authentication(headers)
r = requests.request(method, url, params=params, headers=headers, json=body)

try:
json = r.json()
code = json.get('code') or 0
msg = json.get('msg') or ''
data = json.get('data')
except:
r.raise_for_status()

code = 0
msg = ''
data = {}
code, msg, data = self.__parse_requests_code_msg(r)

if code > 0:
if code is not None and code > 0:
# TODO: Exception 自定义类型
logid = r.headers.get('x-tt-logid')
raise Exception(f'{code}: {msg}, logid:{logid}')
elif code is None and msg != "":
logid = r.headers.get('x-tt-logid')
raise Exception(f'{msg}, logid:{logid}')
return model.model_validate(data)

async def arequest(self, method: str, path: str, **kwargs) -> dict:
"""
Send a request to the server with asyncio.
"""
pass

def __parse_requests_code_msg(self, r: Response) -> Tuple[Optional[int], str, Optional[T]]:
try:
json = r.json()
except:
r.raise_for_status()
return

if 'code' in json and 'msg' in json and int(json['code']) > 0:
return int(json['code']), json['msg'], json['data']
if 'error_message' in json and json['error_message'] != '':
return None, json['error_message'], None
if 'data' in json:
return 0, '', json['data']
return 0, '', json
Loading
Loading