Skip to content

Commit

Permalink
feat: add iolet reusage
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Sep 13, 2024
1 parent b567456 commit 361f08c
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 38 deletions.
28 changes: 17 additions & 11 deletions jina/clients/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ class BaseClient(InstrumentationMixin, ABC):
"""

def __init__(
self,
args: Optional['argparse.Namespace'] = None,
**kwargs,
self,
args: Optional['argparse.Namespace'] = None,
**kwargs,
):
if args and isinstance(args, argparse.Namespace):
self.args = args
Expand Down Expand Up @@ -63,6 +63,11 @@ def __init__(
)
send_telemetry_event(event='start', obj_cls_name=self.__class__.__name__)

async def close(self):
"""Closes the potential resources of the Client.
"""
return self.teardown_instrumentation()

def teardown_instrumentation(self):
"""Shut down the OpenTelemetry tracer and meter if available. This ensures that the daemon threads for
exporting metrics data is properly cleaned up.
Expand Down Expand Up @@ -118,7 +123,7 @@ def check_input(inputs: Optional['InputType'] = None, **kwargs) -> None:
raise BadClientInput from ex

def _get_requests(
self, **kwargs
self, **kwargs
) -> Union[Iterator['Request'], AsyncIterator['Request']]:
"""
Get request in generator.
Expand Down Expand Up @@ -177,13 +182,14 @@ def inputs(self, bytes_gen: 'InputType') -> None:

@abc.abstractmethod
async def _get_results(
self,
inputs: 'InputType',
on_done: 'CallbackFnType',
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
**kwargs,
): ...
self,
inputs: 'InputType',
on_done: 'CallbackFnType',
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
**kwargs,
):
...

@abc.abstractmethod
def _is_flow_ready(self, **kwargs) -> bool:
Expand Down
21 changes: 9 additions & 12 deletions jina/clients/base/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class AioHttpClientlet(ABC):

def __init__(
self,
url: str,
logger: 'JinaLogger',
max_attempts: int = 1,
initial_backoff: float = 0.5,
Expand All @@ -59,7 +58,6 @@ def __init__(
) -> None:
"""HTTP Client to be used with the streamer
:param url: url to send http/websocket request to
:param logger: jina logger
:param max_attempts: Number of sending attempts, including the original request.
:param initial_backoff: The first retry will happen with a delay of random(0, initial_backoff)
Expand All @@ -68,7 +66,6 @@ def __init__(
:param tracer_provider: Optional tracer_provider that will be used to configure aiohttp tracing.
:param kwargs: kwargs which will be forwarded to the `aiohttp.Session` instance. Used to pass headers to requests
"""
self.url = url
self.logger = logger
self.msg_recv = 0
self.msg_sent = 0
Expand Down Expand Up @@ -154,7 +151,7 @@ class HTTPClientlet(AioHttpClientlet):

UPDATE_EVENT_PREFIX = 14 # the update event has the following format: "event: update: {document_json}"

async def send_message(self, request: 'Request'):
async def send_message(self, url, request: 'Request'):
"""Sends a POST request to the server
:param request: request as dict
Expand All @@ -166,7 +163,7 @@ async def send_message(self, request: 'Request'):
req_dict['target_executor'] = req_dict['header']['target_executor']
for attempt in range(1, self.max_attempts + 1):
try:
request_kwargs = {'url': self.url}
request_kwargs = {'url': url}
if not docarray_v2:
request_kwargs['json'] = req_dict
else:
Expand All @@ -179,10 +176,10 @@ async def send_message(self, request: 'Request'):
except aiohttp.ContentTypeError:
r_str = await response.text()
r_status = response.status
handle_response_status(response.status, r_str, self.url)
handle_response_status(response.status, r_str, url)
return r_status, r_str
except (ValueError, ConnectionError, BadClient, aiohttp.ClientError, aiohttp.ClientConnectionError) as err:
self.logger.debug(f'Got an error: {err} sending POST to {self.url} in attempt {attempt}/{self.max_attempts}')
self.logger.debug(f'Got an error: {err} sending POST to {url} in attempt {attempt}/{self.max_attempts}')
await retry.wait_or_raise_err(
attempt=attempt,
err=err,
Expand All @@ -193,10 +190,10 @@ async def send_message(self, request: 'Request'):
)
except Exception as exc:
self.logger.debug(
f'Got a non-retried error: {exc} sending POST to {self.url}')
f'Got a non-retried error: {exc} sending POST to {url}')
raise exc

