From 1c16d1d83f61cd2407f7cedb4788aa545da91246 Mon Sep 17 00:00:00 2001 From: jarvis8x7b <157810922+jarvis8x7b@users.noreply.github.com> Date: Tue, 15 Oct 2024 23:33:55 +0800 Subject: [PATCH] fix: dubious ownership issue chore: use dev image temporarily fix: add missing vali dendrite hotkey, catch error properly refactor(prisma): add createdat/updatedat perf: add exception handling for miner create task fix: ensure consistency of datetime formats and objects, cleanup logs --- commons/exceptions.py | 8 ++ commons/human_feedback/dojo.py | 146 +++++++++++++++++++-------------- commons/orm.py | 35 +++++++- commons/utils.py | 22 +++-- database/mappers.py | 18 ++-- docker-compose.miner.yaml | 11 +-- docker-compose.validator.yaml | 5 +- docker/Dockerfile | 2 + neurons/miner.py | 1 + neurons/validator.py | 15 ++-- schema.prisma | 5 +- 11 files changed, 170 insertions(+), 98 deletions(-) diff --git a/commons/exceptions.py b/commons/exceptions.py index a1d4d3f1..26686549 100644 --- a/commons/exceptions.py +++ b/commons/exceptions.py @@ -52,3 +52,11 @@ class EmptyScores(Exception): def __init__(self, message): self.message = message super().__init__(self.message) + + +class CreateTaskFailed(Exception): + """Exception raised when creating a task fails.""" + + def __init__(self, message): + self.message = message + super().__init__(self.message) diff --git a/commons/human_feedback/dojo.py b/commons/human_feedback/dojo.py index 07b8c903..60bb7076 100644 --- a/commons/human_feedback/dojo.py +++ b/commons/human_feedback/dojo.py @@ -5,13 +5,14 @@ from bittensor.btlogging import logging as logger import dojo +from commons.exceptions import CreateTaskFailed from commons.utils import loaddotenv, set_expire_time from dojo import get_dojo_api_base_url from dojo.protocol import FeedbackRequest, MultiScoreCriteria, RankingCriteria DOJO_API_BASE_URL = get_dojo_api_base_url() # to be able to get the curlify requests -DEBUG = False +# DEBUG = False class DojoAPI: @@ -57,7 +58,7 @@ async def get_task_results_by_task_id(cls, task_id: str) -> List[Dict] | None: return task_results @staticmethod - def serialize_feedback_request(data: FeedbackRequest) -> Dict[str, str]: + def serialize_feedback_request(data: FeedbackRequest): output = dict( prompt=data.prompt, responses=[], @@ -83,66 +84,87 @@ async def create_task( cls, feedback_request: FeedbackRequest, ): - logger.debug("Creating Task....") - path = f"{DOJO_API_BASE_URL}/api/v1/tasks/create-tasks" - taskData = cls.serialize_feedback_request(feedback_request) - for criteria_type in feedback_request.criteria_types: - if isinstance(criteria_type, RankingCriteria) or isinstance( - criteria_type, MultiScoreCriteria - ): - taskData["criteria"].append( - { - **criteria_type.model_dump(), - "options": [ - option - for option in criteria_type.model_dump().get("options", []) - ], - } - ) - else: - logger.error(f"Unrecognized criteria type: {type(criteria_type)}") - - expire_at = set_expire_time(dojo.TASK_DEADLINE) - - form_body = { - "title": ("", "LLM Code Generation Task"), - "body": ("", feedback_request.prompt), - "expireAt": ("", expire_at), - "taskData": ("", json.dumps([taskData])), - "maxResults": ("", "1"), - } - - DOJO_API_KEY = loaddotenv("DOJO_API_KEY") - - response = await cls._http_client.post( - path, - files=form_body, - headers={ - "x-api-key": DOJO_API_KEY, - }, - timeout=15.0, - ) + response_text = "" + response_json = {} + try: + path = f"{DOJO_API_BASE_URL}/api/v1/tasks/create-tasks" + taskData = cls.serialize_feedback_request(feedback_request) + for criteria_type in feedback_request.criteria_types: + if isinstance(criteria_type, RankingCriteria) or isinstance( + criteria_type, MultiScoreCriteria + ): + taskData["criteria"].append( + { + **criteria_type.model_dump(), + "options": [ + option + for option in criteria_type.model_dump().get( + "options", [] + ) + ], + } + ) + else: + logger.error(f"Unrecognized criteria type: {type(criteria_type)}") + + expire_at = set_expire_time(dojo.TASK_DEADLINE) + + form_body = { + "title": ("", "LLM Code Generation Task"), + "body": ("", feedback_request.prompt), + "expireAt": ("", expire_at), + "taskData": ("", json.dumps([taskData])), + "maxResults": ("", "1"), + } + + DOJO_API_KEY = loaddotenv("DOJO_API_KEY") + + response = await cls._http_client.post( + path, + files=form_body, + headers={ + "x-api-key": DOJO_API_KEY, + }, + timeout=15.0, + ) - if DEBUG is True: - try: - from curlify2 import Curlify - - curl_req = Curlify(response.request) - print("CURL REQUEST >>> ") - print(curl_req.to_curl()) - except ImportError: - print("Curlify not installed") - except Exception as e: - print("Tried to export create task request as curl, but failed.") - print(f"Exception: {e}") - - task_ids = [] - if response.status_code == 200: - task_ids = response.json()["body"] - logger.success(f"Successfully created task with\ntask ids:{task_ids}") - else: + response_text = response.text + response_json = response.json() + # if DEBUG is True: + # try: + # from curlify2 import Curlify + + # curl_req = Curlify(response.request) + # print("CURL REQUEST >>> ") + # print(curl_req.to_curl()) + # except ImportError: + # print("Curlify not installed") + # except Exception as e: + # print("Tried to export create task request as curl, but failed.") + # print(f"Exception: {e}") + + task_ids = [] + if response.status_code == 200: + task_ids = response.json()["body"] + logger.success(f"Successfully created task with\ntask ids:{task_ids}") + else: + logger.error( + f"Error occurred when trying to create task\nErr:{response.json()['error']}" + ) + response.raise_for_status() + return task_ids + except json.JSONDecodeError as e1: + message = f"While trying to create task got JSON decode error: {e1}, response_text: {response_text}" + logger.error(message) + raise CreateTaskFailed("Failed to create task due to JSON decode error") + except httpx.HTTPStatusError as e: logger.error( - f"Error occurred when trying to create task\nErr:{response.json()['error']}" + f"HTTP error occurred: {e}. Status code: {e.response.status_code}. Response content: {e.response.text}" + ) + raise CreateTaskFailed( + f"Failed to create task due to HTTP error: {e}, response_text: {response_text}, response_json: {response_json}" + ) + except Exception as e: + raise CreateTaskFailed( + f"Failed to create task due to unexpected error: {e}, response_text: {response_text}, response_json: {response_json}" ) - response.raise_for_status() - return task_ids diff --git a/commons/orm.py b/commons/orm.py index cd831a48..ddb6f632 100644 --- a/commons/orm.py +++ b/commons/orm.py @@ -1,9 +1,11 @@ +import asyncio import json from datetime import datetime, timezone from typing import AsyncGenerator, List import torch from bittensor.btlogging import logging as logger +from dotenv import find_dotenv, load_dotenv from commons.exceptions import ( InvalidCompletion, @@ -12,7 +14,7 @@ NoNewUnexpiredTasksYet, UnexpiredTasksAlreadyProcessed, ) -from database.client import transaction +from database.client import connect_db, disconnect_db, transaction from database.mappers import ( map_child_feedback_request_to_model, map_completion_response_to_model, @@ -74,13 +76,16 @@ async def get_unexpired_tasks( "parent_request": True, } ) + # now = datetime(2024, 10, 15, 18, 49, 25, tzinfo=timezone.utc) + now = datetime.now(timezone.utc) + vali_where_query_unprocessed = Feedback_Request_ModelWhereInput( { "hotkey": {"in": validator_hotkeys, "mode": "insensitive"}, "child_requests": {"some": {}}, # only check for expire at since miner may lie "expire_at": { - "gt": datetime.now(timezone.utc), + "gt": now, }, "is_processed": {"equals": False}, } @@ -92,7 +97,7 @@ async def get_unexpired_tasks( "child_requests": {"some": {}}, # only check for expire at since miner may lie "expire_at": { - "gt": datetime.now(timezone.utc), + "gt": now, }, "is_processed": {"equals": True}, } @@ -107,6 +112,13 @@ async def get_unexpired_tasks( where=vali_where_query_processed, ) + logger.info(f"Count of unprocessed tasks: {task_count_unprocessed}") + logger.info(f"Count of processed tasks: {task_count_processed}") + + logger.debug( + f"Count of unprocessed tasks: {task_count_unprocessed}, count of processed tasks: {task_count_processed}" + ) + if not task_count_unprocessed: if task_count_processed: raise UnexpiredTasksAlreadyProcessed( @@ -436,3 +448,20 @@ async def get_validator_score() -> torch.Tensor | None: return None return torch.tensor(json.loads(score_record.score)) + + +async def _test_get_unexpired_tasks(): + load_dotenv(find_dotenv(".env.validator")) + await connect_db() + batch_id = 0 + async for task_batch, has_more_batches in ORM.get_unexpired_tasks( + validator_hotkeys=["5Hdf4hSQoLGj4JyJuabTnp85ZYKezLE366SXqkWYjcUw5PfJ"] + ): + for task in task_batch: + print(f"Task expire_at: {task.request.expire_at}") + batch_id += 1 + await disconnect_db() + + +if __name__ == "__main__": + asyncio.run(_test_get_unexpired_tasks()) diff --git a/commons/utils.py b/commons/utils.py index d76024d8..2a01dcd9 100644 --- a/commons/utils.py +++ b/commons/utils.py @@ -28,6 +28,18 @@ def get_epoch_time(): return time.time() +def datetime_as_utc(dt: datetime) -> datetime: + return dt.replace(tzinfo=timezone.utc) + + +def datetime_to_iso8601_str(dt: datetime) -> str: + return dt.replace(tzinfo=timezone.utc).isoformat() + + +def iso8601_str_to_datetime(dt_str: str) -> datetime: + return datetime.fromisoformat(dt_str).replace(tzinfo=timezone.utc) + + def loaddotenv(varname: str): """Wrapper to get env variables for sanity checking""" value = os.getenv(varname) @@ -325,10 +337,6 @@ def ttl_get_block(subtensor) -> int: return subtensor.get_current_block() -def get_current_utc_time_iso(): - return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") - - def set_expire_time(expire_in_seconds: int) -> str: """ Sets the expiration time based on the current UTC time and the given number of seconds. @@ -341,16 +349,12 @@ def set_expire_time(expire_in_seconds: int) -> str: """ return ( (datetime.now(timezone.utc) + timedelta(seconds=expire_in_seconds)) - .replace(microsecond=0, tzinfo=timezone.utc) + .replace(tzinfo=timezone.utc) .isoformat() .replace("+00:00", "Z") ) -def datetime_as_utc(datetime: datetime) -> datetime: - return datetime.replace(microsecond=0, tzinfo=timezone.utc) - - def is_valid_expiry(expire_at: str) -> bool: """ Checks if the given expiry time is not None and falls within a reasonable time period. diff --git a/database/mappers.py b/database/mappers.py index b8c6a517..4c0c750c 100644 --- a/database/mappers.py +++ b/database/mappers.py @@ -5,7 +5,11 @@ from loguru import logger from commons.exceptions import InvalidMinerResponse, InvalidValidatorRequest -from commons.utils import datetime_as_utc +from commons.utils import ( + datetime_as_utc, + datetime_to_iso8601_str, + iso8601_str_to_datetime, +) from database.prisma import Json from database.prisma.enums import CriteriaTypeEnum from database.prisma.models import Criteria_Type_Model, Feedback_Request_Model @@ -105,7 +109,7 @@ def map_completion_response_to_model( result = Completion_Response_ModelCreateInput( completion_id=response.completion_id, model=response.model, - completion=Json(json.dumps(response.completion)), + completion=Json(json.dumps(response.completion, default=vars)), rank_id=response.rank_id, score=response.score, feedback_request_id=feedback_request_id, @@ -122,7 +126,7 @@ def map_parent_feedback_request_to_model( if not request.expire_at: raise InvalidValidatorRequest("Expire at is required") - expire_at = datetime.fromisoformat(request.expire_at) + expire_at = iso8601_str_to_datetime(request.expire_at) if expire_at < datetime.now(timezone.utc): raise InvalidValidatorRequest("Expire at must be in the future") @@ -218,9 +222,7 @@ def map_feedback_request_model_to_feedback_request( criteria_types=criteria_types, completion_responses=completion_responses, dojo_task_id=model.dojo_task_id, - expire_at=model.expire_at.replace(microsecond=0, tzinfo=timezone.utc) - .isoformat() - .replace("+00:00", "Z"), + expire_at=datetime_to_iso8601_str(model.expire_at), axon=bt.TerminalInfo(hotkey=model.hotkey), ) else: @@ -231,9 +233,7 @@ def map_feedback_request_model_to_feedback_request( criteria_types=criteria_types, completion_responses=completion_responses, dojo_task_id=model.dojo_task_id, - expire_at=model.expire_at.replace(microsecond=0, tzinfo=timezone.utc) - .isoformat() - .replace("+00:00", "Z"), + expire_at=datetime_to_iso8601_str(model.expire_at), dendrite=bt.TerminalInfo(hotkey=model.hotkey), ground_truth=ground_truth, ) diff --git a/docker-compose.miner.yaml b/docker-compose.miner.yaml index 53ffa6ad..42a40e7b 100644 --- a/docker-compose.miner.yaml +++ b/docker-compose.miner.yaml @@ -167,7 +167,7 @@ services: logging: *default-logging dojo-cli: - image: ghcr.io/tensorplex-labs/dojo:main + image: ghcr.io/tensorplex-labs/dojo:dev volumes: - ./:/app - ./.env.miner:/app/.env @@ -181,7 +181,7 @@ services: # ============== TEST NET ============== # miner-testnet-decentralised: - image: ghcr.io/tensorplex-labs/dojo:main + image: ghcr.io/tensorplex-labs/dojo:dev working_dir: /app env_file: - .env.miner @@ -206,7 +206,7 @@ services: - external miner-testnet-centralised: - image: ghcr.io/tensorplex-labs/dojo:main + image: ghcr.io/tensorplex-labs/dojo:dev working_dir: /app env_file: - .env.miner @@ -228,7 +228,7 @@ services: # ============== MAIN NET ============== # miner-mainnet-decentralised: - image: ghcr.io/tensorplex-labs/dojo:main + image: ghcr.io/tensorplex-labs/dojo:dev working_dir: /app env_file: - .env.miner @@ -254,7 +254,8 @@ services: logging: *default-logging miner-mainnet-centralised: - image: ghcr.io/tensorplex-labs/dojo:main + # TODO @dev regex change later + image: ghcr.io/tensorplex-labs/dojo:dev working_dir: /app env_file: - .env.miner diff --git a/docker-compose.validator.yaml b/docker-compose.validator.yaml index 0cb6b891..0fb50c99 100644 --- a/docker-compose.validator.yaml +++ b/docker-compose.validator.yaml @@ -118,7 +118,7 @@ services: # ============== TEST NET ============== # validator-testnet: - image: ghcr.io/tensorplex-labs/dojo:main + image: ghcr.io/tensorplex-labs/dojo:dev working_dir: /app env_file: - .env.validator @@ -150,7 +150,8 @@ services: # ============== MAIN NET ============== # validator-mainnet: - image: ghcr.io/tensorplex-labs/dojo:main + # TODO @dev regex change later + image: ghcr.io/tensorplex-labs/dojo:dev working_dir: /app env_file: - .env.validator diff --git a/docker/Dockerfile b/docker/Dockerfile index 58617238..34ad611f 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -29,6 +29,8 @@ ARG TARGETPLATFORM RUN echo "Building for TARGETPLATFORM: $TARGETPLATFORM" +RUN git config --global --add safe.directory /app + # jank because pytorch has different versions for cpu for darwin VS linux, see pyproject.toml for specifics RUN if [ "$TARGETPLATFORM" = "linux/amd64" ]; then \ uv pip install --no-cache -e . --find-links https://download.pytorch.org/whl/torch_stable.html; \ diff --git a/neurons/miner.py b/neurons/miner.py index 9968fe5e..df3931b6 100644 --- a/neurons/miner.py +++ b/neurons/miner.py @@ -108,6 +108,7 @@ async def forward_feedback_request( synapse.dojo_task_id = task_ids[0] except Exception: + traceback.print_exc() logger.error( f"Error occurred while processing request id: {synapse.request_id}, error: {traceback.format_exc()}" ) diff --git a/neurons/validator.py b/neurons/validator.py index 93e7c306..8b88f2a1 100644 --- a/neurons/validator.py +++ b/neurons/validator.py @@ -1,5 +1,6 @@ import asyncio import copy +import gc import random import time import traceback @@ -19,7 +20,7 @@ import dojo from commons.dataset.synthetic import SyntheticAPI -from commons.exceptions import EmptyScores, InvalidMinerResponse +from commons.exceptions import EmptyScores, InvalidMinerResponse, NoNewUnexpiredTasksYet from commons.obfuscation.obfuscation_utils import obfuscate_html_and_js from commons.orm import ORM from commons.scoring import Scoring @@ -81,6 +82,7 @@ async def send_scores(self, synapse: ScoringResult, hotkeys: List[str]): async def update_score_and_send_feedback(self): while True: await asyncio.sleep(dojo.VALIDATOR_UPDATE_SCORE) + logger.info("📝 performing scoring ...") try: validator_hotkeys = [ hotkey @@ -100,6 +102,7 @@ async def update_score_and_send_feedback(self): logger.success( f"All tasks processed, total batches: {batch_id}, batch size: {batch_size}" ) + gc.collect() break if not task_batch: @@ -473,6 +476,7 @@ async def send_request( # include the ground_truth to keep in data manager synapse.ground_truth = data.ground_truth + synapse.dendrite.hotkey = self.wallet.hotkey.ss58_address response_data = DendriteQueryResponse( request=synapse, miner_responses=valid_miner_responses, @@ -777,9 +781,6 @@ async def _get_task_results_from_miner( return [] async def monitor_task_completions(self): - SLEEP_SECONDS = 30 - await asyncio.sleep(dojo.DOJO_TASK_MONITORING) - while not self._should_exit: try: validator_hotkey = self.wallet.hotkey.ss58_address @@ -792,6 +793,7 @@ async def monitor_task_completions(self): logger.success( "No more unexpired tasks found for processing, exiting task monitoring." ) + gc.collect() break if not task_batch: @@ -857,12 +859,13 @@ async def monitor_task_completions(self): logger.info( f"Updating task {request_id} with miner's completion data, success ? {success}" ) - + except NoNewUnexpiredTasksYet as e: + logger.info(f"No new unexpired tasks yet: {e}") except Exception as e: traceback.print_exc() logger.error(f"Error during Dojo task monitoring {str(e)}") pass - await asyncio.sleep(SLEEP_SECONDS) + await asyncio.sleep(dojo.DOJO_TASK_MONITORING) @staticmethod def _calculate_averages( diff --git a/schema.prisma b/schema.prisma index 1452667e..b656db00 100644 --- a/schema.prisma +++ b/schema.prisma @@ -17,14 +17,15 @@ enum CriteriaTypeEnum { } model Ground_Truth_Model { - id String @id @default(uuid()) + id String @id @default(uuid()) request_id String obfuscated_model_id String real_model_id String rank_id Int - feedback_request Feedback_Request_Model @relation(fields: [feedback_request_id], references: [id]) feedback_request_id String + created_at DateTime @default(now()) + updated_at DateTime @updatedAt @@unique([request_id, obfuscated_model_id, rank_id]) }