Skip to content

Commit

Permalink
chore(python): proper type hinting (#3839)
Browse files Browse the repository at this point in the history
Co-authored-by: Pierre Millot <pierre.millot@algolia.com>
  • Loading branch information
Fluf22 and millotp authored Oct 4, 2024
1 parent 0cdb866 commit 6b42a26
Show file tree
Hide file tree
Showing 30 changed files with 238 additions and 183 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ApiResponse(Generic[T]):
def __init__(
self,
verb: Verb,
data: T = None,
data: Optional[T] = None,
error_message: str = "",
headers: Optional[Dict[str, str]] = None,
host: str = "",
Expand Down Expand Up @@ -94,6 +94,6 @@ def deserialize(klass: Any = None, data: Any = None) -> Any:
return data

if isinstance(data, str):
return klass.from_json(data)
return klass.from_json(data) # pyright: ignore

return klass.from_dict(data)
return klass.from_dict(data) # pyright: ignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,6 @@


class BaseConfig:
app_id: Optional[str]
api_key: Optional[str]

read_timeout: int
write_timeout: int
connect_timeout: int

wait_task_time_before_retry: Optional[int]

headers: Dict[str, str]
proxies: Dict[str, str]

hosts: HostsCollection

def __init__(self, app_id: Optional[str] = None, api_key: Optional[str] = None):
app_id = environ.get("ALGOLIA_APP_ID") if app_id is None else app_id

Expand All @@ -36,12 +22,14 @@ def __init__(self, app_id: Optional[str] = None, api_key: Optional[str] = None):
self.write_timeout = 30000
self.connect_timeout = 2000

self.wait_task_time_before_retry = None
self.headers = None
self.proxies = None
self.hosts = None
self.wait_task_time_before_retry: Optional[int] = None
self.headers: Optional[Dict[str, str]] = None
self.proxies: Optional[Dict[str, str]] = None
self.hosts: Optional[HostsCollection] = None

def set_client_api_key(self, api_key: str) -> None:
"""Sets a new API key to authenticate requests."""
self.api_key = api_key
if self.headers is None:
self.headers = {}
self.headers["x-algolia-api-key"] = api_key
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@


class BaseTransporter:
_config: BaseConfig
_retry_strategy: RetryStrategy
_hosts: List[Host]

def __init__(self, config: BaseConfig) -> None:
self._config = config
self._retry_strategy = RetryStrategy()
self._hosts = []
self._hosts: List[Host] = []
self._timeout = 5000

@property
def config(self) -> BaseConfig:
return self._config

def prepare(
self,
Expand All @@ -25,13 +26,18 @@ def prepare(

if use_read_transporter:
self._timeout = request_options.timeouts["read"]
self._hosts = self._config.hosts.read()
self._hosts = (
self._config.hosts.read() if self._config.hosts is not None else []
)
if isinstance(request_options.data, dict):
query_parameters.update(request_options.data)
return query_parameters
else:
self._timeout = request_options.timeouts["write"]
self._hosts = (
self._config.hosts.write() if self._config.hosts is not None else []
)

self._timeout = request_options.timeouts["write"]
self._hosts = self._config.hosts.write()
return query_parameters

def build_path(self, path, query_parameters):
if query_parameters is not None and len(query_parameters) > 0:
Expand All @@ -54,9 +60,21 @@ def build_url(self, host, path):
)

def get_proxy(self, url):
if self._config.proxies is None:
return None

if url.startswith("https"):
return self._config.proxies.get("https")
elif url.startswith("http"):
return self._config.proxies.get("http")
else:
return None

def get_proxies(self, url):
if self._config.proxies is None:
return None

if url.startswith("http"):
return self._config.proxies
else:
return None
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,12 @@ def __init__(self, msg, path_to_item=None) -> None:


class ApiException(AlgoliaException):
def __init__(self, status_code=None, error_message=None, raw_data=None) -> None:
def __init__(
self,
status_code: int = -1,
error_message: str = "Unknown error",
raw_data: bytes = b"",
) -> None:
self.status_code = status_code
self.error_message = error_message
self.body = raw_data.decode("utf-8")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def __call__(self, retry_count: int = 0) -> int:
async def create_iterable(
func: Callable[[Optional[T]], Awaitable[T]],
validate: Callable[[T], bool],
aggregator: Callable[[T], None],
timeout: Timeout = Timeout(),
aggregator: Optional[Callable[[T], None]],
timeout: Callable[[], int] = Timeout(),
error_validate: Optional[Callable[[T], bool]] = None,
error_message: Optional[Callable[[T], str]] = None,
) -> T:
Expand Down Expand Up @@ -55,8 +55,8 @@ async def retry(prev: Optional[T] = None) -> T:
def create_iterable_sync(
func: Callable[[Optional[T]], T],
validate: Callable[[T], bool],
aggregator: Callable[[T], None],
timeout: Timeout = Timeout(),
aggregator: Optional[Callable[[T], None]],
timeout: Callable[[], int] = Timeout(),
error_validate: Optional[Callable[[T], bool]] = None,
error_message: Optional[Callable[[T], str]] = None,
) -> T:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ def __init__(
self.port = port
self.priority = cast(int, priority)
self.accept = (CallType.WRITE | CallType.READ) if accept is None else accept

self.reset()
self.last_use = 0.0
self.retry_count = 0
self.up = True

def reset(self) -> None:
self.last_use = 0.0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from copy import deepcopy
from sys import version_info
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Optional, Union
from urllib.parse import quote

from algoliasearch.http.base_config import BaseConfig
Expand All @@ -13,20 +13,20 @@


class RequestOptions:
_config: BaseConfig
headers: Dict[str, str]
query_parameters: Dict[str, Any]
timeouts: Dict[str, int]
data: Dict[str, Any]

def __init__(
self,
config: BaseConfig,
headers: Dict[str, str] = {},
query_parameters: Dict[str, Any] = {},
timeouts: Dict[str, int] = {},
data: Dict[str, Any] = {},
headers: Optional[Dict[str, str]] = None,
query_parameters: Optional[Dict[str, Any]] = None,
timeouts: Optional[Dict[str, int]] = None,
data: Optional[Dict[str, Any]] = None,
) -> None:
if headers is None:
headers = {}
if query_parameters is None:
query_parameters = {}
if timeouts is None:
timeouts = {}
self._config = config
self.headers = headers
self.query_parameters = {
Expand All @@ -51,25 +51,28 @@ def from_dict(self, data: Dict[str, Dict[str, Any]]) -> Self:
query_parameters=data.get("query_parameters", {}),
timeouts=data.get("timeouts", {}),
data=data.get("data", {}),
)
) # pyright: ignore

def merge(
self,
query_parameters: List[Tuple[str, str]] = [],
headers: Dict[str, Optional[str]] = {},
_: Dict[str, int] = {},
query_parameters: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
data: Optional[str] = None,
user_request_options: Optional[Union[Self, Dict[str, Any]]] = None,
) -> Self:
"""
Merges the default config values with the user given request options if it exists.
"""

headers.update(self._config.headers)
if query_parameters is None:
query_parameters = {}
if headers is None:
headers = {}
headers.update(self._config.headers or {})

request_options = {
"headers": headers,
"query_parameters": {k: v for k, v in query_parameters},
"query_parameters": query_parameters,
"timeouts": {
"read": self._config.read_timeout,
"write": self._config.write_timeout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class RetryOutcome:
class RetryStrategy:
def valid_hosts(self, hosts: List[Host]) -> List[Host]:
for host in hosts:
if not host.up and self._now() - host.last_use > Host.TTL:
if not host.up and time.time() - host.last_use > Host.TTL:
host.up = True

return [host for host in hosts if host.up]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,16 @@ class QueryParametersSerializer:
Parses the given 'query_parameters' values of each keys into their string value.
"""

query_parameters: Dict[str, Any] = {}
def __init__(self, query_parameters: Optional[Dict[str, Any]]) -> None:
self.query_parameters: Dict[str, Any] = {}
if query_parameters is None:
return
for key, value in query_parameters.items():
if isinstance(value, dict):
for dkey, dvalue in value.items():
self.query_parameters[dkey] = self.parse(dvalue)
else:
self.query_parameters[key] = self.parse(value)

def parse(self, value) -> Any:
if isinstance(value, list):
Expand All @@ -27,19 +36,8 @@ def encoded(self) -> str:
dict(sorted(self.query_parameters.items(), key=lambda val: val[0]))
).replace("+", "%20")

def __init__(self, query_parameters: Optional[Dict[str, Any]]) -> None:
self.query_parameters = {}
if query_parameters is None:
return
for key, value in query_parameters.items():
if isinstance(value, dict):
for dkey, dvalue in value.items():
self.query_parameters[dkey] = self.parse(dvalue)
else:
self.query_parameters[key] = self.parse(value)


def bodySerializer(obj: Any) -> Any:
def body_serializer(obj: Any) -> Any:
"""Builds a JSON POST object.
If obj is None, return None.
Expand All @@ -57,14 +55,14 @@ def bodySerializer(obj: Any) -> Any:
elif isinstance(obj, PRIMITIVE_TYPES):
return obj
elif isinstance(obj, list):
return [bodySerializer(sub_obj) for sub_obj in obj]
return [body_serializer(sub_obj) for sub_obj in obj]
elif isinstance(obj, tuple):
return tuple(bodySerializer(sub_obj) for sub_obj in obj)
return tuple(body_serializer(sub_obj) for sub_obj in obj)
elif isinstance(obj, dict):
obj_dict = obj
else:
obj_dict = obj.to_dict()
if obj_dict is None:
return None

return {key: bodySerializer(val) for key, val in obj_dict.items()}
return {key: body_serializer(val) for key, val in obj_dict.items()}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from asyncio import TimeoutError
from json import loads
from typing import List, Optional

from aiohttp import ClientSession, TCPConnector
from async_timeout import timeout
Expand All @@ -11,19 +12,19 @@
AlgoliaUnreachableHostException,
RequestException,
)
from algoliasearch.http.hosts import Host
from algoliasearch.http.request_options import RequestOptions
from algoliasearch.http.retry import RetryOutcome, RetryStrategy
from algoliasearch.http.verb import Verb


class Transporter(BaseTransporter):
_session: ClientSession

def __init__(self, config: BaseConfig) -> None:
self._session = None
super().__init__(config)
self._session: Optional[ClientSession] = None
self._config = config
self._retry_strategy = RetryStrategy()
self._hosts = []
self._hosts: List[Host] = []

async def close(self) -> None:
if self._session is not None:
Expand Down Expand Up @@ -71,7 +72,7 @@ async def request(
url=url,
host=host.url,
status_code=resp.status,
headers=resp.headers,
headers=resp.headers, # pyright: ignore # insensitive dict is still a dict
data=_raw_data,
raw_data=_raw_data,
error_message=str(resp.reason),
Expand Down Expand Up @@ -103,6 +104,7 @@ async def request(

class EchoTransporter(Transporter):
def __init__(self, config: BaseConfig) -> None:
super().__init__(config)
self._config = config
self._retry_strategy = RetryStrategy()

Expand Down
Loading

0 comments on commit 6b42a26

Please sign in to comment.