Skip to content

Commit

Permalink
fix: fix some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Jul 23, 2024
1 parent dbd278f commit ef12539
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 30 deletions.
2 changes: 1 addition & 1 deletion jina/clients/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ async def _get_results(*args, **kwargs):
results_in_order=results_in_order,
stream=stream,
prefetch=prefetch,
return_type=return_type,
on=on,
**kwargs,
)
Expand Down Expand Up @@ -507,7 +508,6 @@ async def post(
c.continue_on_error = continue_on_error

parameters = _include_results_field_in_param(parameters)

async for result in c._get_results(
on=on,
inputs=inputs,
Expand Down
1 change: 0 additions & 1 deletion jina/serve/runtimes/worker/http_fastapi_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ async def post(body: input_model, response: Response):
docs_response = resp.docs.to_dict()
else:
docs_response = resp.docs

ret = output_model(data=docs_response, parameters=resp.parameters)

return ret
Expand Down
33 changes: 19 additions & 14 deletions jina/serve/runtimes/worker/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def call_handle(request):
'is_generator'
]

return self.process_single_data(request, None, is_generator=is_generator)
return self.process_single_data(request, None, http=True, is_generator=is_generator)

Check warning on line 180 in jina/serve/runtimes/worker/request_handling.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/worker/request_handling.py#L180

Added line #L180 was not covered by tests

app = get_fastapi_app(
request_models_map=request_models_map, caller=call_handle, **kwargs
Expand All @@ -201,7 +201,7 @@ def call_handle(request):
'is_generator'
]

return self.process_single_data(request, None, is_generator=is_generator)
return self.process_single_data(request, None, http=True, is_generator=is_generator)

Check warning on line 204 in jina/serve/runtimes/worker/request_handling.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/worker/request_handling.py#L204

Added line #L204 was not covered by tests

app = get_fastapi_app(
request_models_map=request_models_map, caller=call_handle, **kwargs
Expand Down Expand Up @@ -548,7 +548,7 @@ def _record_response_size_monitoring(self, requests):
requests[0].nbytes, attributes=attributes
)

def _set_result(self, requests, return_data, docs):
def _set_result(self, requests, return_data, docs, http=False):
# assigning result back to request
if return_data is not None:
if isinstance(return_data, DocumentArray):
Expand All @@ -568,10 +568,12 @@ def _set_result(self, requests, return_data, docs):
f'The return type must be DocList / Dict / `None`, '
f'but getting {return_data!r}'
)

WorkerRequestHandler.replace_docs(
requests[0], docs, self.args.output_array_type
)
if not http:
WorkerRequestHandler.replace_docs(
requests[0], docs, self.args.output_array_type
)
else:
requests[0].direct_docs = docs

Check warning on line 576 in jina/serve/runtimes/worker/request_handling.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/worker/request_handling.py#L576

Added line #L576 was not covered by tests
return docs

def _setup_req_doc_array_cls(self, requests, exec_endpoint, is_response=False):
Expand Down Expand Up @@ -659,11 +661,12 @@ async def handle_generator(
)

async def handle(
self, requests: List['DataRequest'], tracing_context: Optional['Context'] = None
self, requests: List['DataRequest'], http=False, tracing_context: Optional['Context'] = None
) -> DataRequest:
"""Initialize private parameters and execute private loading functions.
:param requests: The messages to handle containing a DataRequest
:param http: Flag indicating if it is used by the HTTP server for some optims
:param tracing_context: Optional OpenTelemetry tracing context from the originating request.
:returns: the processed message
"""
Expand Down Expand Up @@ -721,7 +724,7 @@ async def handle(
docs_map=docs_map,
tracing_context=tracing_context,
)
_ = self._set_result(requests, return_data, docs)
_ = self._set_result(requests, return_data, docs, http=http)

for req in requests:
req.add_executor(self.deployment_name)
Expand Down Expand Up @@ -909,18 +912,19 @@ def reduce_requests(requests: List['DataRequest']) -> 'DataRequest':

# serving part
async def process_single_data(
self, request: DataRequest, context, is_generator: bool = False
self, request: DataRequest, context, http: bool = False, is_generator: bool = False
) -> DataRequest:
"""
Process the received requests and return the result as a new request
:param request: the data request to process
:param context: grpc context
:param http: Flag indicating if it is used by the HTTP server for some optims
:param is_generator: whether the request should be handled with streaming
:returns: the response request
"""
self.logger.debug('recv a process_single_data request')
return await self.process_data([request], context, is_generator=is_generator)
return await self.process_data([request], context, http=http, is_generator=is_generator)

