Skip to content

Commit

Permalink
fix: dubious ownership issue
Browse files Browse the repository at this point in the history
chore: use dev image temporarily

fix: add missing vali dendrite hotkey, catch error properly

refactor(prisma): add createdat/updatedat

perf: add exception handling for miner create task

fix: ensure consistency of datetime formats and objects, cleanup logs
  • Loading branch information
jarvis8x7b committed Oct 16, 2024
1 parent 4087de9 commit 1c16d1d
Show file tree
Hide file tree
Showing 11 changed files with 170 additions and 98 deletions.
8 changes: 8 additions & 0 deletions commons/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,11 @@ class EmptyScores(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)


class CreateTaskFailed(Exception):
"""Exception raised when creating a task fails."""

def __init__(self, message):
self.message = message
super().__init__(self.message)
146 changes: 84 additions & 62 deletions commons/human_feedback/dojo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from bittensor.btlogging import logging as logger

import dojo
from commons.exceptions import CreateTaskFailed
from commons.utils import loaddotenv, set_expire_time
from dojo import get_dojo_api_base_url
from dojo.protocol import FeedbackRequest, MultiScoreCriteria, RankingCriteria

DOJO_API_BASE_URL = get_dojo_api_base_url()
# to be able to get the curlify requests
DEBUG = False
# DEBUG = False


class DojoAPI:
Expand Down Expand Up @@ -57,7 +58,7 @@ async def get_task_results_by_task_id(cls, task_id: str) -> List[Dict] | None:
return task_results

