From 461214aef4b51f42574d920effeaa7cbb58f5a92 Mon Sep 17 00:00:00 2001 From: karootplx Date: Thu, 14 Nov 2024 20:16:04 +0800 Subject: [PATCH 1/3] perf: add retries for dojo api calls --- commons/human_feedback/dojo.py | 207 ++++++++++++++++++--------------- 1 file changed, 115 insertions(+), 92 deletions(-) diff --git a/commons/human_feedback/dojo.py b/commons/human_feedback/dojo.py index 61fc3e7..c6a609a 100644 --- a/commons/human_feedback/dojo.py +++ b/commons/human_feedback/dojo.py @@ -1,3 +1,4 @@ +import asyncio import json import os from typing import Dict, List @@ -56,15 +57,34 @@ async def get_task_results_by_task_id(cls, task_id: str) -> List[Dict] | None: # if is_completed is False: # return - task_results_response = await cls._get_task_results_by_task_id(task_id) - task_results = task_results_response.get("body", {}).get("taskResults") - if task_results is None: - return - - if not task_results: - return + max_retries = 3 + base_delay = 1 + + for attempt in range(max_retries): + try: + task_results_response = await cls._get_task_results_by_task_id(task_id) + task_results = task_results_response.get("body", {}).get("taskResults") + if task_results is None or not task_results: + return None + return task_results + except Exception as e: + if attempt < max_retries - 1: + delay = base_delay * 2**attempt + logger.warning( + f"Error occurred while getting task results for task_id {task_id}: {e}. " + f"Retrying in {delay:.2f} seconds..." + ) + await asyncio.sleep(delay) + else: + logger.error( + f"Failed to get task results for task_id {task_id} after {max_retries} attempts: {e}" + ) + return None - return task_results + logger.error( + f"Failed to get task results for task_id {task_id} after {max_retries} retries" + ) + return None @staticmethod def serialize_feedback_request(data: FeedbackRequest): @@ -95,89 +115,92 @@ async def create_task( ): response_text = "" response_json = {} - try: - path = f"{DOJO_API_BASE_URL}/api/v1/tasks/create-tasks" - taskData = cls.serialize_feedback_request(feedback_request) - for criteria_type in feedback_request.criteria_types: - if isinstance(criteria_type, RankingCriteria) or isinstance( - criteria_type, MultiScoreCriteria - ): - taskData["criteria"].append( - { - **criteria_type.model_dump(), - "options": [ - option - for option in criteria_type.model_dump().get( - "options", [] - ) - ], - } + max_retries = 5 + base_delay = 1 + + for attempt in range(max_retries): + try: + path = f"{DOJO_API_BASE_URL}/api/v1/tasks/create-tasks" + taskData = cls.serialize_feedback_request(feedback_request) + for criteria_type in feedback_request.criteria_types: + if isinstance(criteria_type, RankingCriteria) or isinstance( + criteria_type, MultiScoreCriteria + ): + taskData["criteria"].append( + { + **criteria_type.model_dump(), + "options": [ + option + for option in criteria_type.model_dump().get( + "options", [] + ) + ], + } + ) + else: + logger.error( + f"Unrecognized criteria type: {type(criteria_type)}" + ) + + expire_at = set_expire_time(dojo.TASK_DEADLINE) + + max_results = _get_max_results_param() + form_body = { + "title": ("", "LLM Code Generation Task"), + "body": ("", feedback_request.prompt), + "expireAt": ("", expire_at), + "taskData": ("", json.dumps([taskData])), + "maxResults": ("", str(max_results)), + } + + payload_size = sum(len(str(v[1])) for v in form_body.values()) + logger.info(f"Payload size: {payload_size} bytes") + + DOJO_API_KEY = loaddotenv("DOJO_API_KEY") + + response = await cls._http_client.post( + path, + files=form_body, + headers={ + "x-api-key": DOJO_API_KEY, + }, + timeout=15.0, + ) + + response_text = response.text + response_json = response.json() + + task_ids = [] + if response.status_code == 200: + task_ids = response.json()["body"] + logger.success( + f"Successfully created task with\ntask ids:{task_ids}" ) else: - logger.error(f"Unrecognized criteria type: {type(criteria_type)}") - - expire_at = set_expire_time(dojo.TASK_DEADLINE) - - max_results = _get_max_results_param() - form_body = { - "title": ("", "LLM Code Generation Task"), - "body": ("", feedback_request.prompt), - "expireAt": ("", expire_at), - "taskData": ("", json.dumps([taskData])), - "maxResults": ("", str(max_results)), - } - - payload_size = sum(len(str(v[1])) for v in form_body.values()) - logger.info(f"Payload size: {payload_size} bytes") - - DOJO_API_KEY = loaddotenv("DOJO_API_KEY") - - response = await cls._http_client.post( - path, - files=form_body, - headers={ - "x-api-key": DOJO_API_KEY, - }, - timeout=15.0, - ) - - response_text = response.text - response_json = response.json() - # if DEBUG is True: - # try: - # from curlify2 import Curlify - - # curl_req = Curlify(response.request) - # print("CURL REQUEST >>> ") - # print(curl_req.to_curl()) - # except ImportError: - # print("Curlify not installed") - # except Exception as e: - # print("Tried to export create task request as curl, but failed.") - # print(f"Exception: {e}") - - task_ids = [] - if response.status_code == 200: - task_ids = response.json()["body"] - logger.success(f"Successfully created task with\ntask ids:{task_ids}") - else: - logger.error( - f"Error occurred when trying to create task\nErr:{response.json()['error']}" - ) - response.raise_for_status() - return task_ids - except json.JSONDecodeError as e1: - message = f"While trying to create task got JSON decode error: {e1}, response_text: {response_text}" - logger.error(message) - raise CreateTaskFailed("Failed to create task due to JSON decode error") - except httpx.HTTPStatusError as e: - logger.error( - f"HTTP error occurred: {e}. Status code: {e.response.status_code}. Response content: {e.response.text}" - ) - raise CreateTaskFailed( - f"Failed to create task due to HTTP error: {e}, response_text: {response_text}, response_json: {response_json}" - ) - except Exception as e: - raise CreateTaskFailed( - f"Failed to create task due to unexpected error: {e}, response_text: {response_text}, response_json: {response_json}" - ) + logger.error( + f"Error occurred when trying to create task\nErr:{response.json()['error']}" + ) + response.raise_for_status() + return task_ids + except Exception as e: + if attempt < max_retries - 1: + delay = base_delay * 2**attempt + logger.warning( + f"Error occurred: {e}. Retrying in {delay:.2f} seconds..." + ) + await asyncio.sleep(delay) + else: + logger.error(f"Error occurred after {max_retries} attempts: {e}") + if isinstance(e, httpx.HTTPStatusError | httpx.RequestError): + raise CreateTaskFailed( + f"Failed to create task after {max_retries} attempts due to HTTP error: {e}" + ) + elif isinstance(e, json.JSONDecodeError): + raise CreateTaskFailed( + f"Failed to create task due to JSON decode error: {e}, response_text: {response_text}" + ) + else: + raise CreateTaskFailed( + f"Failed to create task due to unexpected error: {e}, response_text: {response_text}, response_json: {response_json}" + ) + raise CreateTaskFailed(f"Failed to create task after {max_retries} retries") From 69e19742f23fd1a2ac81f197f4e2f07fe9018049 Mon Sep 17 00:00:00 2001 From: karootplx Date: Thu, 14 Nov 2024 22:11:29 +0800 Subject: [PATCH 2/3] perf: extend delay --- commons/human_feedback/dojo.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/commons/human_feedback/dojo.py b/commons/human_feedback/dojo.py index c6a609a..bd3aca2 100644 --- a/commons/human_feedback/dojo.py +++ b/commons/human_feedback/dojo.py @@ -1,6 +1,7 @@ import asyncio import json import os +import random from typing import Dict, List import httpx @@ -57,7 +58,7 @@ async def get_task_results_by_task_id(cls, task_id: str) -> List[Dict] | None: # if is_completed is False: # return - max_retries = 3 + max_retries = 5 base_delay = 1 for attempt in range(max_retries): @@ -184,7 +185,7 @@ async def create_task( return task_ids except Exception as e: if attempt < max_retries - 1: - delay = base_delay * 2**attempt + delay = base_delay * 2**attempt + random.uniform(0, 1) logger.warning( f"Error occurred: {e}. Retrying in {delay:.2f} seconds..." ) From 97a9b0417106c4e3cfc11501c9408f15574daf23 Mon Sep 17 00:00:00 2001 From: karootplx Date: Fri, 15 Nov 2024 05:32:10 +0800 Subject: [PATCH 3/3] perf: extend delay --- commons/human_feedback/dojo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/commons/human_feedback/dojo.py b/commons/human_feedback/dojo.py index bd3aca2..f049eb1 100644 --- a/commons/human_feedback/dojo.py +++ b/commons/human_feedback/dojo.py @@ -70,7 +70,7 @@ async def get_task_results_by_task_id(cls, task_id: str) -> List[Dict] | None: return task_results except Exception as e: if attempt < max_retries - 1: - delay = base_delay * 2**attempt + delay = base_delay * 2**attempt + random.uniform(0, 1) logger.warning( f"Error occurred while getting task results for task_id {task_id}: {e}. " f"Retrying in {delay:.2f} seconds..."