async def stream_doc(
self, request: SingleDocumentRequest, context: 'grpc.aio.ServicerContext'
Expand Down Expand Up @@ -1065,13 +1069,14 @@ def _extract_tracing_context(
return None

async def process_data(
self, requests: List[DataRequest], context, is_generator: bool = False
self, requests: List[DataRequest], context, http=False, is_generator: bool = False
) -> DataRequest:
"""
Process the received requests and return the result as a new request
:param requests: the data requests to process
:param context: grpc context
:param http: Flag indicating if it is used by the HTTP server for some optims
:param is_generator: whether the request should be handled with streaming
:returns: the response request
"""
Expand All @@ -1094,11 +1099,11 @@ async def process_data(

if is_generator:
result = await self.handle_generator(
requests=requests, tracing_context=tracing_context
requests=requests,tracing_context=tracing_context
)
else:
result = await self.handle(
requests=requests, tracing_context=tracing_context
requests=requests, http=http, tracing_context=tracing_context
)

if self._successful_requests_metrics:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def _assert_all_docs_processed(port, num_docs, endpoint):
target=f'0.0.0.0:{port}',
endpoint=endpoint,
)
docs = resp.data.docs
docs = resp.docs
assert docs.texts == ['long timeout' for _ in range(num_docs)]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_flow_returned_collect(protocol, port_generator):
def validate_func(resp):
num_evaluations = 0
scores = set()
for doc in resp.data.docs:
for doc in resp.docs:
num_evaluations += len(doc.evaluations)
scores.add(doc.evaluations['evaluate'].value)
assert num_evaluations == 1
Expand Down
24 changes: 12 additions & 12 deletions tests/unit/serve/dynamic_batching/test_batch_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async def foo(docs, **kwargs):
three_data_requests = [DataRequest() for _ in range(3)]
for req in three_data_requests:
req.data.docs = DocumentArray.empty(1)
assert req.data.docs[0].text == ''
assert req.docs[0].text == ''

async def process_request(req):
q = await bq.push(req)
Expand All @@ -42,20 +42,20 @@ async def process_request(req):
assert time_spent >= 2000
# Test that since no more docs arrived, the function was triggerred after timeout
for resp in responses:
assert resp.data.docs[0].text == 'Done'
assert resp.docs[0].text == 'Done'

four_data_requests = [DataRequest() for _ in range(4)]
for req in four_data_requests:
req.data.docs = DocumentArray.empty(1)
assert req.data.docs[0].text == ''
assert req.docs[0].text == ''
init_time = time.time()
tasks = [asyncio.create_task(process_request(req)) for req in four_data_requests]
responses = await asyncio.gather(*tasks)
time_spent = (time.time() - init_time) * 1000
assert time_spent < 2000
# Test that since no more docs arrived, the function was triggerred after timeout
for resp in responses:
assert resp.data.docs[0].text == 'Done'
assert resp.docs[0].text == 'Done'

await bq.close()

Expand Down Expand Up @@ -135,7 +135,7 @@ async def foo(docs, **kwargs):
data_requests = [DataRequest() for _ in range(3)]
for req in data_requests:
req.data.docs = DocumentArray.empty(10) # 30 docs in total
assert req.data.docs[0].text == ''
assert req.docs[0].text == ''

async def process_request(req):
q = await bq.push(req)
Expand All @@ -150,7 +150,7 @@ async def process_request(req):
assert time_spent < 2000
# Test that since no more docs arrived, the function was triggerred after timeout
for resp in responses:
assert resp.data.docs[0].text == 'Done'
assert resp.docs[0].text == 'Done'

await bq.close()

Expand Down Expand Up @@ -196,9 +196,9 @@ async def process_request(req):
assert isinstance(item, Exception)
for i, req in enumerate(data_requests):
if i not in BAD_REQUEST_IDX:
assert req.data.docs[0].text == f'{i} Processed'
assert req.docs[0].text == f'{i} Processed'
else:
assert req.data.docs[0].text == 'Bad'
assert req.docs[0].text == 'Bad'


@pytest.mark.asyncio
Expand Down Expand Up @@ -227,7 +227,7 @@ async def foo(docs, **kwargs):

data_requests = [DataRequest() for _ in range(35)]
for i, req in enumerate(data_requests):
req.data.docs = DocumentArray(
req.docs = DocumentArray(
Document(text='' if i not in TRIGGER_BAD_REQUEST_IDX else 'Bad')
)

Expand All @@ -246,11 +246,11 @@ async def process_request(req):
assert isinstance(item, Exception)
for i, req in enumerate(data_requests):
if i not in EXPECTED_BAD_REQUESTS:
assert req.data.docs[0].text == 'Processed'
assert req.docs[0].text == 'Processed'
elif i in TRIGGER_BAD_REQUEST_IDX:
assert req.data.docs[0].text == 'Bad'
assert req.docs[0].text == 'Bad'
else:
assert req.data.docs[0].text == ''
assert req.docs[0].text == ''


@pytest.mark.asyncio
Expand Down

0 comments on commit ef12539

Please sign in to comment.