@staticmethod
def serialize_feedback_request(data: FeedbackRequest) -> Dict[str, str]:
def serialize_feedback_request(data: FeedbackRequest):
output = dict(
prompt=data.prompt,
responses=[],
Expand All @@ -83,66 +84,87 @@ async def create_task(
cls,
feedback_request: FeedbackRequest,
):
logger.debug("Creating Task....")
path = f"{DOJO_API_BASE_URL}/api/v1/tasks/create-tasks"
taskData = cls.serialize_feedback_request(feedback_request)
for criteria_type in feedback_request.criteria_types:
if isinstance(criteria_type, RankingCriteria) or isinstance(
criteria_type, MultiScoreCriteria
):
taskData["criteria"].append(
{
**criteria_type.model_dump(),
"options": [
option
for option in criteria_type.model_dump().get("options", [])
],
}
)
else:
logger.error(f"Unrecognized criteria type: {type(criteria_type)}")

expire_at = set_expire_time(dojo.TASK_DEADLINE)

form_body = {
"title": ("", "LLM Code Generation Task"),
"body": ("", feedback_request.prompt),
"expireAt": ("", expire_at),
"taskData": ("", json.dumps([taskData])),
"maxResults": ("", "1"),
}

DOJO_API_KEY = loaddotenv("DOJO_API_KEY")

response = await cls._http_client.post(
path,
files=form_body,
headers={
"x-api-key": DOJO_API_KEY,
},
timeout=15.0,
)
response_text = ""
response_json = {}
try:
path = f"{DOJO_API_BASE_URL}/api/v1/tasks/create-tasks"
taskData = cls.serialize_feedback_request(feedback_request)
for criteria_type in feedback_request.criteria_types:
if isinstance(criteria_type, RankingCriteria) or isinstance(
criteria_type, MultiScoreCriteria
):
taskData["criteria"].append(
{
**criteria_type.model_dump(),
"options": [
option
for option in criteria_type.model_dump().get(
"options", []
)
],
}
)
else:
logger.error(f"Unrecognized criteria type: {type(criteria_type)}")

expire_at = set_expire_time(dojo.TASK_DEADLINE)

form_body = {
"title": ("", "LLM Code Generation Task"),
"body": ("", feedback_request.prompt),
"expireAt": ("", expire_at),
"taskData": ("", json.dumps([taskData])),
"maxResults": ("", "1"),
}

DOJO_API_KEY = loaddotenv("DOJO_API_KEY")

response = await cls._http_client.post(
path,
files=form_body,
headers={
"x-api-key": DOJO_API_KEY,
},
timeout=15.0,
)

if DEBUG is True:
try:
from curlify2 import Curlify

curl_req = Curlify(response.request)
print("CURL REQUEST >>> ")
print(curl_req.to_curl())
except ImportError:
print("Curlify not installed")
except Exception as e:
print("Tried to export create task request as curl, but failed.")
print(f"Exception: {e}")

task_ids = []
if response.status_code == 200:
task_ids = response.json()["body"]
logger.success(f"Successfully created task with\ntask ids:{task_ids}")
else:
response_text = response.text
response_json = response.json()
# if DEBUG is True:
# try:
# from curlify2 import Curlify

# curl_req = Curlify(response.request)
# print("CURL REQUEST >>> ")
# print(curl_req.to_curl())
# except ImportError:
# print("Curlify not installed")
# except Exception as e:
# print("Tried to export create task request as curl, but failed.")
# print(f"Exception: {e}")

task_ids = []
if response.status_code == 200:
task_ids = response.json()["body"]
logger.success(f"Successfully created task with\ntask ids:{task_ids}")
else:
logger.error(
f"Error occurred when trying to create task\nErr:{response.json()['error']}"
)
response.raise_for_status()
return task_ids
except json.JSONDecodeError as e1:
message = f"While trying to create task got JSON decode error: {e1}, response_text: {response_text}"
logger.error(message)
raise CreateTaskFailed("Failed to create task due to JSON decode error")
except httpx.HTTPStatusError as e:
logger.error(
f"Error occurred when trying to create task\nErr:{response.json()['error']}"
f"HTTP error occurred: {e}. Status code: {e.response.status_code}. Response content: {e.response.text}"
)
raise CreateTaskFailed(
f"Failed to create task due to HTTP error: {e}, response_text: {response_text}, response_json: {response_json}"
)
except Exception as e:
raise CreateTaskFailed(
f"Failed to create task due to unexpected error: {e}, response_text: {response_text}, response_json: {response_json}"
)
response.raise_for_status()
return task_ids
35 changes: 32 additions & 3 deletions commons/orm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import json
from datetime import datetime, timezone
from typing import AsyncGenerator, List

import torch
from bittensor.btlogging import logging as logger
from dotenv import find_dotenv, load_dotenv

from commons.exceptions import (
InvalidCompletion,
Expand All @@ -12,7 +14,7 @@
NoNewUnexpiredTasksYet,
UnexpiredTasksAlreadyProcessed,
)
from database.client import transaction
from database.client import connect_db, disconnect_db, transaction
from database.mappers import (
map_child_feedback_request_to_model,
map_completion_response_to_model,
Expand Down Expand Up @@ -74,13 +76,16 @@ async def get_unexpired_tasks(
"parent_request": True,
}
)
# now = datetime(2024, 10, 15, 18, 49, 25, tzinfo=timezone.utc)
now = datetime.now(timezone.utc)

vali_where_query_unprocessed = Feedback_Request_ModelWhereInput(
{
"hotkey": {"in": validator_hotkeys, "mode": "insensitive"},
"child_requests": {"some": {}},
# only check for expire at since miner may lie
"expire_at": {
"gt": datetime.now(timezone.utc),
"gt": now,
},
"is_processed": {"equals": False},
}
Expand All @@ -92,7 +97,7 @@ async def get_unexpired_tasks(
"child_requests": {"some": {}},
# only check for expire at since miner may lie
"expire_at": {
"gt": datetime.now(timezone.utc),
"gt": now,
},
"is_processed": {"equals": True},
}
Expand All @@ -107,6 +112,13 @@ async def get_unexpired_tasks(
where=vali_where_query_processed,
)

logger.info(f"Count of unprocessed tasks: {task_count_unprocessed}")
logger.info(f"Count of processed tasks: {task_count_processed}")

logger.debug(
f"Count of unprocessed tasks: {task_count_unprocessed}, count of processed tasks: {task_count_processed}"
)

if not task_count_unprocessed:
if task_count_processed:
raise UnexpiredTasksAlreadyProcessed(
Expand Down Expand Up @@ -436,3 +448,20 @@ async def get_validator_score() -> torch.Tensor | None:
return None

return torch.tensor(json.loads(score_record.score))


async def _test_get_unexpired_tasks():
load_dotenv(find_dotenv(".env.validator"))
await connect_db()
batch_id = 0
async for task_batch, has_more_batches in ORM.get_unexpired_tasks(
validator_hotkeys=["5Hdf4hSQoLGj4JyJuabTnp85ZYKezLE366SXqkWYjcUw5PfJ"]
):
for task in task_batch:
print(f"Task expire_at: {task.request.expire_at}")
batch_id += 1
await disconnect_db()


if __name__ == "__main__":
asyncio.run(_test_get_unexpired_tasks())
22 changes: 13 additions & 9 deletions commons/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ def get_epoch_time():
return time.time()


def datetime_as_utc(dt: datetime) -> datetime:
return dt.replace(tzinfo=timezone.utc)


def datetime_to_iso8601_str(dt: datetime) -> str:
return dt.replace(tzinfo=timezone.utc).isoformat()


def iso8601_str_to_datetime(dt_str: str) -> datetime:
return datetime.fromisoformat(dt_str).replace(tzinfo=timezone.utc)


def loaddotenv(varname: str):
"""Wrapper to get env variables for sanity checking"""
value = os.getenv(varname)
Expand Down Expand Up @@ -325,10 +337,6 @@ def ttl_get_block(subtensor) -> int:
return subtensor.get_current_block()


def get_current_utc_time_iso():
return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")


def set_expire_time(expire_in_seconds: int) -> str:
"""
Sets the expiration time based on the current UTC time and the given number of seconds.
Expand All @@ -341,16 +349,12 @@ def set_expire_time(expire_in_seconds: int) -> str:
"""
return (
(datetime.now(timezone.utc) + timedelta(seconds=expire_in_seconds))
.replace(microsecond=0, tzinfo=timezone.utc)
.replace(tzinfo=timezone.utc)
.isoformat()
.replace("+00:00", "Z")
)


def datetime_as_utc(datetime: datetime) -> datetime:
return datetime.replace(microsecond=0, tzinfo=timezone.utc)


def is_valid_expiry(expire_at: str) -> bool:
"""
Checks if the given expiry time is not None and falls within a reasonable time period.
Expand Down
18 changes: 9 additions & 9 deletions database/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from loguru import logger

from commons.exceptions import InvalidMinerResponse, InvalidValidatorRequest
from commons.utils import datetime_as_utc
from commons.utils import (
datetime_as_utc,
datetime_to_iso8601_str,
iso8601_str_to_datetime,
)
from database.prisma import Json
from database.prisma.enums import CriteriaTypeEnum
from database.prisma.models import Criteria_Type_Model, Feedback_Request_Model
Expand Down Expand Up @@ -105,7 +109,7 @@ def map_completion_response_to_model(
result = Completion_Response_ModelCreateInput(
completion_id=response.completion_id,
model=response.model,
completion=Json(json.dumps(response.completion)),
completion=Json(json.dumps(response.completion, default=vars)),
rank_id=response.rank_id,
score=response.score,
feedback_request_id=feedback_request_id,
Expand All @@ -122,7 +126,7 @@ def map_parent_feedback_request_to_model(
if not request.expire_at:
raise InvalidValidatorRequest("Expire at is required")

expire_at = datetime.fromisoformat(request.expire_at)
expire_at = iso8601_str_to_datetime(request.expire_at)
if expire_at < datetime.now(timezone.utc):
raise InvalidValidatorRequest("Expire at must be in the future")

Expand Down Expand Up @@ -218,9 +222,7 @@ def map_feedback_request_model_to_feedback_request(
criteria_types=criteria_types,
completion_responses=completion_responses,
dojo_task_id=model.dojo_task_id,
expire_at=model.expire_at.replace(microsecond=0, tzinfo=timezone.utc)
.isoformat()
.replace("+00:00", "Z"),
expire_at=datetime_to_iso8601_str(model.expire_at),
axon=bt.TerminalInfo(hotkey=model.hotkey),
)
else:
Expand All @@ -231,9 +233,7 @@ def map_feedback_request_model_to_feedback_request(
criteria_types=criteria_types,
completion_responses=completion_responses,
dojo_task_id=model.dojo_task_id,
expire_at=model.expire_at.replace(microsecond=0, tzinfo=timezone.utc)
.isoformat()
.replace("+00:00", "Z"),
expire_at=datetime_to_iso8601_str(model.expire_at),
dendrite=bt.TerminalInfo(hotkey=model.hotkey),
ground_truth=ground_truth,
)
Expand Down
Loading

0 comments on commit 1c16d1d

Please sign in to comment.