Skip to content

Commit

Permalink
fix: invalid input raise exception (#5141)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM authored Oct 7, 2022
1 parent 425b029 commit 573f607
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 14 deletions.
1 change: 1 addition & 0 deletions jina/clients/base/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ async def _get_results(
self.logger.warning('user cancel the process')
except asyncio.CancelledError as ex:
self.logger.warning(f'process error: {ex!r}')
raise
except:
# Not sure why, adding this line helps in fixing a hanging test
raise
41 changes: 28 additions & 13 deletions jina/clients/base/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,16 @@ async def _send():
return False

async def _get_results(
self,
inputs: 'InputType',
on_done: 'CallbackFnType',
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
max_attempts: int = 1,
initial_backoff: float = 0.5,
max_backoff: float = 0.1,
backoff_multiplier: float = 1.5,
**kwargs,
self,
inputs: 'InputType',
on_done: 'CallbackFnType',
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
max_attempts: int = 1,
initial_backoff: float = 0.5,
max_backoff: float = 0.1,
backoff_multiplier: float = 1.5,
**kwargs,
):
"""
:param inputs: the callable
Expand Down Expand Up @@ -110,7 +110,8 @@ async def _get_results(
proto = 'wss' if self.args.tls else 'ws'
url = f'{proto}://{self.args.host}:{self.args.port}/'
iolet = await stack.enter_async_context(
WebsocketClientlet(url=url, logger=self.logger, max_attempts=max_attempts, initial_backoff=initial_backoff,
WebsocketClientlet(url=url, logger=self.logger, max_attempts=max_attempts,
initial_backoff=initial_backoff,
max_backoff=max_backoff, backoff_multiplier=backoff_multiplier, **kwargs)
)

Expand Down Expand Up @@ -149,7 +150,7 @@ def _handle_end_of_iter():
asyncio.create_task(iolet.send_eoi())

def _request_handler(
request: 'Request',
request: 'Request',
) -> 'Tuple[asyncio.Future, Optional[asyncio.Future]]':
"""
For each request in the iterator, we send the `Message` using `iolet.send_message()`.
Expand All @@ -176,6 +177,8 @@ def _request_handler(

receive_task = asyncio.create_task(_receive())

exception_raised = None

if receive_task.done():
raise RuntimeError('receive task not running, can not send messages')
try:
Expand All @@ -191,7 +194,19 @@ def _request_handler(
if self.show_progress:
p_bar.update()
yield response
except Exception as ex:
exception_raised = ex
try:
receive_task.cancel()
except:
raise ex
finally:
if iolet.close_code == status.WS_1011_INTERNAL_ERROR:
raise ConnectionError(iolet.close_message)
await receive_task
try:
await receive_task
except asyncio.CancelledError:
if exception_raised is not None:
raise exception_raised
else:
raise
1 change: 1 addition & 0 deletions jina/clients/request/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,4 @@ def request_generator(
except Exception as ex:
# must be handled here, as grpc channel wont handle Python exception
default_logger.critical(f'inputs is not valid! {ex!r}', exc_info=True)
raise
1 change: 1 addition & 0 deletions jina/clients/request/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,4 @@ async def request_generator(
except Exception as ex:
# must be handled here, as grpc channel wont handle Python exception
default_logger.critical(f'inputs is not valid! {ex!r}', exc_info=True)
raise
13 changes: 12 additions & 1 deletion jina/serve/stream/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ def update_all_handled():
async def end_future():
raise self._EndOfStreaming

async def exception_raise(exception):
raise exception

def callback(future: 'asyncio.Future'):
"""callback to be run after future is completed.
1. Put the future in the result queue.
Expand Down Expand Up @@ -187,7 +190,7 @@ async def handle_floating_responses():
except self._EndOfStreaming:
pass

asyncio.create_task(iterate_requests())
iterate_requests_task = asyncio.create_task(iterate_requests())
handle_floating_task = asyncio.create_task(handle_floating_responses())
self.total_num_floating_tasks_alive += 1

Expand All @@ -196,6 +199,14 @@ def floating_task_done(*args):

handle_floating_task.add_done_callback(floating_task_done)

def iterating_task_done(task):
if task.exception() is not None:
all_requests_handled.set()
future_cancel = asyncio.ensure_future(exception_raise(task.exception()))
result_queue.put_nowait(future_cancel)

iterate_requests_task.add_done_callback(iterating_task_done)

while not all_requests_handled.is_set():
future = await result_queue.get()
try:
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest
from datetime import datetime
from jina import Flow, DocumentArray, Document


class MyOwnException(Exception):
pass


@pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket'])
def test_invalid_input_raise(protocol):
f = Flow(protocol=protocol).add()

try:
with f:
da = DocumentArray([Document(text='hello', tags={'date': datetime.now()})])
try:
f.post(on='/', inputs=da) # process should stop here and raise an exception
except Exception:
raise MyOwnException
assert False
except MyOwnException:
pass

0 comments on commit 573f607

Please sign in to comment.