Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add async methods to the sync client class #22

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions modelz/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .client import ModelzClient
from .aioclient import AioModelzClient

__all__ = [
"ModelzClient",
"AioModelzClient",
]
147 changes: 147 additions & 0 deletions modelz/aioclient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from __future__ import annotations
from typing import Any
from http import HTTPStatus
from urllib.parse import urljoin

import aiohttp

from rich.console import Console

from modelz.env import EnvConfig
from modelz.serde import Serde, SerdeEnum, TextSerde
from modelz.utils import get_ssl_context_no_verify


TIMEOUT = 300
console = Console()
DEFAULT_RESP_SERDE = TextSerde()
DEFAULT_RETRY = 3
tddschn marked this conversation as resolved.
Show resolved Hide resolved


class ModelzAuth:
def __init__(self, key: str | None = None) -> None:
config = EnvConfig()
self.key: str = key if key else config.api_key
if not self.key:
raise RuntimeError("cannot find the API key")

def get_headers(self) -> dict:
return {"X-API-Key": self.key}


class ModelzResponse:
def __init__(self, resp: aiohttp.ClientResponse, serde: Serde = DEFAULT_RESP_SERDE):
"""Modelz internal response."""
if resp.status != HTTPStatus.OK:
console.print(f"[bold red]err[{resp.status}][/bold red]: {resp.text}")
raise ValueError(f"inference err with code {resp.status}")
self.resp = resp
self.serde = serde
self._data = None

async def save_to_file(self, file: str):
with open(file, "wb") as f:
f.write(await self.data)

@property
async def data(self) -> Any:
if not self._data:
self._data = self.serde.decode(await self.resp.content.read())
return self._data

async def show(self):
console.print(await self.data)
Comment on lines +63 to +74
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the sync interface work with this?



class AioModelzClient:
def __init__(
self,
deployment: str | None = None,
key: str | None = None,
host: str | None = None,
timeout: float = TIMEOUT,
) -> None:
# ...
tddschn marked this conversation as resolved.
Show resolved Hide resolved
config = EnvConfig()
self.host = host if host else config.host
self.deployment = deployment
self.auth = ModelzAuth(key)
self.timeout = timeout
self.session_request_kwargs = {}
if not getattr(config, "ssl_verify", True):
self.session_request_kwargs.update({"ssl": get_ssl_context_no_verify()})

async def _post(self, url, content, timeout):
headers = self.auth.get_headers()
async with aiohttp.ClientSession() as session:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reuse the session.

Copy link
Contributor Author

@tddschn tddschn May 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried that, see https://github.com/tddschn/nssurge-api/blob/f2aafe1b01877b8ff8bfc6e74a58736a7ba3b058/nssurge_api/api.py#L30

