Skip to content

Commit

Permalink
Merge pull request #167 from wilhelm-lab/patch/async_retry_reorder
Browse files Browse the repository at this point in the history
Patch/async retry reorder
  • Loading branch information
picciama authored Dec 14, 2023
2 parents c9bfc1b + 3ad49ed commit 60ab7ad
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .cookietemple.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ full_name: Victor Giurcoiu
email: victor.giurcoiu@tum.de
project_name: oktoberfest
project_short_description: Public repo oktoberfest
version: 0.5.1
version: 0.5.2
license: MIT
4 changes: 2 additions & 2 deletions .github/release-drafter.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name-template: "0.5.1 🌈" # <<COOKIETEMPLE_FORCE_BUMP>>
tag-template: 0.5.1 # <<COOKIETEMPLE_FORCE_BUMP>>
name-template: "0.5.2 🌈" # <<COOKIETEMPLE_FORCE_BUMP>>
tag-template: 0.5.2 # <<COOKIETEMPLE_FORCE_BUMP>>
exclude-labels:
- "skip-changelog"

Expand Down
2 changes: 1 addition & 1 deletion cookietemple.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.5.1
current_version = 0.5.2

[bumpversion_files_whitelisted]
init_file = oktoberfest/__init__.py
Expand Down
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@
# the built documents.
#
# The short X.Y version.
version = "0.5.1"
version = "0.5.2"
# The full version, including alpha/beta/rc tags.
release = "0.5.1"
release = "0.5.2"

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
Expand Down
2 changes: 1 addition & 1 deletion oktoberfest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
__author__ = """The Oktoberfest development team (Wilhelmlab at Technical University of Munich)"""
__copyright__ = f"Copyright {datetime.now():%Y}, Wilhelmlab at Technical University of Munich"
__license__ = "MIT"
__version__ = "0.5.1"
__version__ = "0.5.2"

import logging.handlers
import sys
Expand Down
67 changes: 48 additions & 19 deletions oktoberfest/predict/koina.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ def __async_predict_batch(
infer_results: Dict[int, Union[InferResult, InferenceServerException]],
request_id: int,
timeout: int = 60000,
retries: int = 3,
):
"""
Perform asynchronous batch inference on the given data using the Koina model.
Expand All @@ -401,13 +402,13 @@ def __async_predict_batch(
:param infer_results: A dictionary to which the results of asynchronous inference will be added.
:param request_id: An identifier for the inference request, used to track the order of completion.
:param timeout: The maximum time (in seconds) to wait for the inference to complete. Defaults to 10 seconds.
:param retries: The maximum number of requests in case of failure
:yield: None, this is to separate async clien infer from checking the result
"""
batch_outputs = self.__get_batch_outputs(self.model_outputs.keys())
batch_inputs = self.__get_batch_inputs(data)

max_requests = 3

for _ in range(max_requests):
for _ in range(retries):
self.client.async_infer(
model_name=self.model_name,
request_id=str(request_id),
Expand All @@ -416,11 +417,9 @@ def __async_predict_batch(
outputs=batch_outputs,
client_timeout=timeout,
)
while infer_results.get(request_id) is None:
time.sleep(0.1)
yield
if isinstance(infer_results.get(request_id), InferResult):
break
del infer_results[request_id]

def predict(
self,
Expand Down Expand Up @@ -480,31 +479,61 @@ def __predict_async(
:param data: A dictionary containing input data for inference. Keys are input names, and values are numpy arrays.
:param disable_progress_bar: If True, disable the progress bar during asynchronous inference. Defaults to False.
:param debug: If True, store raw InferResult / InferServerException dictionary for later analysis.
:raises InferenceServerException: If at least one batch of predictions could not be inferred.
:return: A dictionary containing the model's predictions. Keys are output names, and values are numpy arrays
representing the model's output.
"""
infer_results: Dict[int, Union[InferResult, InferenceServerException]] = {}
tasks = []
for i, data_batch in enumerate(self.__slice_dict(data, self.batchsize)):
self.__async_predict_batch(data_batch, infer_results, request_id=i)

with tqdm(total=i + 1, desc="Getting predictions", disable=disable_progress_bar) as pbar:
while len(infer_results) != i + 1:
pbar.n = len(infer_results)
tasks.append(self.__async_predict_batch(data_batch, infer_results, request_id=i, retries=3))
next(tasks[i])

n_tasks = i + 1
with tqdm(total=n_tasks, desc="Getting predictions", disable=disable_progress_bar) as pbar:
unfinished_tasks = [i for i in range(n_tasks)]
while pbar.n != n_tasks:
time.sleep(0.2)
new_unfinished_tasks = []
for j in unfinished_tasks:
result = infer_results.get(j)
if result is None:
new_unfinished_tasks.append(j)
continue
if isinstance(result, InferenceServerException):
try:
new_unfinished_tasks.append(j)
next(tasks[j])
except StopIteration:
pbar.n += 1
continue
if isinstance(result, InferResult):
pbar.n += 1

unfinished_tasks = new_unfinished_tasks
pbar.refresh()
time.sleep(1)
pbar.n = len(infer_results)
pbar.refresh()

return self.__handle_results(infer_results, debug)

def __handle_results(
self, infer_results: Dict[int, Union[InferResult, InferenceServerException]], debug: bool
) -> Dict[str, np.ndarray]:
"""
Handles the results.
:param infer_results: The dictionary containing the inferred results
:param debug: whether to store the infer_results in the response_dict attribute
:raises InferenceServerException: If at least one batch of predictions could not be inferred.
:return: A dictionary containing the model's predictions. Keys are output names, and values are numpy arrays
representing the model's output.
"""
if debug:
self._response_dict = infer_results

try:
# sort according to request id
infer_results_to_return = [
self.__extract_predictions(infer_results[i]) for i in np.argsort(list(infer_results.keys()))
]
infer_results_to_return = [self.__extract_predictions(infer_results[i]) for i in range(len(infer_results))]
return self.__merge_list_dict_array(infer_results_to_return)
except AttributeError:
for res in infer_results.values():
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "oktoberfest"
version = "0.5.1" # <<COOKIETEMPLE_FORCE_BUMP>>
version = "0.5.2" # <<COOKIETEMPLE_FORCE_BUMP>>
description = "Public repo oktoberfest"
authors = ["Victor Giurcoiu <victor.giurcoiu@tum.de>"]
license = "MIT"
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/configs/ce_calib_ransac.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
},
"output": "../data/tims_calib_out/",
"models": {
"intensity": "Prosit_2023_intensity_TOF",
"intensity": "Prosit_2023_intensity_timsTOF",
"irt": "Prosit_2019_irt"
},
"prediction_server": "koina.proteomicsdb.org:443",
Expand Down

0 comments on commit 60ab7ad

Please sign in to comment.