async def send_streaming_message(self, doc: 'Document', on: str):
async def send_streaming_message(self, url, doc: 'Document', on: str):
"""Sends a GET SSE request to the server
:param doc: Request Document
Expand All @@ -205,7 +202,7 @@ async def send_streaming_message(self, doc: 'Document', on: str):
"""
req_dict = doc.to_dict() if hasattr(doc, "to_dict") else doc.dict()
request_kwargs = {
'url': self.url,
'url': url,
'headers': {'Accept': 'text/event-stream'},
'json': req_dict,
}
Expand All @@ -219,13 +216,13 @@ async def send_streaming_message(self, doc: 'Document', on: str):
elif event.startswith(b'end'):
pass

async def send_dry_run(self, **kwargs):
async def send_dry_run(self, url, **kwargs):
"""Query the dry_run endpoint from Gateway
:param kwargs: keyword arguments to make sure compatible API with other clients
:return: send get message
"""
return await self.session.get(
url=self.url, timeout=kwargs.get('timeout', None)
url=url, timeout=kwargs.get('timeout', None)
).__aenter__()

async def recv_message(self):
Expand Down
45 changes: 34 additions & 11 deletions jina/clients/base/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ class HTTPBaseClient(BaseClient):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._endpoints = []
self.iolet = None

async def close(self):
await super().close()
if self.iolet is not None:
await self.iolet.__aexit__()

async def _get_endpoints_from_openapi(self, **kwargs):
def extract_paths_by_method(spec):
Expand Down Expand Up @@ -69,14 +75,24 @@ async def _is_flow_ready(self, **kwargs) -> bool:
try:
proto = 'https' if self.args.tls else 'http'
url = f'{proto}://{self.args.host}:{self.args.port}/dry_run'
iolet = await stack.enter_async_context(
HTTPClientlet(
url=url,

if self.iolet is not None and self.args.reuse_session:
iolet = self.iolet
else:
iolet = HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
**kwargs,
)
)

if self.args.reuse_session and self.iolet is None:
self.iolet = iolet
await self.iolet.__aenter__()

if not self.args.reuse_session:
iolet = await stack.enter_async_context(
iolet
)

response = await iolet.send_dry_run(**kwargs)
r_status = response.status
Expand Down Expand Up @@ -152,9 +168,10 @@ async def _get_results(
else:
url = f'{proto}://{self.args.host}:{self.args.port}/post'

iolet = await stack.enter_async_context(
HTTPClientlet(
url=url,
if self.iolet is not None and self.args.reuse_session:
iolet = self.iolet
else:
iolet = HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
max_attempts=max_attempts,
Expand All @@ -164,7 +181,14 @@ async def _get_results(
timeout=timeout,
**kwargs,
)
)
if self.args.reuse_session and self.iolet is None:
self.iolet = iolet
await self.iolet.__aenter__()

if not self.args.reuse_session:
iolet = await stack.enter_async_context(
iolet
)

def _request_handler(
request: 'Request', **kwargs
Expand All @@ -176,7 +200,7 @@ def _request_handler(
:param kwargs: kwargs
:return: asyncio Task for sending message
"""
return asyncio.ensure_future(iolet.send_message(request=request)), None
return asyncio.ensure_future(iolet.send_message(url=url, request=request)), None

def _result_handler(result):
return result
Expand Down Expand Up @@ -250,15 +274,14 @@ async def _get_streaming_results(
url = f'{proto}://{self.args.host}:{self.args.port}/default'

iolet = HTTPClientlet(
url=url,
logger=self.logger,
tracer_provider=self.tracer_provider,
timeout=timeout,
**kwargs,
)

async with iolet:
async for doc in iolet.send_streaming_message(doc=inputs, on=on):
async for doc in iolet.send_streaming_message(url=url, doc=inputs, on=on):
if not docarray_v2:
yield Document.from_dict(json.loads(doc))
else:
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/clients/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ async def test_http_clientlet():
port = random_port()
with Flow(port=port, protocol='http').add():
async with HTTPClientlet(
url=f'http://localhost:{port}/post', logger=logger
logger=logger
) as iolet:
request = _new_data_request('/', None, {'a': 'b'})
assert request.header.target_executor == ''
r_status, r_json = await iolet.send_message(request)
r_status, r_json = await iolet.send_message(url=f'http://localhost:{port}/post', request=request)
response = DataRequest(r_json)
assert response.header.exec_endpoint == '/'
assert response.parameters == {'a': 'b'}
Expand All @@ -50,11 +50,11 @@ async def test_http_clientlet_target():
port = random_port()
with Flow(port=port, protocol='http').add():
async with HTTPClientlet(
url=f'http://localhost:{port}/post', logger=logger
logger=logger
) as iolet:
request = _new_data_request('/', 'nothing', {'a': 'b'})
assert request.header.target_executor == 'nothing'
r = await iolet.send_message(request)
r = await iolet.send_message(url=f'http://localhost:{port}/post', request=request)
r_status, r_json = r
response = DataRequest(r_json)
assert response.header.exec_endpoint == '/'
Expand Down

0 comments on commit 361f08c

Please sign in to comment.