Definding self.session in init() would require also defining __aenter__ and __aexit__ methods, and request methods that uses self.session to use an async context manager (see https://github.com/tddschn/nssurge-cli/blob/112e600fad6299d4f78012b7aaa20dfb0cdb9758/nssurge_cli/cap_commands.py#L20), due to how aiohttp works.

I'll try to figure out a better way to support session reuse.

async with session.post(
url,
data=content,
headers=headers,
timeout=timeout,
**self.session_request_kwargs,
) as response:
return response

async def _get(self, url, timeout):
headers = self.auth.get_headers()
async with aiohttp.ClientSession() as session:
async with session.get(
url, headers=headers, timeout=timeout, **self.session_request_kwargs
) as response:
return response

async def inference(
self,
params: Any,
deployment: str | None = None,
serde: str = "json",
) -> ModelzResponse:
"""Get the inference result.

Args:
params: request params, will be serialized by `serde`
deployment: deployment ID
serde: serialize/deserialize method, choose from ("json", "msg", "raw")
"""
deploy = deployment if deployment else self.deployment
assert deploy, "deployment is required"
self.serde = SerdeEnum[serde.lower()].value()

with console.status(f"[bold green]Modelz {deploy} inference..."):
resp = await self._post(
urljoin(self.host.format(deploy), "/inference"),
self.serde.encode(params),
self.timeout,
)

return ModelzResponse(resp, self.serde)

async def metrics(self, deployment: str | None = None) -> ModelzResponse:
"""Get deployment metrics.

Args:
deployment: deployment ID
"""
deploy = deployment if deployment else self.deployment
assert deploy, "deployment is required"

with console.status(f"[bold green]Modelz {deploy} metrics..."):
resp = await self._get(
urljoin(self.host.format(deploy), "/metrics"),
self.timeout,
)

return ModelzResponse(resp)

async def build(self, repo: str):
"""Build a Docker image and push it to the registry."""
with console.status(f"[bold green]Modelz build {repo}..."):
resp = await self._post(
urljoin(self.host.format("api"), "/build"),
None,
self.timeout,
)

ModelzResponse(resp)
console.print(f"created the build job for repo [bold cyan]{repo}[/bold cyan]")
130 changes: 41 additions & 89 deletions modelz/client.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,8 @@
from __future__ import annotations
from typing import Any, Generator
from http import HTTPStatus
from urllib.parse import urljoin

import httpx
from rich.console import Console

from .env import EnvConfig
from .serde import Serde, SerdeEnum, TextSerde


TIMEOUT = httpx.Timeout(5, read=300, write=300)
console = Console()
config = EnvConfig()
DEFAULT_RESP_SERDE = TextSerde()
DEFAULT_RETRY = 3


class ModelzAuth(httpx.Auth):
def __init__(self, key: str | None = None) -> None:
self.key: str = key if key else config.api_key
if not self.key:
raise RuntimeError("cannot find the API key")

def auth_flow(
self, request: httpx.Request
) -> Generator[httpx.Request, httpx.Response, None]:
request.headers["X-API-Key"] = self.key
yield request


class ModelzResponse:
def __init__(self, resp: httpx.Response, serde: Serde = DEFAULT_RESP_SERDE):
"""Modelz internal response."""
if resp.status_code != HTTPStatus.OK:
console.print(f"[bold red]err[{resp.status_code}][/bold red]: {resp.text}")
raise ValueError(f"inference err with code {resp.status_code}")
self.resp = resp
self.serde = serde
self._data = None

def save_to_file(self, file: str):
with open(file, "wb") as f:
f.write(self.data)

@property
def data(self) -> Any:
if not self._data:
self._data = self.serde.decode(self.resp.content)
return self._data

def show(self):
console.print(self.data)
import asyncio
from typing import Any
from modelz.aioclient import ModelzResponse, AioModelzClient
from modelz.aioclient import TIMEOUT


class ModelzClient:
Expand All @@ -60,7 +11,7 @@ def __init__(
deployment: str | None = None,
key: str | None = None,
host: str | None = None,
timeout: float | httpx.Timeout = TIMEOUT,
timeout: float = TIMEOUT,
) -> None:
"""Create a Modelz Client.

Expand All @@ -70,13 +21,12 @@ def __init__(
host: Modelz host address
timeout: request timeout (second)
"""
self.host = host if host else config.host
self.deployment = deployment
auth = ModelzAuth(key)
transport = httpx.HTTPTransport(retries=DEFAULT_RETRY)
self.client = httpx.Client(auth=auth, transport=transport)
self.serde: Serde
self.timeout = timeout
self.client = AioModelzClient(
deployment=deployment,
key=key,
host=host,
timeout=timeout,
)

def inference(
self,
Expand All @@ -91,43 +41,45 @@ def inference(
deployment: deployment ID
serde: serialize/deserialize method, choose from ("json", "msg", "raw")
"""
deploy = deployment if deployment else self.deployment
assert deploy, "deployment is required"
self.serde = SerdeEnum[serde.lower()].value()

with console.status(f"[bold green]Modelz {deploy} inference..."):
resp = self.client.post(
urljoin(self.host.format(deploy), "/inference"),
content=self.serde.encode(params),
timeout=self.timeout,
)

return ModelzResponse(resp, self.serde)
return asyncio.run(self.client.inference(params, deployment, serde))

def metrics(self, deployment: str | None = None) -> ModelzResponse:
"""Get deployment metrics.

Args:
deployment: deployment ID
"""
deploy = deployment if deployment else self.deployment
assert deploy, "deployment is required"
return asyncio.run(self.client.metrics(deployment))

with console.status(f"[bold green]Modelz {deploy} metrics..."):
resp = self.client.get(
urljoin(self.host.format(deploy), "/metrics"),
timeout=self.timeout,
)
def build(self, repo: str):
"""Build a Docker image and push it to the registry."""

return ModelzResponse(resp)
return asyncio.run(self.client.build(repo))

def build(self, repo: str):
async def ainference(
self,
params: Any,
deployment: str | None = None,
serde: str = "json",
) -> ModelzResponse:
"""Get the inference result.

Args:
params: request params, will be serialized by `serde`
deployment: deployment ID
serde: serialize/deserialize method, choose from ("json", "msg", "raw")
"""
return await self.client.inference(params, deployment, serde)

async def ametrics(self, deployment: str | None = None) -> ModelzResponse:
"""Get deployment metrics.

Args:
deployment: deployment ID
"""
return await self.client.metrics(deployment)

async def abuild(self, repo: str):
"""Build a Docker image and push it to the registry."""
with console.status(f"[bold green]Modelz build {repo}..."):
resp = self.client.post(
urljoin(self.host.format("api"), "/build"),
timeout=self.timeout,
)

ModelzResponse(resp)
console.print(f"created the build job for repo [bold cyan]{repo}[/bold cyan]")
return await self.client.build(repo)
6 changes: 4 additions & 2 deletions modelz/env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
import os

from modelz.utils import strtobool

PREFIX: str = "MODELZ_"

Expand All @@ -12,7 +12,9 @@ def __init__(self) -> None:
self.update_from_env()

def update_from_env(self):
for key in ("api_key", "host"):
for key in ("api_key", "host", "ssl_verify"):
tddschn marked this conversation as resolved.
Show resolved Hide resolved
val = os.environ.get(f"{PREFIX}{key.upper()}")
if key == "ssl_verify" and val is not None:
val = strtobool(val)
if val is not None:
setattr(self, key, val)
28 changes: 28 additions & 0 deletions modelz/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/usr/bin/env python3

from typing import Literal
import ssl


def strtobool(val: str) -> Literal[0, 1]:
# copied from distutils.util cuz importing it is too slow
"""Convert a string representation of truth to true (1) or false (0).

True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
'val' is anything else.
"""
val = val.lower()
if val in ("y", "yes", "t", "true", "on", "1"):
return 1
elif val in ("n", "no", "f", "false", "off", "0"):
return 0
tddschn marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError("invalid truth value %r" % (val,))


def get_ssl_context_no_verify() -> ssl.SSLContext:
sslcontext = ssl.create_default_context()
sslcontext.check_hostname = False
sslcontext.verify_mode = ssl.CERT_NONE
return sslcontext
Loading