Skip to content

Commit

Permalink
Merge pull request #72 from tensorplex-labs/perf/add-retries-for-api-…
Browse files Browse the repository at this point in the history
…calls

perf: add retries for dojo api calls
  • Loading branch information
karootplx authored Nov 15, 2024
2 parents e520542 + 97a9b04 commit 91792f6
Showing 1 changed file with 116 additions and 92 deletions.
208 changes: 116 additions & 92 deletions commons/human_feedback/dojo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import json
import os
import random
from typing import Dict, List

import httpx
Expand Down Expand Up @@ -56,15 +58,34 @@ async def get_task_results_by_task_id(cls, task_id: str) -> List[Dict] | None:
# if is_completed is False:
# return

task_results_response = await cls._get_task_results_by_task_id(task_id)
task_results = task_results_response.get("body", {}).get("taskResults")
if task_results is None:
return

if not task_results:
return
max_retries = 5
base_delay = 1

for attempt in range(max_retries):
try:
task_results_response = await cls._get_task_results_by_task_id(task_id)
task_results = task_results_response.get("body", {}).get("taskResults")
if task_results is None or not task_results:
return None
return task_results
except Exception as e:
if attempt < max_retries - 1:
delay = base_delay * 2**attempt + random.uniform(0, 1)
logger.warning(
f"Error occurred while getting task results for task_id {task_id}: {e}. "
f"Retrying in {delay:.2f} seconds..."
)
await asyncio.sleep(delay)
else:
logger.error(
f"Failed to get task results for task_id {task_id} after {max_retries} attempts: {e}"
)
return None

return task_results
logger.error(
f"Failed to get task results for task_id {task_id} after {max_retries} retries"
)
return None

@staticmethod
def serialize_feedback_request(data: FeedbackRequest):
Expand Down Expand Up @@ -95,89 +116,92 @@ async def create_task(
):
response_text = ""
response_json = {}
try:
path = f"{DOJO_API_BASE_URL}/api/v1/tasks/create-tasks"
taskData = cls.serialize_feedback_request(feedback_request)
for criteria_type in feedback_request.criteria_types:
if isinstance(criteria_type, RankingCriteria) or isinstance(
criteria_type, MultiScoreCriteria
):
taskData["criteria"].append(
{
**criteria_type.model_dump(),
"options": [
option
for option in criteria_type.model_dump().get(
"options", []
)
],
}
max_retries = 5
base_delay = 1

for attempt in range(max_retries):
try:
path = f"{DOJO_API_BASE_URL}/api/v1/tasks/create-tasks"
taskData = cls.serialize_feedback_request(feedback_request)
for criteria_type in feedback_request.criteria_types:
if isinstance(criteria_type, RankingCriteria) or isinstance(
criteria_type, MultiScoreCriteria
):
taskData["criteria"].append(
{
**criteria_type.model_dump(),
"options": [
option
for option in criteria_type.model_dump().get(
"options", []
)
],
}
)
else:
logger.error(
f"Unrecognized criteria type: {type(criteria_type)}"
)

expire_at = set_expire_time(dojo.TASK_DEADLINE)

max_results = _get_max_results_param()
form_body = {
"title": ("", "LLM Code Generation Task"),
"body": ("", feedback_request.prompt),
"expireAt": ("", expire_at),
"taskData": ("", json.dumps([taskData])),
"maxResults": ("", str(max_results)),
}

payload_size = sum(len(str(v[1])) for v in form_body.values())
logger.info(f"Payload size: {payload_size} bytes")

DOJO_API_KEY = loaddotenv("DOJO_API_KEY")

response = await cls._http_client.post(
path,
files=form_body,
headers={
"x-api-key": DOJO_API_KEY,
},
timeout=15.0,
)

response_text = response.text
response_json = response.json()

task_ids = []
if response.status_code == 200:
task_ids = response.json()["body"]
logger.success(
f"Successfully created task with\ntask ids:{task_ids}"
)
else:
logger.error(f"Unrecognized criteria type: {type(criteria_type)}")

expire_at = set_expire_time(dojo.TASK_DEADLINE)

max_results = _get_max_results_param()
form_body = {
"title": ("", "LLM Code Generation Task"),
"body": ("", feedback_request.prompt),
"expireAt": ("", expire_at),
"taskData": ("", json.dumps([taskData])),
"maxResults": ("", str(max_results)),
}

payload_size = sum(len(str(v[1])) for v in form_body.values())
logger.info(f"Payload size: {payload_size} bytes")

DOJO_API_KEY = loaddotenv("DOJO_API_KEY")

response = await cls._http_client.post(
path,
files=form_body,
headers={
"x-api-key": DOJO_API_KEY,
},
timeout=15.0,
)

response_text = response.text
response_json = response.json()
# if DEBUG is True:
# try:
# from curlify2 import Curlify

# curl_req = Curlify(response.request)
# print("CURL REQUEST >>> ")
# print(curl_req.to_curl())
# except ImportError:
# print("Curlify not installed")
# except Exception as e:
# print("Tried to export create task request as curl, but failed.")
# print(f"Exception: {e}")

task_ids = []
if response.status_code == 200:
task_ids = response.json()["body"]
logger.success(f"Successfully created task with\ntask ids:{task_ids}")
else:
logger.error(
f"Error occurred when trying to create task\nErr:{response.json()['error']}"
)
response.raise_for_status()
return task_ids
except json.JSONDecodeError as e1:
message = f"While trying to create task got JSON decode error: {e1}, response_text: {response_text}"
logger.error(message)
raise CreateTaskFailed("Failed to create task due to JSON decode error")
except httpx.HTTPStatusError as e:
logger.error(
f"HTTP error occurred: {e}. Status code: {e.response.status_code}. Response content: {e.response.text}"
)
raise CreateTaskFailed(
f"Failed to create task due to HTTP error: {e}, response_text: {response_text}, response_json: {response_json}"
)
except Exception as e:
raise CreateTaskFailed(
f"Failed to create task due to unexpected error: {e}, response_text: {response_text}, response_json: {response_json}"
)
logger.error(
f"Error occurred when trying to create task\nErr:{response.json()['error']}"
)
response.raise_for_status()
return task_ids
except Exception as e:
if attempt < max_retries - 1:
delay = base_delay * 2**attempt + random.uniform(0, 1)
logger.warning(
f"Error occurred: {e}. Retrying in {delay:.2f} seconds..."
)
await asyncio.sleep(delay)
else:
logger.error(f"Error occurred after {max_retries} attempts: {e}")
if isinstance(e, httpx.HTTPStatusError | httpx.RequestError):
raise CreateTaskFailed(
f"Failed to create task after {max_retries} attempts due to HTTP error: {e}"
)
elif isinstance(e, json.JSONDecodeError):
raise CreateTaskFailed(
f"Failed to create task due to JSON decode error: {e}, response_text: {response_text}"
)
else:
raise CreateTaskFailed(
f"Failed to create task due to unexpected error: {e}, response_text: {response_text}, response_json: {response_json}"
)
raise CreateTaskFailed(f"Failed to create task after {max_retries} retries")

0 comments on commit 91792f6

Please sign in to comment.