diff --git a/.github/workflows/linter.yaml b/.github/workflows/linter.yaml index 42206ccb..baf3e94b 100644 --- a/.github/workflows/linter.yaml +++ b/.github/workflows/linter.yaml @@ -2,6 +2,7 @@ name: Run Linter on: pull_request: + types: [opened, reopened, ready_for_review, synchronize] branches: - dev - staging @@ -9,6 +10,7 @@ on: jobs: lint: + if: github.event.pull_request.draft == false runs-on: self-hosted steps: - name: Checkout @@ -17,7 +19,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v5.1.1 with: - python-version: '3.11' + python-version: "3.11" - name: Pipx uses: CfirTsabari/actions-pipx@v1.0.2 @@ -25,4 +27,4 @@ jobs: - uses: chartboost/ruff-action@v1 with: version: 0.4.10 - args: 'check . --config pyproject.toml --no-cache' + args: "check . --config pyproject.toml --no-cache" diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index d98a91c8..97ad8191 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -2,6 +2,7 @@ name: Run Tests on: pull_request: + types: [opened, reopened, ready_for_review, synchronize] branches: - dev - staging @@ -9,6 +10,7 @@ on: jobs: tests: + if: github.event.pull_request.draft == false runs-on: self-hosted strategy: matrix: diff --git a/README.md b/README.md index abc9de59..04ad5820 100644 --- a/README.md +++ b/README.md @@ -106,7 +106,7 @@ By creating an open platform for gathering human-generated datasets, Tensorplex - 4 cores - 16 GB RAM -- 256 SSD +- 2TB SSD ## Miner @@ -120,7 +120,7 @@ By creating an open platform for gathering human-generated datasets, Tensorplex - 2 cores - 8 GB RAM -- 32GB SSD +- 32GB SSD or 1TB SSD if decentralised # Getting Started diff --git a/commons/data_manager.py b/commons/data_manager.py deleted file mode 100644 index 53fc1512..00000000 --- a/commons/data_manager.py +++ /dev/null @@ -1,407 +0,0 @@ -import json -from collections import defaultdict -from datetime import datetime, timezone -from typing import List - -import torch -from bittensor.btlogging import logging as logger -from strenum import StrEnum - -from database.client import transaction -from database.mappers import ( - map_completion_response_to_model, - map_criteria_type_to_model, - map_feedback_request_to_model, - map_miner_response_to_model, - map_model_to_dendrite_query_response, -) -from database.prisma._fields import Json -from database.prisma.models import ( - Feedback_Request_Model, - Miner_Response_Model, - Score_Model, - Validator_State_Model, -) -from database.prisma.types import ( - Score_ModelCreateInput, - Score_ModelUpdateInput, - Validator_State_ModelCreateInput, -) -from dojo.protocol import ( - DendriteQueryResponse, - FeedbackRequest, - RidToHotKeyToTaskId, - RidToModelMap, - TaskExpiryDict, -) - - -class ValidatorStateKeys(StrEnum): - SCORES = "scores" - DOJO_TASKS_TO_TRACK = "dojo_tasks_to_track" - MODEL_MAP = "model_map" - TASK_TO_EXPIRY = "task_to_expiry" - - -class DataManager: - _instance = None - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - @classmethod - async def load(cls) -> List[DendriteQueryResponse] | None: - try: - feedback_requests = await Feedback_Request_Model.prisma().find_many( - include={ - "criteria_types": True, - "miner_responses": {"include": {"completions": True}}, - } - ) - - if not feedback_requests or len(feedback_requests) == 0: - logger.error("No Feedback_Request_Model data found.") - return None - - logger.info(f"Loaded {len(feedback_requests)} requests") - - result = [ - map_model_to_dendrite_query_response(r) for r in feedback_requests - ] - - return result - - except Exception as e: - logger.error(f"Failed to load data from database: {e}") - return None - - @classmethod - async def save_dendrite_response( - cls, response: DendriteQueryResponse - ) -> Feedback_Request_Model | None: - try: - feedback_request_model: Feedback_Request_Model | None = None - async with transaction() as tx: - logger.info( - f"Saving dendrite query response for request_id: {response.request.request_id}" - ) - logger.trace("Starting transaction for saving dendrite query response.") - - # Create the main feedback request record - feedback_request_model = await tx.feedback_request_model.create( - data=map_feedback_request_to_model(response.request) - ) - - # Create related criteria types - for criteria in response.request.criteria_types: - criteria_model = map_criteria_type_to_model( - criteria, feedback_request_model.request_id - ) - await tx.criteria_type_model.create(data=criteria_model) - - miner_responses: list[Miner_Response_Model] = [] - # Create related miner responses and their completion responses - for miner_response in response.miner_responses: - miner_response_data = map_miner_response_to_model( - miner_response, - feedback_request_model.request_id, # Use feedback_request_model.request_id - ) - - if not miner_response_data.get("dojo_task_id"): - logger.error("Dojo task id is required") - raise ValueError("Dojo task id is required") - - miner_response_model = await tx.miner_response_model.create( - data=miner_response_data - ) - - miner_responses.append(miner_response_model) - - # Create related completions for miner responses - for completion in miner_response.completion_responses: - completion_data = map_completion_response_to_model( - completion, miner_response_model.id - ) - await tx.completion_response_model.create(data=completion_data) - logger.trace(f"Created completion response: {completion_data}") - - feedback_request_model.miner_responses = miner_responses - return feedback_request_model - except Exception as e: - logger.error(f"Failed to save dendrite query response: {e}") - return None - - @classmethod - async def overwrite_miner_responses_by_request_id( - cls, request_id: str, miner_responses: List[FeedbackRequest] - ) -> bool: - try: - # TODO can improve this - async with transaction() as tx: - # Delete existing completion responses for the given request_id - await tx.completion_response_model.delete_many( - where={"miner_response": {"is": {"request_id": request_id}}} - ) - - # Delete existing miner responses for the given request_id - await tx.miner_response_model.delete_many( - where={"request_id": request_id} - ) - - # Create new miner responses - for miner_response in miner_responses: - miner_response_model = await tx.miner_response_model.create( - data=map_miner_response_to_model(miner_response, request_id) - ) - - # Create related completions for miner responses - for completion in miner_response.completion_responses: - await tx.completion_response_model.create( - data=map_completion_response_to_model( - completion, miner_response_model.id - ) - ) - - logger.success(f"Overwritten miner responses for requestId: {request_id}") - return True - except Exception as e: - logger.error(f"Failed to overwrite miner responses: {e}") - return False - - @classmethod - async def get_by_request_id(cls, request_id: str) -> DendriteQueryResponse | None: - try: - feedback_request = await Feedback_Request_Model.prisma().find_first( - where={"request_id": request_id}, - include={ - "criteria_types": True, - "miner_responses": {"include": {"completions": True}}, - }, - ) - if feedback_request: - return map_model_to_dendrite_query_response(feedback_request) - return None - except Exception as e: - logger.error(f"Failed to get feedback request by request_id: {e}") - return None - - @classmethod - async def remove_responses(cls, responses: List[DendriteQueryResponse]) -> bool: - try: - async with transaction() as tx: - request_ids = [] - for response in responses: - request_id = response.request.request_id - request_ids.append(request_id) - - # Delete completion responses associated with the miner responses - await tx.completion_response_model.delete_many( - where={"miner_response": {"is": {"request_id": request_id}}} - ) - - # Delete miner responses associated with the feedback request - await tx.miner_response_model.delete_many( - where={"request_id": request_id} - ) - - # Delete criteria types associated with the feedback request - await tx.criteria_type_model.delete_many( - where={"request_id": request_id} - ) - - # Delete the feedback request itself - await tx.feedback_request_model.delete_many( - where={"request_id": request_id} - ) - - logger.success(f"Successfully removed responses for {request_ids} requests") - return True - except Exception as e: - logger.error(f"Failed to remove responses: {e}") - return False - - @classmethod - async def validator_save( - cls, - scores: torch.Tensor, - requestid_to_mhotkey_to_task_id: RidToHotKeyToTaskId, - model_map: RidToModelMap, - task_to_expiry: TaskExpiryDict, - ): - """Saves the state of the validator to the database.""" - if cls._instance and cls._instance.step == 0: - return - try: - dojo_task_data = json.loads(json.dumps(requestid_to_mhotkey_to_task_id)) - if not dojo_task_data and torch.count_nonzero(scores).item() == 0: - raise ValueError("Dojo task data and scores are empty. Skipping save.") - - logger.trace(f"Saving validator dojo_task_data: {dojo_task_data}") - logger.trace(f"Saving validator score: {scores}") - - # Convert tensors to lists for JSON serialization - scores_list = scores.tolist() - - # Prepare nested data for creating the validator state - validator_state_data: list[Validator_State_ModelCreateInput] = [ - { - "request_id": request_id, - "miner_hotkey": miner_hotkey, - "task_id": task_id, - "expire_at": task_to_expiry[task_id], - "obfuscated_model": obfuscated_model, - "real_model": real_model, - } - for request_id, hotkey_to_task in dojo_task_data.items() - for miner_hotkey, task_id in hotkey_to_task.items() - for obfuscated_model, real_model in model_map[request_id].items() - ] - - # Save the validator state - await Validator_State_Model.prisma().create_many( - data=validator_state_data, skip_duplicates=True - ) - - if not torch.all(scores == 0): - # Save scores as a single record - score_model = await Score_Model.prisma().find_first() - - if score_model: - await Score_Model.prisma().update( - where={"id": score_model.id}, - data=Score_ModelUpdateInput( - score=Json(json.dumps(scores_list)) - ), - ) - else: - await Score_Model.prisma().create( - data=Score_ModelCreateInput( - score=Json(json.dumps(scores_list)), - ) - ) - - logger.success( - f"📦 Saved validator state with scores: {scores}, and for {len(dojo_task_data)} requests" - ) - else: - logger.warning("Scores are all zero. Skipping save.") - except Exception as e: - logger.error(f"Failed to save validator state: {e}") - - @classmethod - async def validator_load(cls) -> dict | None: - try: - # Query the latest validator state - states: List[ - Validator_State_Model - ] = await Validator_State_Model.prisma().find_many() - - if not states: - return None - - # Query the scores - score_record = await Score_Model.prisma().find_first( - order={"created_at": "desc"} - ) - - if not score_record: - logger.trace("Score record not found.") - return None - - # Deserialize the data - scores: torch.Tensor = torch.tensor(json.loads(score_record.score)) - - # Initialize the dictionaries with the correct types and default factories - dojo_tasks_to_track: RidToHotKeyToTaskId = defaultdict( - lambda: defaultdict(str) - ) - model_map: RidToModelMap = defaultdict(dict) - task_to_expiry: TaskExpiryDict = defaultdict(str) - - for state in states: - if ( - state.request_id not in dojo_tasks_to_track - ): # might not need to check - dojo_tasks_to_track[state.request_id] = {} - dojo_tasks_to_track[state.request_id][state.miner_hotkey] = ( - state.task_id - ) - - if state.request_id not in model_map: - model_map[state.request_id] = {} - model_map[state.request_id][state.obfuscated_model] = state.real_model - - task_to_expiry[state.task_id] = state.expire_at - - return { - "scores": scores, - "dojo_tasks_to_track": dojo_tasks_to_track, - "model_map": model_map, - "task_to_expiry": task_to_expiry, - } - - except Exception as e: - logger.error( - f"Unexpected error occurred while loading validator state: {e}" - ) - return None - - @staticmethod - async def remove_expired_tasks_from_storage(): - try: - state_data = await DataManager.validator_load() - if not state_data: - logger.error( - "Failed to load validator state while removing expired tasks, skipping" - ) - return - - # Identify expired tasks - current_time = datetime.now(timezone.utc) - task_to_expiry = state_data.get(ValidatorStateKeys.TASK_TO_EXPIRY, {}) - expired_tasks = [ - task_id - for task_id, expiry_time in task_to_expiry.items() - if datetime.fromisoformat(expiry_time) < current_time - ] - - # Remove expired tasks from the database - for task_id in expired_tasks: - await Validator_State_Model.prisma().delete_many( - where={"task_id": task_id} - ) - - # Update the in-memory state - for task_id in expired_tasks: - for request_id, hotkeys in list( - state_data[ValidatorStateKeys.DOJO_TASKS_TO_TRACK].items() - ): - for hotkey, t_id in list(hotkeys.items()): - if t_id == task_id: - del state_data[ValidatorStateKeys.DOJO_TASKS_TO_TRACK][ - request_id - ][hotkey] - if not state_data[ValidatorStateKeys.DOJO_TASKS_TO_TRACK][ - request_id - ]: - del state_data[ValidatorStateKeys.DOJO_TASKS_TO_TRACK][ - request_id - ] - del task_to_expiry[task_id] - - # Save the updated state - state_data[ValidatorStateKeys.TASK_TO_EXPIRY] = task_to_expiry - await DataManager.validator_save( - state_data[ValidatorStateKeys.SCORES], - state_data[ValidatorStateKeys.DOJO_TASKS_TO_TRACK], - state_data[ValidatorStateKeys.MODEL_MAP], - task_to_expiry, - ) - if len(expired_tasks) > 0: - logger.info( - f"Removed {len(expired_tasks)} expired tasks from database." - ) - except Exception as e: - logger.error(f"Failed to remove expired tasks: {e}") diff --git a/commons/dojo_task_tracker.py b/commons/dojo_task_tracker.py deleted file mode 100644 index d66c06fc..00000000 --- a/commons/dojo_task_tracker.py +++ /dev/null @@ -1,316 +0,0 @@ -import asyncio -import copy -import traceback -from collections import defaultdict -from datetime import datetime, timezone -from typing import Dict - -import bittensor as bt -from bittensor.btlogging import logging as logger - -import dojo -from commons.data_manager import DataManager -from commons.objects import ObjectManager -from commons.utils import get_epoch_time -from database.prisma.models import Feedback_Request_Model, Miner_Response_Model -from dojo.protocol import ( - CriteriaTypeEnum, - DendriteQueryResponse, - MultiScoreCriteria, - RankingCriteria, - RidToHotKeyToTaskId, - RidToModelMap, - TaskExpiryDict, - TaskResult, - TaskResultRequest, -) - - -class DojoTaskTracker: - _instance = None - # request id -> miner hotkey -> task id - _rid_to_mhotkey_to_task_id: RidToHotKeyToTaskId = defaultdict( - lambda: defaultdict(str) - ) - _rid_to_model_map: RidToModelMap = defaultdict(lambda: defaultdict(str)) - _task_to_expiry: TaskExpiryDict = defaultdict(str) - _lock = asyncio.Lock() - _should_exit: bool = False - - def __new__(cls, *args, **kwargs): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - @classmethod - async def update_task_map( - cls, - request_id: str, - fb_request_model: Feedback_Request_Model, - obfuscated_model_to_model: Dict, - ): - dojo_responses = fb_request_model.miner_responses - if dojo_responses is None or len(dojo_responses) == 0: - logger.warning("No Dojo responses found") - return - - async with cls._lock: - valid_responses: list[Miner_Response_Model] = list( - filter( - lambda r: r.request_id == request_id - and r.miner_hotkey - and r.dojo_task_id, - dojo_responses, - ) - ) - logger.debug( - f"Got {len(valid_responses)} valid Dojo responses to update task tracker" - ) - - if request_id not in cls._rid_to_mhotkey_to_task_id: - cls._rid_to_mhotkey_to_task_id[request_id] = {} - - for r in valid_responses: - cls._rid_to_mhotkey_to_task_id[request_id][r.miner_hotkey] = ( - r.dojo_task_id - ) - - cls._task_to_expiry[r.dojo_task_id] = r.expire_at - - cls._rid_to_model_map[request_id] = obfuscated_model_to_model - return - - @classmethod - async def remove_expired_tasks(cls): - # Identify expired tasks - current_time = datetime.now(timezone.utc) - expired_tasks = [ - task_id - for task_id, expiry_time in cls._task_to_expiry.items() - if datetime.fromisoformat(expiry_time) < current_time - ] - - async with cls._lock: - for task_id in expired_tasks: - # Remove from _rid_to_mhotkey_to_task_id - for request_id, hotkeys in list(cls._rid_to_mhotkey_to_task_id.items()): - for hotkey, t_id in list(hotkeys.items()): - if t_id == task_id: - del cls._rid_to_mhotkey_to_task_id[request_id][hotkey] - if not cls._rid_to_mhotkey_to_task_id[ - request_id - ]: # This means no more hotkeys for this request, so remove the request - del cls._rid_to_mhotkey_to_task_id[request_id] - # Remove from _task_to_expiry - del cls._task_to_expiry[task_id] - - if len(expired_tasks): - logger.info(f"Removed {len(expired_tasks)} expired tasks from task tracker") - - @classmethod - async def get_task_results_from_miner( - cls, miner_hotkey: str, task_id: str - ) -> list[TaskResult]: - """Fetch task results from the miner's Axon using Dendrite.""" - try: - logger.info( - f"Fetching task result from miner {miner_hotkey} for task {task_id}" - ) - - validator = ObjectManager.get_validator() - - dendrite: bt.dendrite = validator.dendrite - metagraph = validator.metagraph - - if not dendrite: - raise ValueError("Dendrite not initialized") - - # Prepare the synapse (data request) that will be sent via Dendrite - task_synapse = TaskResultRequest(task_id=task_id) - - # Use Dendrite to communicate with the Axon - miner_axon = metagraph.axons[metagraph.hotkeys.index(miner_hotkey)] - if not miner_axon: - raise ValueError(f"Miner Axon not found for hotkey: {miner_hotkey}") - - # Send the request via Dendrite and get the response - response = await dendrite.forward( - axons=[miner_axon], synapse=task_synapse, deserialize=False - ) - - logger.debug(f"TaskResult Response from miner {miner_hotkey}: {response}") - - if response and response[0]: - logger.info( - f"Received task result from miner {miner_hotkey} for task {task_id}" - ) - return response[0].task_results - else: - logger.warning( - f"No task results found from miner {miner_hotkey} for task {task_id}" - ) - return [] - - except Exception as e: - logger.error(f"Error fetching task result from miner {miner_hotkey}: {e}") - return [] - - @classmethod - async def monitor_task_completions(cls): - SLEEP_SECONDS = 30 - await asyncio.sleep(dojo.DOJO_TASK_MONITORING) - - while not cls._should_exit: - try: - if len(cls._rid_to_mhotkey_to_task_id.keys()) == 0: - await asyncio.sleep(SLEEP_SECONDS) - continue - - logger.info( - f"Monitoring task completions {get_epoch_time()} for {len(cls._rid_to_mhotkey_to_task_id.keys())} requests" - ) - - # Clean up expired tasks before processing - await cls.remove_expired_tasks() - await DataManager.remove_expired_tasks_from_storage() - - if not cls._rid_to_mhotkey_to_task_id: - await asyncio.sleep(SLEEP_SECONDS) - continue - - for request_id in list(cls._rid_to_mhotkey_to_task_id.keys()): - miner_to_task_id = cls._rid_to_mhotkey_to_task_id[request_id] - processed_hotkeys = set() - - data: ( - DendriteQueryResponse | None - ) = await DataManager.get_by_request_id(request_id) - - if not data or not data.request: - logger.error( - f"No request on disk found for request id: {request_id}" - ) - continue - - for miner_hotkey, task_id in miner_to_task_id.items(): - if not task_id: - logger.warning( - f"No task ID found for miner hotkey: {miner_hotkey}" - ) - continue - - task_results = await cls.get_task_results_from_miner( - miner_hotkey, task_id - ) - - if not task_results and not len(task_results) > 0: - logger.warning( - f"Task ID: {task_id} by miner: {miner_hotkey} has not been completed yet or no task results." - ) - continue - - logger.trace( - f"Request id: {request_id}, miner hotkey: {miner_hotkey}, task id: {task_id}" - ) - - # calculate average rank/scores across a single miner's workers - model_id_to_avg_rank = defaultdict(float) - model_id_to_avg_score = defaultdict(float) - # keep track so we average across the miner's worker pool - num_ranks_by_workers, num_scores_by_workers = 0, 0 - for result in task_results: - for result_data in result.result_data: - type = result_data.type - value = result_data.value - if type == CriteriaTypeEnum.RANKING_CRITERIA: - for model_id, rank in value.items(): - real_model_id = cls._rid_to_model_map.get( - request_id - ).get(model_id) - model_id_to_avg_rank[real_model_id] += rank - num_ranks_by_workers += 1 - elif type == CriteriaTypeEnum.MULTI_SCORE: - for model_id, score in value.items(): - real_model_id = cls._rid_to_model_map.get( - request_id - ).get(model_id) - model_id_to_avg_score[real_model_id] += score - num_scores_by_workers += 1 - - # dvide all sums by the number of ranks and scores - for model_id in model_id_to_avg_rank: - model_id_to_avg_rank[model_id] /= num_ranks_by_workers - for model_id in model_id_to_avg_score: - model_id_to_avg_score[model_id] /= num_scores_by_workers - - # mimic miners responding to the dendrite call - miner_response = copy.deepcopy(data.request) - miner_response.axon = bt.TerminalInfo( - hotkey=miner_hotkey, - ) - miner_response.dojo_task_id = task_id - for completion in miner_response.completion_responses: - model_id = completion.model - - for criteria in miner_response.criteria_types: - if isinstance(criteria, RankingCriteria): - completion.rank_id = int( - model_id_to_avg_rank[model_id] - ) - elif isinstance(criteria, MultiScoreCriteria): - completion.score = model_id_to_avg_score[model_id] - - if model_id_to_avg_rank: - logger.trace( - f"Parsed request with ranks data: {model_id_to_avg_rank}" - ) - if model_id_to_avg_score: - logger.trace( - f"Parsed request with scores data: {model_id_to_avg_score}" - ) - - # miner would have originally responded with the right task id - found_response = next( - ( - r - for r in data.miner_responses - if r.axon.hotkey == miner_hotkey - ), - None, - ) - if not found_response: - logger.warning( - "Miner response not found in data, this should never happen" - ) - data.miner_responses.append(miner_response) - else: - data.miner_responses.remove(found_response) - data.miner_responses.append(miner_response) - - status = ( - await DataManager.overwrite_miner_responses_by_request_id( - request_id, data.miner_responses - ) - ) - logger.trace( - f"Appending Dojo task results for request id: {request_id}, was successful? {status}" - ) - if status: - processed_hotkeys.add(miner_hotkey) - - # determine if we should completely remove the request from the tracker - async with cls._lock: - if processed_hotkeys == set(miner_to_task_id.keys()): - del cls._rid_to_mhotkey_to_task_id[request_id] - del cls._rid_to_model_map[request_id] - else: - for hotkey in processed_hotkeys: - del cls._rid_to_mhotkey_to_task_id[request_id][hotkey] - - ObjectManager.get_validator().save_state() - - except Exception as e: - traceback.print_exc() - logger.error(f"Error during Dojo task monitoring {str(e)}") - pass - await asyncio.sleep(SLEEP_SECONDS) diff --git a/commons/exceptions.py b/commons/exceptions.py new file mode 100644 index 00000000..26686549 --- /dev/null +++ b/commons/exceptions.py @@ -0,0 +1,62 @@ +class NoNewUnexpiredTasksYet(Exception): + """Exception raised when no unexpired tasks are found for processing.""" + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +class UnexpiredTasksAlreadyProcessed(Exception): + """Exception raised when all unexpired tasks have already been processed.""" + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +class InvalidValidatorRequest(Exception): + """Exception raised when a miner response is invalid.""" + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +class InvalidMinerResponse(Exception): + """Exception raised when a miner response is invalid.""" + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +class InvalidCompletion(Exception): + """Exception raised when a completion response is invalid.""" + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +class InvalidTask(Exception): + """Exception raised when a task is invalid.""" + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +class EmptyScores(Exception): + """Exception raised when scores are invalid.""" + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +class CreateTaskFailed(Exception): + """Exception raised when creating a task fails.""" + + def __init__(self, message): + self.message = message + super().__init__(self.message) diff --git a/commons/human_feedback/dojo.py b/commons/human_feedback/dojo.py index 07b8c903..60bb7076 100644 --- a/commons/human_feedback/dojo.py +++ b/commons/human_feedback/dojo.py @@ -5,13 +5,14 @@ from bittensor.btlogging import logging as logger import dojo +from commons.exceptions import CreateTaskFailed from commons.utils import loaddotenv, set_expire_time from dojo import get_dojo_api_base_url from dojo.protocol import FeedbackRequest, MultiScoreCriteria, RankingCriteria DOJO_API_BASE_URL = get_dojo_api_base_url() # to be able to get the curlify requests -DEBUG = False +# DEBUG = False class DojoAPI: @@ -57,7 +58,7 @@ async def get_task_results_by_task_id(cls, task_id: str) -> List[Dict] | None: return task_results @staticmethod - def serialize_feedback_request(data: FeedbackRequest) -> Dict[str, str]: + def serialize_feedback_request(data: FeedbackRequest): output = dict( prompt=data.prompt, responses=[], @@ -83,66 +84,87 @@ async def create_task( cls, feedback_request: FeedbackRequest, ): - logger.debug("Creating Task....") - path = f"{DOJO_API_BASE_URL}/api/v1/tasks/create-tasks" - taskData = cls.serialize_feedback_request(feedback_request) - for criteria_type in feedback_request.criteria_types: - if isinstance(criteria_type, RankingCriteria) or isinstance( - criteria_type, MultiScoreCriteria - ): - taskData["criteria"].append( - { - **criteria_type.model_dump(), - "options": [ - option - for option in criteria_type.model_dump().get("options", []) - ], - } - ) - else: - logger.error(f"Unrecognized criteria type: {type(criteria_type)}") - - expire_at = set_expire_time(dojo.TASK_DEADLINE) - - form_body = { - "title": ("", "LLM Code Generation Task"), - "body": ("", feedback_request.prompt), - "expireAt": ("", expire_at), - "taskData": ("", json.dumps([taskData])), - "maxResults": ("", "1"), - } - - DOJO_API_KEY = loaddotenv("DOJO_API_KEY") - - response = await cls._http_client.post( - path, - files=form_body, - headers={ - "x-api-key": DOJO_API_KEY, - }, - timeout=15.0, - ) + response_text = "" + response_json = {} + try: + path = f"{DOJO_API_BASE_URL}/api/v1/tasks/create-tasks" + taskData = cls.serialize_feedback_request(feedback_request) + for criteria_type in feedback_request.criteria_types: + if isinstance(criteria_type, RankingCriteria) or isinstance( + criteria_type, MultiScoreCriteria + ): + taskData["criteria"].append( + { + **criteria_type.model_dump(), + "options": [ + option + for option in criteria_type.model_dump().get( + "options", [] + ) + ], + } + ) + else: + logger.error(f"Unrecognized criteria type: {type(criteria_type)}") + + expire_at = set_expire_time(dojo.TASK_DEADLINE) + + form_body = { + "title": ("", "LLM Code Generation Task"), + "body": ("", feedback_request.prompt), + "expireAt": ("", expire_at), + "taskData": ("", json.dumps([taskData])), + "maxResults": ("", "1"), + } + + DOJO_API_KEY = loaddotenv("DOJO_API_KEY") + + response = await cls._http_client.post( + path, + files=form_body, + headers={ + "x-api-key": DOJO_API_KEY, + }, + timeout=15.0, + ) - if DEBUG is True: - try: - from curlify2 import Curlify - - curl_req = Curlify(response.request) - print("CURL REQUEST >>> ") - print(curl_req.to_curl()) - except ImportError: - print("Curlify not installed") - except Exception as e: - print("Tried to export create task request as curl, but failed.") - print(f"Exception: {e}") - - task_ids = [] - if response.status_code == 200: - task_ids = response.json()["body"] - logger.success(f"Successfully created task with\ntask ids:{task_ids}") - else: + response_text = response.text + response_json = response.json() + # if DEBUG is True: + # try: + # from curlify2 import Curlify + + # curl_req = Curlify(response.request) + # print("CURL REQUEST >>> ") + # print(curl_req.to_curl()) + # except ImportError: + # print("Curlify not installed") + # except Exception as e: + # print("Tried to export create task request as curl, but failed.") + # print(f"Exception: {e}") + + task_ids = [] + if response.status_code == 200: + task_ids = response.json()["body"] + logger.success(f"Successfully created task with\ntask ids:{task_ids}") + else: + logger.error( + f"Error occurred when trying to create task\nErr:{response.json()['error']}" + ) + response.raise_for_status() + return task_ids + except json.JSONDecodeError as e1: + message = f"While trying to create task got JSON decode error: {e1}, response_text: {response_text}" + logger.error(message) + raise CreateTaskFailed("Failed to create task due to JSON decode error") + except httpx.HTTPStatusError as e: logger.error( - f"Error occurred when trying to create task\nErr:{response.json()['error']}" + f"HTTP error occurred: {e}. Status code: {e.response.status_code}. Response content: {e.response.text}" + ) + raise CreateTaskFailed( + f"Failed to create task due to HTTP error: {e}, response_text: {response_text}, response_json: {response_json}" + ) + except Exception as e: + raise CreateTaskFailed( + f"Failed to create task due to unexpected error: {e}, response_text: {response_text}, response_json: {response_json}" ) - response.raise_for_status() - return task_ids diff --git a/commons/orm.py b/commons/orm.py new file mode 100644 index 00000000..641c72cc --- /dev/null +++ b/commons/orm.py @@ -0,0 +1,523 @@ +import json +from datetime import datetime, timedelta, timezone +from typing import AsyncGenerator, List + +import torch +from bittensor.btlogging import logging as logger + +from commons.exceptions import ( + InvalidCompletion, + InvalidMinerResponse, + InvalidTask, + NoNewUnexpiredTasksYet, + UnexpiredTasksAlreadyProcessed, +) +from commons.utils import datetime_as_utc +from database.client import transaction +from database.mappers import ( + map_child_feedback_request_to_model, + map_completion_response_to_model, + map_criteria_type_to_model, + map_feedback_request_model_to_feedback_request, + map_parent_feedback_request_to_model, +) +from database.prisma import Json +from database.prisma.errors import PrismaError +from database.prisma.models import ( + Feedback_Request_Model, + Ground_Truth_Model, + Score_Model, +) +from database.prisma.types import ( + Feedback_Request_ModelInclude, + Feedback_Request_ModelWhereInput, + Ground_Truth_ModelCreateInput, + Score_ModelCreateInput, + Score_ModelUpdateInput, +) +from dojo import TASK_DEADLINE +from dojo.protocol import ( + CodeAnswer, + CompletionResponses, + DendriteQueryResponse, + FeedbackRequest, +) + + +class ORM: + @staticmethod + async def get_last_expire_at_cutoff( + validator_hotkeys: list[str], + expire_at: datetime = datetime_as_utc( + datetime.now(timezone.utc) - timedelta(seconds=TASK_DEADLINE) + ), + ) -> datetime: + """ + Get the expire at cutoff for the query to `get_unexpired_tasks` + We use 1.5 * TASK_DEADLINE to overlap with the `expire_at` field in the + database so we don't miss out any data. + + Args: + validator_hotkeys (list[str]): List of validator hotkeys. + expire_at (datetime, optional): _description_. Defaults to datetime_as_utc( datetime.now(timezone.utc) - 1.5 * timedelta(seconds=TASK_DEADLINE) ). + + Raises: + ValueError: Unable to determine expire at cutoff + + Returns: + datetime: Expire at cutoff + """ + logger.debug(f"Expire at cutoff: {expire_at}") + vali_where_query_unprocessed = Feedback_Request_ModelWhereInput( + { + "hotkey": {"in": validator_hotkeys, "mode": "insensitive"}, + "child_requests": {"some": {}}, + "expire_at": {"gt": expire_at}, + "is_processed": {"equals": False}, + } + ) + + found = await Feedback_Request_Model.prisma().find_first( + where=vali_where_query_unprocessed, + order={"expire_at": "asc"}, + ) + if found: + return datetime_as_utc(found.expire_at) + + raise ValueError("Unable to determine expire at cutoff") + + @staticmethod + async def get_expired_tasks( + validator_hotkeys: list[str], + batch_size: int = 10, + expire_at: datetime | None = None, + ) -> AsyncGenerator[tuple[List[DendriteQueryResponse], bool], None]: + """Returns a batch of Feedback_Request_Model and a boolean indicating if there are more batches. + Depending on the `expire_at` provided, it will return different results. + + YOUR LOGIC ON WHETHER TASKS ARE EXPIRED OR NON-EXPIRED SHOULD BE HANDLED BY SETTING EXPIRE_AT YOURSELF. + + Args: + validator_hotkeys (list[str]): List of validator hotkeys. + batch_size (int, optional): Number of tasks to return in a batch. Defaults to 10. + + 1 task == 1 validator request, N miner responses + expire_at: (datetime | None) If provided, only tasks with expire_at after the provided datetime will be returned. + You must determine the `expire_at` cutoff yourself, otherwise it defaults to current time UTC. + + Raises: + NoNewUnexpiredTasksYet: If no unexpired tasks are found for processing. + UnexpiredTasksAlreadyProcessed: If all unexpired tasks have already been processed. + + Yields: + Iterator[AsyncGenerator[tuple[List[DendriteQueryResponse], bool], None]]: + Returns a batch of DendriteQueryResponse and a boolean indicating if there are more batches + + """ + + # find all validator requests first + include_query = Feedback_Request_ModelInclude( + { + "completions": True, + "criteria_types": True, + "ground_truths": True, + "parent_request": True, + } + ) + + now = datetime_as_utc(datetime.now(timezone.utc)) + if expire_at: + now = expire_at + + vali_where_query_unprocessed = Feedback_Request_ModelWhereInput( + { + "hotkey": {"in": validator_hotkeys, "mode": "insensitive"}, + "child_requests": {"some": {}}, + # only check for expire at since miner may lie + "expire_at": { + "gt": now, + }, + "is_processed": {"equals": False}, + } + ) + + vali_where_query_processed = Feedback_Request_ModelWhereInput( + { + "hotkey": {"in": validator_hotkeys, "mode": "insensitive"}, + "child_requests": {"some": {}}, + # only check for expire at since miner may lie + "expire_at": { + "gt": now, + }, + "is_processed": {"equals": True}, + } + ) + + # count first total including non + task_count_unprocessed = await Feedback_Request_Model.prisma().count( + where=vali_where_query_unprocessed, + ) + + task_count_processed = await Feedback_Request_Model.prisma().count( + where=vali_where_query_processed, + ) + + logger.debug(f"Count of unprocessed tasks: {task_count_unprocessed}") + logger.debug(f"Count of processed tasks: {task_count_processed}") + + logger.debug( + f"Count of unprocessed tasks: {task_count_unprocessed}, count of processed tasks: {task_count_processed}" + ) + + if not task_count_unprocessed: + if task_count_processed: + raise UnexpiredTasksAlreadyProcessed( + f"No remaining unexpired tasks found for processing, but don't worry as you have processed {task_count_processed} tasks." + ) + else: + raise NoNewUnexpiredTasksYet( + f"No unexpired tasks found for processing, please wait for tasks to pass the task deadline of {TASK_DEADLINE} seconds." + ) + + for i in range(0, task_count_unprocessed, batch_size): + # find all validator requests + validator_requests = await Feedback_Request_Model.prisma().find_many( + include=include_query, + where=vali_where_query_unprocessed, + order={"created_at": "desc"}, + skip=i, + take=batch_size, + ) + + # find all miner responses + validator_request_ids = [r.id for r in validator_requests] + + miner_responses = await Feedback_Request_Model.prisma().find_many( + include=include_query, + where={ + "parent_id": {"in": validator_request_ids}, + "is_processed": {"equals": False}, + }, + order={"created_at": "desc"}, + ) + + responses: list[DendriteQueryResponse] = [] + for validator_request in validator_requests: + vali_request = map_feedback_request_model_to_feedback_request( + validator_request + ) + + m_responses = list( + map( + lambda x: map_feedback_request_model_to_feedback_request( + x, is_miner=True + ), + [ + m + for m in miner_responses + if m.parent_id == validator_request.id + ], + ) + ) + + responses.append( + DendriteQueryResponse( + request=vali_request, miner_responses=m_responses + ) + ) + + # yield responses, so caller can do something + has_more_batches = True + yield responses, has_more_batches + + yield [], False + + @staticmethod + async def get_real_model_ids(request_id: str) -> dict[str, str]: + """Fetches a mapping of obfuscated model IDs to real model IDs for a given request ID.""" + ground_truths = await Ground_Truth_Model.prisma().find_many( + where={"request_id": request_id} + ) + return {gt.obfuscated_model_id: gt.real_model_id for gt in ground_truths} + + @staticmethod + async def mark_tasks_processed_by_request_ids(request_ids: list[str]) -> None: + """Mark records associated with validator's request and miner's responses as processed. + + Args: + request_ids (list[str]): List of request ids. + """ + if not request_ids: + logger.error("No request ids provided to mark as processed") + return + + try: + async with transaction() as tx: + num_updated = await tx.feedback_request_model.update_many( + data={"is_processed": True}, + where={"request_id": {"in": request_ids}}, + ) + logger.success( + f"Marked {num_updated} records associated to {len(request_ids)} tasks as processed" + ) + except PrismaError as exc: + logger.error(f"Prisma error occurred: {exc}") + except Exception as exc: + logger.error(f"Unexpected error occurred: {exc}") + + @staticmethod + async def get_task_by_request_id(request_id: str) -> DendriteQueryResponse | None: + try: + # find the parent id first + include_query = Feedback_Request_ModelInclude( + { + "completions": True, + "criteria_types": True, + "ground_truths": True, + "parent_request": True, + "child_requests": True, + } + ) + all_requests = await Feedback_Request_Model.prisma().find_many( + where={ + "request_id": request_id, + }, + include=include_query, + ) + + validator_requests = [r for r in all_requests if r.parent_id is None] + assert len(validator_requests) == 1, "Expected only one validator request" + validator_request = validator_requests[0] + if not validator_request.child_requests: + raise InvalidTask( + f"Validator request {validator_request.id} must have child requests" + ) + + miner_responses = [ + map_feedback_request_model_to_feedback_request(r, is_miner=True) + for r in validator_request.child_requests + ] + return DendriteQueryResponse( + request=map_feedback_request_model_to_feedback_request( + model=validator_request, is_miner=False + ), + miner_responses=miner_responses, + ) + + except Exception as e: + logger.error(f"Failed to get feedback request by request_id: {e}") + return None + + @staticmethod + async def get_num_processed_tasks() -> int: + return await Feedback_Request_Model.prisma().count( + where={"is_processed": True, "parent_id": None} + ) + + @staticmethod + async def update_miner_completions_by_request_id( + request_id: str, miner_responses: List[FeedbackRequest] + ) -> bool: + """Update the miner's provided rank_id / scores etc. for a given request id that it is responding to validator. This exists because over the course of a task, a miner may recruit multiple workers and we + need to recalculate the average score / rank_id etc. across all workers. + """ + try: + async with transaction() as tx: + # find the feedback request ids + miner_hotkeys = [] + for miner_response in miner_responses: + if not miner_response.axon or not miner_response.axon.hotkey: + raise InvalidMinerResponse( + f"Miner response {miner_response} must have a hotkey" + ) + miner_hotkeys.append(miner_response.axon.hotkey) + + found_responses = await tx.feedback_request_model.find_many( + where={"request_id": request_id, "hotkey": {"in": miner_hotkeys}} + ) + + # delete the completions for all of these miners + await tx.completion_response_model.delete_many( + where={ + "feedback_request_id": {"in": [r.id for r in found_responses]} + } + ) + + # reconstruct the completion_responses data + for miner_response in miner_responses: + # find the particular request + hotkey = miner_response.axon.hotkey # type: ignore + curr_miner_response = await tx.feedback_request_model.find_first( + where=Feedback_Request_ModelWhereInput( + request_id=request_id, + hotkey=hotkey, # type: ignore + ) + ) + + if not curr_miner_response: + raise ValueError("Miner response not found") + + # recreate completions + for completion in miner_response.completion_responses: + await tx.completion_response_model.create( + data=map_completion_response_to_model( + completion, curr_miner_response.id + ) + ) + + logger.success( + f"Successfully updated completion data for miners: {miner_hotkeys}" + ) + return True + except Exception as e: + logger.error(f"Failed to update completion data for miner responses: {e}") + return False + + @staticmethod + async def save_task( + validator_request: FeedbackRequest, + miner_responses: List[FeedbackRequest], + ground_truth: dict[str, int], + ) -> Feedback_Request_Model | None: + """Saves a task, which consists of both the validator's request and the miners' responses. + + Args: + validator_request (FeedbackRequest): The request made by the validator. + miner_responses (List[FeedbackRequest]): The responses made by the miners. + ground_truth (dict[str, str]): The ground truth for the task, where dict + + Returns: + Feedback_Request_Model | None: Only validator's feedback request model, or None if failed. + """ + try: + feedback_request_model: Feedback_Request_Model | None = None + async with transaction() as tx: + logger.trace("Starting transaction for saving task.") + + feedback_request_model = await tx.feedback_request_model.create( + data=map_parent_feedback_request_to_model(validator_request) + ) + + # Create related criteria types + criteria_create_input = [ + map_criteria_type_to_model(criteria, feedback_request_model.id) + for criteria in validator_request.criteria_types + ] + await tx.criteria_type_model.create_many(criteria_create_input) + + # Create related miner responses (child) and their completion responses + created_miner_models: list[Feedback_Request_Model] = [] + for miner_response in miner_responses: + try: + create_miner_model_input = map_child_feedback_request_to_model( + miner_response, + feedback_request_model.id, + expire_at=feedback_request_model.expire_at, + ) + + created_miner_model = await tx.feedback_request_model.create( + data=create_miner_model_input + ) + + created_miner_models.append(created_miner_model) + + criteria_create_input = [ + map_criteria_type_to_model(criteria, created_miner_model.id) + for criteria in miner_response.criteria_types + ] + await tx.criteria_type_model.create_many(criteria_create_input) + + # Create related completions for miner responses + for completion in miner_response.completion_responses: + # remove the completion field, since the miner receives an obfuscated completion_response anyways + # therefore it is useless for training + try: + completion_copy = completion.model_dump() + completion_copy["completion"] = CodeAnswer(files=[]) + except KeyError: + pass + completion_input = map_completion_response_to_model( + CompletionResponses.model_validate(completion_copy), + created_miner_model.id, + ) + await tx.completion_response_model.create( + data=completion_input + ) + logger.trace( + f"Created completion response: {completion_input}" + ) + + # we catch exceptions here because whether a miner responds well should not affect other miners + except InvalidMinerResponse as e: + miner_hotkey = ( + miner_response.axon.hotkey if miner_response.axon else "??" + ) + logger.debug( + f"Miner response from hotkey: {miner_hotkey} is invalid: {e}" + ) + except InvalidCompletion as e: + miner_hotkey = ( + miner_response.axon.hotkey if miner_response.axon else "??" + ) + logger.debug( + f"Completion response from hotkey: {miner_hotkey} is invalid: {e}" + ) + + if len(created_miner_models) == 0: + raise InvalidTask( + "A task must consist of at least one miner response, along with validator's request" + ) + + # this is dependent on how we obfuscate in `validator.send_request` + for completion_id, rank_id in ground_truth.items(): + gt_create_input = { + "rank_id": rank_id, + "obfuscated_model_id": completion_id, + "request_id": validator_request.request_id, + "real_model_id": completion_id, + "feedback_request_id": feedback_request_model.id, + } + await tx.ground_truth_model.create( + data=Ground_Truth_ModelCreateInput(**gt_create_input) + ) + for vali_completion in validator_request.completion_responses: + vali_completion_input = map_completion_response_to_model( + vali_completion, + feedback_request_model.id, + ) + await tx.completion_response_model.create( + data=vali_completion_input + ) + + feedback_request_model.child_requests = created_miner_models + return feedback_request_model + except Exception as e: + logger.error(f"Failed to save dendrite query response: {e}") + return None + + @staticmethod + async def create_or_update_validator_score(scores: torch.Tensor) -> None: + # Save scores as a single record + score_model = await Score_Model.prisma().find_first() + scores_list = scores.tolist() + if score_model: + await Score_Model.prisma().update( + where={"id": score_model.id}, + data=Score_ModelUpdateInput(score=Json(json.dumps(scores_list))), + ) + else: + await Score_Model.prisma().create( + data=Score_ModelCreateInput( + score=Json(json.dumps(scores_list)), + ) + ) + + @staticmethod + async def get_validator_score() -> torch.Tensor | None: + score_record = await Score_Model.prisma().find_first( + order={"created_at": "desc"} + ) + if not score_record: + return None + + return torch.tensor(json.loads(score_record.score)) diff --git a/commons/scoring.py b/commons/scoring.py index 7ef818e8..abec67b9 100644 --- a/commons/scoring.py +++ b/commons/scoring.py @@ -197,21 +197,28 @@ def consensus_score( # this works because we are calculating ICC for each rater VS the avg for rater_id in rater_ids: - data_by_rater = df[["subject", rater_id, "avg"]] - # only use the columns for the current rater and avg - data_by_rater = data_by_rater.melt( - id_vars=["subject"], var_name=rater_id, value_name="score" - ) - icc = pg.intraclass_corr( - data=data_by_rater, - targets="subject", - raters=rater_id, - ratings="score", - ) + try: + data_by_rater = df[["subject", rater_id, "avg"]] + # only use the columns for the current rater and avg + data_by_rater = data_by_rater.melt( + id_vars=["subject"], var_name=rater_id, value_name="score" + ) + icc = pg.intraclass_corr( + data=data_by_rater, + targets="subject", + raters=rater_id, + ratings="score", + ) + + # take ICC(2,1) + icc2_value = icc[icc["Type"] == "ICC2"]["ICC"].iloc[0] + icc_arr.append(icc2_value) + + except Exception as e: + logger.error(f"Error calculating ICC for rater {rater_id}: {e}") + logger.debug(f"Data by rater: {data_by_rater}") + continue - # take ICC(2,1) - icc2_value = icc[icc["Type"] == "ICC2"]["ICC"].iloc[0] - icc_arr.append(icc2_value) # already in the range [0, 1] icc_arr: torch.Tensor = torch.tensor(np.array(icc_arr)) diff --git a/commons/utils.py b/commons/utils.py index 41ba0467..2a01dcd9 100644 --- a/commons/utils.py +++ b/commons/utils.py @@ -28,6 +28,18 @@ def get_epoch_time(): return time.time() +def datetime_as_utc(dt: datetime) -> datetime: + return dt.replace(tzinfo=timezone.utc) + + +def datetime_to_iso8601_str(dt: datetime) -> str: + return dt.replace(tzinfo=timezone.utc).isoformat() + + +def iso8601_str_to_datetime(dt_str: str) -> datetime: + return datetime.fromisoformat(dt_str).replace(tzinfo=timezone.utc) + + def loaddotenv(varname: str): """Wrapper to get env variables for sanity checking""" value = os.getenv(varname) @@ -68,7 +80,6 @@ def init_wandb(config: bt.config, my_uid, wallet: bt.wallet): # Manually deepcopy neuron and data_manager, otherwise it is referenced to the same object config.neuron = copy.deepcopy(config.neuron) - config.data_manager = copy.deepcopy(config.data_manager) project_name = None @@ -86,11 +97,6 @@ def init_wandb(config: bt.config, my_uid, wallet: bt.wallet): if config.neuron.full_path else None ) - config.data_manager.base_path = ( - hide_sensitive_path(config.data_manager.base_path) - if config.data_manager.base_path - else None - ) config.uid = my_uid config.hotkey = wallet.hotkey.ss58_address @@ -331,10 +337,6 @@ def ttl_get_block(subtensor) -> int: return subtensor.get_current_block() -def get_current_utc_time_iso(): - return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") - - def set_expire_time(expire_in_seconds: int) -> str: """ Sets the expiration time based on the current UTC time and the given number of seconds. @@ -347,7 +349,7 @@ def set_expire_time(expire_in_seconds: int) -> str: """ return ( (datetime.now(timezone.utc) + timedelta(seconds=expire_in_seconds)) - .replace(microsecond=0, tzinfo=timezone.utc) + .replace(tzinfo=timezone.utc) .isoformat() .replace("+00:00", "Z") ) diff --git a/database/mappers.py b/database/mappers.py index 24fddb46..4c0c750c 100644 --- a/database/mappers.py +++ b/database/mappers.py @@ -1,24 +1,27 @@ import json -from datetime import timedelta +from datetime import datetime, timezone import bittensor as bt from loguru import logger -import dojo -from commons.utils import is_valid_expiry, set_expire_time +from commons.exceptions import InvalidMinerResponse, InvalidValidatorRequest +from commons.utils import ( + datetime_as_utc, + datetime_to_iso8601_str, + iso8601_str_to_datetime, +) from database.prisma import Json -from database.prisma.enums import Criteria_Type_Enum_Model +from database.prisma.enums import CriteriaTypeEnum from database.prisma.models import Criteria_Type_Model, Feedback_Request_Model from database.prisma.types import ( Completion_Response_ModelCreateInput, Criteria_Type_ModelCreateInput, + Criteria_Type_ModelCreateWithoutRelationsInput, Feedback_Request_ModelCreateInput, - Miner_Response_ModelCreateInput, ) from dojo.protocol import ( CompletionResponses, CriteriaType, - DendriteQueryResponse, FeedbackRequest, MultiScoreCriteria, MultiSelectCriteria, @@ -26,35 +29,40 @@ ScoreCriteria, ) +# ---------------------------------------------------------------------------- # +# MAP PROTOCOL OBJECTS TO DATABASE MODEL INPUTS # +# ---------------------------------------------------------------------------- # + def map_criteria_type_to_model( - criteria: CriteriaType, request_id: str -) -> Criteria_Type_ModelCreateInput: + criteria: CriteriaType, feedback_request_id: str +) -> Criteria_Type_ModelCreateWithoutRelationsInput: try: if isinstance(criteria, RankingCriteria): return Criteria_Type_ModelCreateInput( - type=Criteria_Type_Enum_Model.RANKING_CRITERIA, - request_id=request_id, + type=CriteriaTypeEnum.RANKING_CRITERIA, + feedback_request_id=feedback_request_id, # this is parent_id # options=cast(Json, json.dumps(criteria.options)), options=Json(json.dumps(criteria.options)), ) elif isinstance(criteria, ScoreCriteria): return Criteria_Type_ModelCreateInput( - type=Criteria_Type_Enum_Model.SCORE, - request_id=request_id, + type=CriteriaTypeEnum.SCORE, + feedback_request_id=feedback_request_id, min=criteria.min, max=criteria.max, + options=Json(json.dumps([])), ) elif isinstance(criteria, MultiSelectCriteria): return Criteria_Type_ModelCreateInput( - type=Criteria_Type_Enum_Model.MULTI_SELECT, - request_id=request_id, + type=CriteriaTypeEnum.MULTI_SELECT, + feedback_request_id=feedback_request_id, options=Json(json.dumps(criteria.options)), ) elif isinstance(criteria, MultiScoreCriteria): return Criteria_Type_ModelCreateInput( - type=Criteria_Type_Enum_Model.MULTI_SCORE, - request_id=request_id, + type=CriteriaTypeEnum.MULTI_SCORE, + feedback_request_id=feedback_request_id, options=Json(json.dumps(criteria.options)), min=criteria.min, max=criteria.max, @@ -69,20 +77,20 @@ def map_criteria_type_model_to_criteria_type( model: Criteria_Type_Model, ) -> CriteriaType: try: - if model.type == Criteria_Type_Enum_Model.RANKING_CRITERIA: + if model.type == CriteriaTypeEnum.RANKING_CRITERIA: return RankingCriteria( options=json.loads(model.options) if model.options else [] ) - elif model.type == Criteria_Type_Enum_Model.SCORE: + elif model.type == CriteriaTypeEnum.SCORE: return ScoreCriteria( min=model.min if model.min is not None else 0.0, max=model.max if model.max is not None else 0.0, ) - elif model.type == Criteria_Type_Enum_Model.MULTI_SELECT: + elif model.type == CriteriaTypeEnum.MULTI_SELECT: return MultiSelectCriteria( options=json.loads(model.options) if model.options else [] ) - elif model.type == Criteria_Type_Enum_Model.MULTI_SCORE: + elif model.type == CriteriaTypeEnum.MULTI_SCORE: return MultiScoreCriteria( options=json.loads(model.options) if model.options else [], min=model.min if model.min is not None else 0.0, @@ -96,123 +104,142 @@ def map_criteria_type_model_to_criteria_type( def map_completion_response_to_model( - response: CompletionResponses, miner_response_id: str + response: CompletionResponses, feedback_request_id: str ) -> Completion_Response_ModelCreateInput: - try: - result = Completion_Response_ModelCreateInput( - completion_id=response.completion_id, - model=response.model, - # completion=cast(Json, json.dumps(response.completion)), - completion=Json(json.dumps(response.completion, default=vars)), - rank_id=response.rank_id, - score=response.score, - miner_response_id=miner_response_id, - ) - return result - except Exception as e: - raise ValueError(f"Failed to map completion response to model {e}") + result = Completion_Response_ModelCreateInput( + completion_id=response.completion_id, + model=response.model, + completion=Json(json.dumps(response.completion, default=vars)), + rank_id=response.rank_id, + score=response.score, + feedback_request_id=feedback_request_id, + ) + return result + + +def map_parent_feedback_request_to_model( + request: FeedbackRequest, +) -> Feedback_Request_ModelCreateInput: + if not request.dendrite or not request.dendrite.hotkey: + raise InvalidValidatorRequest("Validator Hotkey is required") + if not request.expire_at: + raise InvalidValidatorRequest("Expire at is required") -def map_miner_response_to_model( - response: FeedbackRequest, request_id: str -) -> Miner_Response_ModelCreateInput: - try: - # Ensure expire_at is set and is reasonable, this will prevent exploits where miners can set their own expiry times - expire_at = response.expire_at - if expire_at is None or is_valid_expiry(expire_at) is not True: - expire_at = set_expire_time(dojo.TASK_DEADLINE) - - if response.dojo_task_id is None: - raise ValueError("Dojo task id is required") - - result = Miner_Response_ModelCreateInput( - request_id=request_id, - miner_hotkey=response.axon.hotkey - if response.axon and response.axon.hotkey - else "", - dojo_task_id=response.dojo_task_id, - expire_at=expire_at, - ) + expire_at = iso8601_str_to_datetime(request.expire_at) + if expire_at < datetime.now(timezone.utc): + raise InvalidValidatorRequest("Expire at must be in the future") - return result - except Exception as e: - raise ValueError(f"Failed to map miner response to model {e}") + result = Feedback_Request_ModelCreateInput( + request_id=request.request_id, + task_type=request.task_type, + prompt=request.prompt, + hotkey=request.dendrite.hotkey, + expire_at=expire_at, + ) + return result -def map_feedback_request_to_model( - request: FeedbackRequest, + +def map_child_feedback_request_to_model( + request: FeedbackRequest, parent_id: str, expire_at: datetime ) -> Feedback_Request_ModelCreateInput: - try: - result = Feedback_Request_ModelCreateInput( - request_id=request.request_id, - task_type=request.task_type, - prompt=request.prompt, - ground_truth=Json(json.dumps(request.ground_truth)), - ) + if not request.axon or not request.axon.hotkey: + raise InvalidMinerResponse("Miner Hotkey is required") - return result - except Exception as e: - raise ValueError(f"Failed to map feedback request to model {e}") + if not parent_id: + raise InvalidMinerResponse("Parent ID is required") + + if not request.dojo_task_id: + raise InvalidMinerResponse("Dojo Task ID is required") + + result = Feedback_Request_ModelCreateInput( + request_id=request.request_id, + task_type=request.task_type, + prompt=request.prompt, + hotkey=request.axon.hotkey, + expire_at=datetime_as_utc(expire_at), + dojo_task_id=request.dojo_task_id, + parent_id=parent_id, + ) + + return result -def map_model_to_dendrite_query_response( - model: Feedback_Request_Model, -) -> DendriteQueryResponse: +# ---------------------------------------------------------------------------- # +# MAPPING DATABASE OBJECTS TO OUR PROTOCOL OBJECTS # +# ---------------------------------------------------------------------------- # + + +def map_feedback_request_model_to_feedback_request( + model: Feedback_Request_Model, is_miner: bool = False +) -> FeedbackRequest: + """Smaller function to map Feedback_Request_Model to FeedbackRequest, meant to be used when reading from database. + + Args: + model (Feedback_Request_Model): Feedback_Request_Model from database. + is_miner (bool, optional): If we're converting for a validator request or miner response. + Defaults to False. + + Raises: + ValueError: If failed to map. + + Returns: + FeedbackRequest: FeedbackRequest object. + """ + try: + # Map criteria types criteria_types = [ map_criteria_type_model_to_criteria_type(criteria) for criteria in model.criteria_types or [] ] - completions: list[CompletionResponses] = [] - if model.miner_responses is not None: - completions = [ - CompletionResponses( - completion_id=completion.completion_id, - model=completion.model, - completion=completion.completion, - rank_id=completion.rank_id, - score=completion.score, - ) - for completion in model.miner_responses[0].completions or [] - ] - - # Add TASK_DEADLINE to created_at - expire_at_dt = model.created_at + timedelta(seconds=dojo.TASK_DEADLINE) - - request: FeedbackRequest = FeedbackRequest( - request_id=model.request_id, - prompt=model.prompt, - completion_responses=completions, - task_type=model.task_type, - criteria_types=criteria_types, - ground_truth=json.loads(model.ground_truth), - expire_at=expire_at_dt.isoformat().replace("+00:00", "Z"), - ) + # Map completion responses + completion_responses = [ + CompletionResponses( + completion_id=completion.completion_id, + model=completion.model, + completion=json.loads(completion.completion), + rank_id=completion.rank_id, + score=completion.score, + ) + for completion in model.completions or [] + ] - miner_responses: list[FeedbackRequest] = [ - FeedbackRequest( - request_id=miner_response.request_id, + ground_truth: dict[str, int] = {} + + if model.ground_truths: + for gt in model.ground_truths: + ground_truth[gt.obfuscated_model_id] = gt.rank_id + + if is_miner: + # Create FeedbackRequest object + feedback_request = FeedbackRequest( + request_id=model.request_id, prompt=model.prompt, + task_type=model.task_type, criteria_types=criteria_types, + completion_responses=completion_responses, + dojo_task_id=model.dojo_task_id, + expire_at=datetime_to_iso8601_str(model.expire_at), + axon=bt.TerminalInfo(hotkey=model.hotkey), + ) + else: + feedback_request = FeedbackRequest( + request_id=model.request_id, + prompt=model.prompt, task_type=model.task_type, - dojo_task_id=miner_response.dojo_task_id, - expire_at=miner_response.expire_at, - completion_responses=[ - CompletionResponses( - completion_id=completion.completion_id, - model=completion.model, - completion=completion.completion, - rank_id=completion.rank_id, - score=completion.score, - ) - for completion in miner_response.completions or [] - ], - axon=bt.TerminalInfo(hotkey=miner_response.miner_hotkey), + criteria_types=criteria_types, + completion_responses=completion_responses, + dojo_task_id=model.dojo_task_id, + expire_at=datetime_to_iso8601_str(model.expire_at), + dendrite=bt.TerminalInfo(hotkey=model.hotkey), + ground_truth=ground_truth, ) - for miner_response in (model.miner_responses or []) - ] - return DendriteQueryResponse(request=request, miner_responses=miner_responses) + return feedback_request except Exception as e: - raise ValueError(f"Failed to map model to dendrite query response: {e}") + raise ValueError( + f"Failed to map Feedback_Request_Model to FeedbackRequest: {e}" + ) diff --git a/docker-compose.validator.yaml b/docker-compose.validator.yaml index 2e61cf9b..0cb6b891 100644 --- a/docker-compose.validator.yaml +++ b/docker-compose.validator.yaml @@ -127,7 +127,7 @@ services: - WANDB_PROJECT_NAME=dojo-testnet - NETUID=98 - SUBTENSOR_NETWORK=test - - SUBTENSOR_ENDPOINT=ws://test.finney.opentensor.ai + - SUBTENSOR_ENDPOINT=wss://test.finney.opentensor.ai:443/ - PRISMA_QUERY_ENGINE_BINARY=/root/prisma-python/node_modules/prisma/query-engine-debian-openssl-3.0.x volumes: - ./:/app @@ -177,4 +177,4 @@ services: condition: service_healthy prisma-setup-vali: condition: service_completed_successfully - logging: *default-logging \ No newline at end of file + logging: *default-logging diff --git a/docker/Dockerfile b/docker/Dockerfile index 58617238..34ad611f 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -29,6 +29,8 @@ ARG TARGETPLATFORM RUN echo "Building for TARGETPLATFORM: $TARGETPLATFORM" +RUN git config --global --add safe.directory /app + # jank because pytorch has different versions for cpu for darwin VS linux, see pyproject.toml for specifics RUN if [ "$TARGETPLATFORM" = "linux/amd64" ]; then \ uv pip install --no-cache -e . --find-links https://download.pytorch.org/whl/torch_stable.html; \ diff --git a/dojo/__init__.py b/dojo/__init__.py index f47fb30b..b074ce9c 100644 --- a/dojo/__init__.py +++ b/dojo/__init__.py @@ -32,12 +32,12 @@ def get_latest_git_tag(): # Import all submodules. VALIDATOR_MIN_STAKE = 20000 -TASK_DEADLINE = 8 * 60 * 60 +TASK_DEADLINE = 6 * 60 * 60 # Define the time intervals for various tasks. VALIDATOR_RUN = 300 VALIDATOR_HEARTBEAT = 150 -VALIDATOR_UPDATE_SCORE = 1200 +VALIDATOR_UPDATE_SCORE = 3600 VALIDATOR_STATUS = 60 MINER_STATUS = 60 DOJO_TASK_MONITORING = 60 diff --git a/dojo/protocol.py b/dojo/protocol.py index e52f74db..6c2dff4e 100644 --- a/dojo/protocol.py +++ b/dojo/protocol.py @@ -175,6 +175,7 @@ class Heartbeat(bt.Synapse): ack: bool = Field(description="Acknowledgement of the heartbeat", default=False) +# TODO rename this to be a Task or something class DendriteQueryResponse(BaseModel): model_config = ConfigDict(frozen=False) request: FeedbackRequest diff --git a/dojo/utils/config.py b/dojo/utils/config.py index e9629a57..a278777f 100644 --- a/dojo/utils/config.py +++ b/dojo/utils/config.py @@ -118,7 +118,7 @@ def add_args(parser): Adds relevant arguments to the parser for operation. """ # Netuid Arg: The netuid of the subnet to connect to. - parser.add_argument("--netuid", type=int, help="Subnet netuid", default=1) + parser.add_argument("--netuid", type=int, help="Subnet netuid", default=52) neuron_types = ["miner", "validator"] parser.add_argument( @@ -127,7 +127,7 @@ def add_args(parser): type=str, help="Whether running a miner or validator", ) - args, unknown = parser.parse_known_args() + args, _ = parser.parse_known_args() neuron_type = None if known_args := vars(args): neuron_type = known_args["neuron.type"] @@ -164,19 +164,18 @@ def add_args(parser): help="Path to the environment file to use.", ) - if neuron_type == "validator": - parser.add_argument( - "--data_manager.base_path", - type=str, - help="Base path to store data to.", - default=base_path, - ) + parser.add_argument( + "--ignore_min_stake", + action="store_true", + help="Whether to always include self in monitoring queries, mainly for testing", + ) + if neuron_type == "validator": parser.add_argument( "--neuron.sample_size", type=int, help="The number of miners to query per dendrite call.", - default=10, + default=8, ) parser.add_argument( diff --git a/dojo_cli.py b/dojo_cli.py index b2bdd569..3de8406b 100644 --- a/dojo_cli.py +++ b/dojo_cli.py @@ -9,11 +9,13 @@ from rich.console import Console from dojo import get_dojo_api_base_url -from dojo.utils.config import source_dotenv +from dojo.utils.config import get_config, source_dotenv + +get_config() +source_dotenv() DOJO_API_BASE_URL = get_dojo_api_base_url() -source_dotenv() console = Console() diff --git a/main_validator.py b/main_validator.py index 2fa72f9f..0a46e7a5 100644 --- a/main_validator.py +++ b/main_validator.py @@ -13,7 +13,6 @@ from commons.objects import ObjectManager from database.client import connect_db, disconnect_db from dojo.utils.config import source_dotenv -from neurons.validator import DojoTaskTracker source_dotenv() @@ -27,7 +26,6 @@ async def lifespan(app: FastAPI): yield logger.info("Performing shutdown tasks...") validator._should_exit = True - DojoTaskTracker()._should_exit = True wandb.finish() validator.save_state() await SyntheticAPI._session.close() @@ -56,7 +54,7 @@ async def main(): asyncio.create_task(validator.log_validator_status()), asyncio.create_task(validator.run()), asyncio.create_task(validator.update_score_and_send_feedback()), - asyncio.create_task(DojoTaskTracker.monitor_task_completions()), + asyncio.create_task(validator.monitor_task_completions()), asyncio.create_task(validator.send_heartbeats()), ] diff --git a/neurons/miner.py b/neurons/miner.py index 9968fe5e..41574501 100644 --- a/neurons/miner.py +++ b/neurons/miner.py @@ -14,6 +14,7 @@ from dojo import MINER_STATUS, VALIDATOR_MIN_STAKE from dojo.base.miner import BaseMinerNeuron from dojo.protocol import FeedbackRequest, Heartbeat, ScoringResult, TaskResultRequest +from dojo.utils.config import get_config from dojo.utils.uids import is_miner @@ -108,6 +109,7 @@ async def forward_feedback_request( synapse.dojo_task_id = task_ids[0] except Exception: + traceback.print_exc() logger.error( f"Error occurred while processing request id: {synapse.request_id}, error: {traceback.format_exc()}" ) @@ -134,6 +136,7 @@ async def forward_task_result_request( return synapse except Exception as e: + traceback.print_exc() logger.error(f"Error handling TaskResultRequest: {e}") return synapse @@ -157,6 +160,15 @@ async def blacklist_feedback_request( if is_miner(self.metagraph, caller_uid): return True, "Not a validator" + if get_config().ignore_min_stake: + message = f"""Ignoring min stake stake required: {VALIDATOR_MIN_STAKE} \ + for {caller_hotkey}, YOU SHOULD NOT SEE THIS when you are running a miner on mainnet""" + logger.warning(message) + return ( + False, + f"Ignored minimum validator stake requirement of {VALIDATOR_MIN_STAKE}", + ) + if validator_neuron.stake.tao < float(VALIDATOR_MIN_STAKE): logger.warning( f"Blacklisting hotkey: {caller_hotkey} with insufficient stake, minimum stake required: {VALIDATOR_MIN_STAKE}, current stake: {validator_neuron.stake.tao}" diff --git a/neurons/validator.py b/neurons/validator.py index a99fb57e..a3988feb 100644 --- a/neurons/validator.py +++ b/neurons/validator.py @@ -1,8 +1,10 @@ import asyncio import copy +import gc import random import time import traceback +from collections import defaultdict from datetime import datetime, timezone from traceback import print_exception from typing import List @@ -18,25 +20,38 @@ from torch.nn import functional as F import dojo -from commons.data_manager import DataManager, ValidatorStateKeys from commons.dataset.synthetic import SyntheticAPI -from commons.dojo_task_tracker import DojoTaskTracker +from commons.exceptions import ( + EmptyScores, + InvalidMinerResponse, + NoNewUnexpiredTasksYet, +) from commons.obfuscation.obfuscation_utils import obfuscate_html_and_js +from commons.orm import ORM from commons.scoring import Scoring -from commons.utils import get_epoch_time, get_new_uuid, init_wandb, set_expire_time +from commons.utils import ( + datetime_as_utc, + get_epoch_time, + get_new_uuid, + init_wandb, + set_expire_time, +) from database.client import connect_db from dojo.base.neuron import BaseNeuron from dojo.protocol import ( CompletionResponses, + CriteriaTypeEnum, DendriteQueryResponse, FeedbackRequest, Heartbeat, MultiScoreCriteria, ScoringResult, + TaskResult, + TaskResultRequest, TaskType, ) from dojo.utils.config import get_config -from dojo.utils.uids import MinerUidSelector, extract_miner_uids +from dojo.utils.uids import MinerUidSelector, extract_miner_uids, is_miner class Validator(BaseNeuron): @@ -52,7 +67,9 @@ def __init__(self): self.dendrite = bt.dendrite(wallet=self.wallet) logger.info(f"Dendrite: {self.dendrite}") # Set up initial scoring weights for validation - self.scores = torch.zeros(self.metagraph.n.item(), dtype=torch.float32) + self.scores: torch.Tensor = torch.zeros( + self.metagraph.n.item(), dtype=torch.float32 + ) self.load_state() # manually always register and always sync metagraph when application starts @@ -76,127 +93,153 @@ async def send_scores(self, synapse: ScoringResult, hotkeys: List[str]): ) async def update_score_and_send_feedback(self): - """While this function is triggered every X time period in AsyncIOScheduler, - only relevant data that has passed the deadline of 8 hours will be scored and sent feedback. - """ while True: await asyncio.sleep(dojo.VALIDATOR_UPDATE_SCORE) + logger.info("📝 performing scoring ...") try: - data: List[DendriteQueryResponse] | None = await DataManager.load() - if not data: - logger.debug( - "Skipping scoring as no feedback data found, this means either all have been processed or you are running the validator for the first time." - ) - continue - - current_time = datetime.now(timezone.utc) - # allow enough time for human feedback - non_expired_data: List[DendriteQueryResponse] = [ - d - for d in data - if d.request.expire_at - and datetime.fromisoformat(d.request.expire_at) < current_time + validator_hotkeys = [ + hotkey + for uid, hotkey in enumerate(self.metagraph.hotkeys) + if not is_miner(self.metagraph, uid) ] - if not non_expired_data: - logger.warning( - "Skipping scoring as no feedback data is due for scoring." - ) - logger.info( - f"Got {len(non_expired_data)} requests past deadline and ready to score" - ) - for d in non_expired_data: - criteria_to_miner_score, hotkey_to_score = Scoring.calculate_score( - criteria_types=d.request.criteria_types, - request=d.request, - miner_responses=d.miner_responses, - ) - logger.trace(f"Got hotkey to score: {hotkey_to_score}") + if get_config().ignore_min_stake: + validator_hotkeys.append(self.wallet.hotkey.ss58_address) - if not hotkey_to_score: - request_id = d.request.request_id - try: - del DojoTaskTracker._rid_to_mhotkey_to_task_id[request_id] - except KeyError: - pass - await DataManager.remove_responses([d]) - continue + batch_id = 0 + # number of tasks to process in a batch + batch_size = 10 + processed_request_ids = [] - logger.trace( - f"Initially had {len(d.miner_responses)} responses from miners, but only {len(hotkey_to_score.keys())} valid responses" + # figure out an expire_at cutoff time to determine those requests ready for scoring + try: + expire_at = await ORM.get_last_expire_at_cutoff(validator_hotkeys) + except ValueError: + logger.warning( + f"No tasks for scoring yet, please wait for tasks to to pass deadline of {dojo.TASK_DEADLINE} seconds" ) + continue - self.update_scores(hotkey_to_scores=hotkey_to_score) - await self.send_scores( - synapse=ScoringResult( - request_id=d.request.request_id, - hotkey_to_scores=hotkey_to_score, - ), - hotkeys=list(hotkey_to_score.keys()), - ) + async for ( + task_batch, + has_more_batches, + ) in ORM.get_expired_tasks( + validator_hotkeys, batch_size=batch_size, expire_at=expire_at + ): + if not has_more_batches: + logger.success( + f"📝 All tasks processed, total batches: {batch_id}, batch size: {batch_size}" + ) + gc.collect() + break - async def log_wandb(): - # calculate mean across all criteria + if not task_batch: + break - if not criteria_to_miner_score.values() or not hotkey_to_score: - logger.warning( - "No criteria to miner scores available. Skipping calculating averages for wandb." - ) - return - - mean_weighted_consensus_scores = ( - torch.stack( - [ - miner_scores.consensus.score - for miner_scores in criteria_to_miner_score.values() - ] + batch_id += 1 + logger.info( + f"📝 Processing batch {batch_id}, batch size: {batch_size}" + ) + for task in task_batch: + criteria_to_miner_score, hotkey_to_score = ( + Scoring.calculate_score( + criteria_types=task.request.criteria_types, + request=task.request, + miner_responses=task.miner_responses, ) - .mean(dim=0) - .tolist() ) - mean_weighted_gt_scores = ( - torch.stack( - [ - miner_scores.ground_truth.score - for miner_scores in criteria_to_miner_score.values() - ] - ) - .mean(dim=0) - .tolist() + logger.debug(f"📝 Got hotkey to score: {hotkey_to_score}") + logger.debug( + f"📝 Initially had {len(task.miner_responses)} responses from miners, but only {len(hotkey_to_score.keys())} valid responses" ) - logger.info( - f"mean miner scores across differerent criteria: consensus shape:{mean_weighted_consensus_scores}, gt shape:{mean_weighted_gt_scores}" + if not hotkey_to_score: + logger.info( + "📝 Did not manage to generate a dict of hotkey to score" + ) + # append it anyways so we can cut off later + processed_request_ids.append(task.request.request_id) + continue + + self.update_scores(hotkey_to_scores=hotkey_to_score) + await self.send_scores( + synapse=ScoringResult( + request_id=task.request.request_id, + hotkey_to_scores=hotkey_to_score, + ), + hotkeys=list(hotkey_to_score.keys()), ) - score_data = {} - # update the scores based on the rewards - score_data["scores_by_hotkey"] = hotkey_to_score - score_data["mean"] = { - "consensus": mean_weighted_consensus_scores, - "ground_truth": mean_weighted_gt_scores, - } - - wandb_data = jsonable_encoder( - { - "task": d.request.task_type, - "criteria": d.request.criteria_types, - "prompt": d.request.prompt, - "completions": jsonable_encoder( - d.request.completion_responses - ), - "num_completions": len(d.request.completion_responses), - "scores": score_data, - "num_responses": len(d.miner_responses), + async def log_wandb(): + # calculate mean across all criteria + + if ( + not criteria_to_miner_score.values() + or not hotkey_to_score + ): + logger.warning( + "📝 No criteria to miner scores available. Skipping calculating averages for wandb." + ) + return + + mean_weighted_consensus_scores = ( + torch.stack( + [ + miner_scores.consensus.score + for miner_scores in criteria_to_miner_score.values() + ] + ) + .mean(dim=0) + .tolist() + ) + mean_weighted_gt_scores = ( + torch.stack( + [ + miner_scores.ground_truth.score + for miner_scores in criteria_to_miner_score.values() + ] + ) + .mean(dim=0) + .tolist() + ) + + logger.info( + f"📝 Mean miner scores across different criteria: consensus shape:{mean_weighted_consensus_scores}, gt shape:{mean_weighted_gt_scores}" + ) + + score_data = {} + # update the scores based on the rewards + score_data["scores_by_hotkey"] = hotkey_to_score + score_data["mean"] = { + "consensus": mean_weighted_consensus_scores, + "ground_truth": mean_weighted_gt_scores, } - ) - wandb.log(wandb_data, commit=True) + wandb_data = jsonable_encoder( + { + "task": task.request.task_type, + "criteria": task.request.criteria_types, + "prompt": task.request.prompt, + "completions": jsonable_encoder( + task.request.completion_responses + ), + "num_completions": len( + task.request.completion_responses + ), + "scores": score_data, + "num_responses": len(task.miner_responses), + } + ) - asyncio.create_task(log_wandb()) + wandb.log(wandb_data, commit=True) - # once we have scored a response, just remove it - await DataManager.remove_responses([d]) + asyncio.create_task(log_wandb()) + + # once we have scored a response, just remove it + processed_request_ids.append(task.request.request_id) + + if processed_request_ids: + await ORM.mark_tasks_processed_by_request_ids(processed_request_ids) except Exception: traceback.print_exc() @@ -224,14 +267,14 @@ async def send_heartbeats(self): try: all_miner_uids = extract_miner_uids(metagraph=self.metagraph) logger.debug(f"Sending heartbeats to {len(all_miner_uids)} miners") - axons: List[bt.AxonInfo] = [ + axons: list[bt.AxonInfo] = [ self.metagraph.axons[uid] for uid in all_miner_uids if self.metagraph.axons[uid].hotkey.casefold() != self.wallet.hotkey.ss58_address.casefold() ] - responses: List[Heartbeat] = await self.dendrite.forward( + responses: List[Heartbeat] = await self.dendrite.forward( # type: ignore axons=axons, synapse=Heartbeat(), deserialize=False, timeout=12 ) active_hotkeys = [r.axon.hotkey for r in responses if r.ack and r.axon] @@ -421,8 +464,11 @@ async def send_request( valid_miner_responses: List[FeedbackRequest] = [] try: for miner_response in miner_responses: + miner_hotkey = ( + miner_response.axon.hotkey if miner_response.axon else "??" + ) logger.debug( - f"Received response from miner: {miner_response.axon.hotkey, miner_response.dojo_task_id}" + f"Received response from miner: {miner_hotkey, miner_response.dojo_task_id}" ) # map obfuscated model names back to the original model names real_model_ids = [] @@ -441,13 +487,11 @@ async def send_request( continue if miner_response.dojo_task_id is None: - logger.debug( - "Miner must provide the dojo task id for scoring method dojo" - ) + logger.debug(f"Miner {miner_hotkey} must provide the dojo task id") continue logger.debug( - f"Successfully mapped obfuscated model names for {miner_response.axon.hotkey}" + f"Successfully mapped obfuscated model names for {miner_hotkey}" ) # update the miner response with the real model ids @@ -459,32 +503,28 @@ async def send_request( logger.info(f"⬇️ Got {len(valid_miner_responses)} valid responses") if valid_miner_responses is None or len(valid_miner_responses) == 0: - logger.warning("No valid miner responses to process... skipping") + logger.info("No valid miner responses to process... skipping") return # include the ground_truth to keep in data manager synapse.ground_truth = data.ground_truth + synapse.dendrite.hotkey = self.wallet.hotkey.ss58_address response_data = DendriteQueryResponse( request=synapse, miner_responses=valid_miner_responses, ) logger.debug("Attempting to saving dendrite response") - fb_request_model = await DataManager.save_dendrite_response( - response=response_data + vali_request_model = await ORM.save_task( + validator_request=synapse, + miner_responses=valid_miner_responses, + ground_truth=data.ground_truth, ) - if fb_request_model is None: + if vali_request_model is None: logger.error("Failed to save dendrite response") return - logger.debug("Attempting to update task map") - await DojoTaskTracker.update_task_map( - synapse.request_id, - fb_request_model, - obfuscated_model_to_model, - ) - # saving response logger.success( f"Saved dendrite response for request id: {response_data.request.request_id}" @@ -513,6 +553,7 @@ async def run(self): self.step += 1 except Exception as e: + traceback.print_exc() logger.error(f"Error during validator run: {e}") pass await asyncio.sleep(dojo.VALIDATOR_RUN) @@ -618,15 +659,18 @@ def resync_metagraph(self): # If so, we need to add new hotkeys and moving averages. if len(previous_metagraph.hotkeys) < len(self.metagraph.hotkeys): # Update the size of the moving average scores. - new_moving_average = np.zeros(self.metagraph.n) + new_moving_average = torch.zeros(self.metagraph.n) min_len = min(len(previous_metagraph.hotkeys), len(self.scores)) new_moving_average[:min_len] = self.scores[:min_len] self.scores = new_moving_average - def update_scores(self, hotkey_to_scores): + def update_scores(self, hotkey_to_scores: dict[str, float]): """Performs exponential moving average on the scores based on the rewards received from the miners, after setting the self.scores variable here, `set_weights` will be called to set the weights on chain. """ + if not hotkey_to_scores: + logger.warning("hotkey_to_scores is empty, skipping score update") + return nan_value_indices = np.isnan(list(hotkey_to_scores.values())) if nan_value_indices.any(): @@ -640,7 +684,7 @@ def update_scores(self, hotkey_to_scores): for index, (key, value) in enumerate(hotkey_to_scores.items()): # handle nan values if nan_value_indices[index]: - rewards[key] = 0.0 + rewards[key] = 0.0 # type: ignore # search metagraph for hotkey and grab uid try: uid = neuron_hotkeys.index(key) @@ -657,50 +701,243 @@ def update_scores(self, hotkey_to_scores): # Update scores with rewards produced by this step. # shape: [ metagraph.n ] alpha: float = self.config.neuron.moving_average_alpha - self.scores: torch.Tensor = alpha * rewards + (1 - alpha) * self.scores + self.scores = alpha * rewards + (1 - alpha) * self.scores logger.debug(f"Updated scores: {self.scores}") + async def _save_state( + self, + ): + """Saves the state of the validator to the database.""" + if self.step == 0: + return + + try: + if np.count_nonzero(self.scores) == 0: + raise EmptyScores("Skipping save as scores are all empty") + + await ORM.create_or_update_validator_score(self.scores) + logger.success(f"📦 Saved validator state with scores: {self.scores}") + except EmptyScores as e: + logger.debug(f"No need to to save validator state: {e}") + except Exception as e: + logger.error(f"Failed to save validator state: {e}") + def save_state(self): """Saves the state of the validator to a file.""" try: loop = asyncio.get_event_loop() - loop.run_until_complete( - DataManager.validator_save( - self.scores, - DojoTaskTracker._rid_to_mhotkey_to_task_id, - DojoTaskTracker._rid_to_model_map, - DojoTaskTracker._task_to_expiry, - ) - ) + loop.run_until_complete(self._save_state()) except Exception as e: logger.error(f"Failed to save validator state: {e}") pass + async def _load_state(self): + try: + await connect_db() + scores = await ORM.get_validator_score() + + if not scores: + num_processed_tasks = await ORM.get_num_processed_tasks() + if num_processed_tasks > 0: + logger.error( + "Score record not found, but you have processed tasks." + ) + else: + logger.warning( + "Score record not found, and no tasks processed, this is okay if you're running for the first time." + ) + return None + + logger.success(f"Loaded validator state: {scores=}") + self.scores = scores + + except Exception as e: + logger.error( + f"Unexpected error occurred while loading validator state: {e}" + ) + return None + def load_state(self): """Loads the state of the validator from a file.""" loop = asyncio.get_event_loop() - loop.run_until_complete(connect_db()) - state_data = loop.run_until_complete(DataManager.validator_load()) - if state_data is None: - if self.step == 0: - logger.warning( - "Failed to load validator state data, this is okay on start, or if you're running for the first time." - ) - else: - logger.error("Failed to load validator state data") - return - - self.scores = state_data[ValidatorStateKeys.SCORES] - DojoTaskTracker._rid_to_mhotkey_to_task_id = state_data[ - ValidatorStateKeys.DOJO_TASKS_TO_TRACK - ] - DojoTaskTracker._rid_to_model_map = state_data[ValidatorStateKeys.MODEL_MAP] - DojoTaskTracker._task_to_expiry = state_data[ValidatorStateKeys.TASK_TO_EXPIRY] - - logger.info(f"Scores state: {self.scores}") + loop.run_until_complete(self._load_state()) @classmethod async def log_validator_status(cls): while not cls._should_exit: logger.info(f"Validator running... {time.time()}") await asyncio.sleep(dojo.VALIDATOR_STATUS) + + async def _get_task_results_from_miner( + self, miner_hotkey: str, task_id: str + ) -> list[TaskResult]: + """Fetch task results from the miner's Axon using Dendrite.""" + try: + if not self.dendrite: + raise ValueError("Dendrite not initialized") + + # Prepare the synapse (data request) that will be sent via Dendrite + task_synapse = TaskResultRequest(task_id=task_id) + + # Use Dendrite to communicate with the Axon + miner_axon = self.metagraph.axons[ + self.metagraph.hotkeys.index(miner_hotkey) + ] + if not miner_axon: + raise ValueError(f"Miner Axon not found for hotkey: {miner_hotkey}") + + # Send the request via Dendrite and get the response + response: list[TaskResultRequest] = await self.dendrite.forward( # type: ignore + axons=[miner_axon], synapse=task_synapse, deserialize=False + ) + + if response and response[0]: + logger.debug( + f"Received task result from miner {miner_hotkey} for task {task_id}, {response}" + ) + return response[0].task_results + else: + logger.debug( + f"No task results found from miner {miner_hotkey} for task {task_id}" + ) + return [] + + except Exception as e: + logger.error(f"Error fetching task result from miner {miner_hotkey}: {e}") + return [] + + async def monitor_task_completions(self): + while not self._should_exit: + try: + validator_hotkeys = [ + hotkey + for uid, hotkey in enumerate(self.metagraph.hotkeys) + if not is_miner(self.metagraph, uid) + ] + + if get_config().ignore_min_stake: + validator_hotkeys.append(self.wallet.hotkey.ss58_address) + + batch_id = 0 + batch_size = 10 + # use current time as cutoff so we get only unexpired tasks + now = datetime_as_utc(datetime.now(timezone.utc)) + async for task_batch, has_more_batches in ORM.get_expired_tasks( + validator_hotkeys=validator_hotkeys, + batch_size=batch_size, + expire_at=now, + ): + if not has_more_batches: + logger.success( + "No more unexpired tasks found for processing, exiting task monitoring." + ) + gc.collect() + break + + if not task_batch: + continue + + batch_id += 1 + logger.info(f"Monitoring task completions, batch id: {batch_id}") + + for task in task_batch: + request_id = task.request.request_id + miner_responses = task.miner_responses + + obfuscated_to_real_model_id = await ORM.get_real_model_ids( + request_id + ) + + for miner_response in miner_responses: + if ( + not miner_response.axon + or not miner_response.axon.hotkey + or not miner_response.dojo_task_id + ): + raise InvalidMinerResponse( + f"""Missing hotkey, task_id, or axon: + axon: {miner_response.axon} + hotkey: {miner_response.axon.hotkey} + task_id: {miner_response.dojo_task_id}""" + ) + + miner_hotkey = miner_response.axon.hotkey + task_id = miner_response.dojo_task_id + task_results = await asyncio.create_task( + self._get_task_results_from_miner(miner_hotkey, task_id) + ) + + if not task_results and not len(task_results) > 0: + logger.debug( + f"Task ID: {task_id} by miner: {miner_hotkey} has not been completed yet or no task results." + ) + continue + + # Process task result + model_id_to_avg_rank, model_id_to_avg_score = ( + self._calculate_averages( + task_results, obfuscated_to_real_model_id + ) + ) + + # Update the response with the new ranks and scores + for completion in miner_response.completion_responses: + model_id = completion.model + if model_id in model_id_to_avg_rank: + completion.rank_id = int( + model_id_to_avg_rank[model_id] + ) + if model_id in model_id_to_avg_score: + completion.score = model_id_to_avg_score[model_id] + + # Update miner responses in the database + success = await ORM.update_miner_completions_by_request_id( + request_id, task.miner_responses + ) + + logger.info( + f"Updating task {request_id} with miner's completion data, success ? {success}" + ) + await asyncio.sleep(0.2) + except NoNewUnexpiredTasksYet as e: + logger.info(f"No new unexpired tasks yet: {e}") + except Exception as e: + traceback.print_exc() + logger.error(f"Error during Dojo task monitoring {str(e)}") + pass + await asyncio.sleep(dojo.DOJO_TASK_MONITORING) + + @staticmethod + def _calculate_averages( + task_results: list[TaskResult], obfuscated_to_real_model_id + ): + model_id_to_avg_rank = defaultdict(float) + model_id_to_avg_score = defaultdict(float) + num_ranks_by_workers, num_scores_by_workers = 0, 0 + + for result in task_results: + for result_data in result.result_data: + type = result_data.type + value = result_data.value + if type == CriteriaTypeEnum.RANKING_CRITERIA: + for model_id, rank in value.items(): + real_model_id = obfuscated_to_real_model_id.get( + model_id, model_id + ) + model_id_to_avg_rank[real_model_id] += rank + num_ranks_by_workers += 1 + elif type == CriteriaTypeEnum.MULTI_SCORE: + for model_id, score in value.items(): + real_model_id = obfuscated_to_real_model_id.get( + model_id, model_id + ) + model_id_to_avg_score[real_model_id] += score + num_scores_by_workers += 1 + + # Average the ranks and scores + for model_id in model_id_to_avg_rank: + model_id_to_avg_rank[model_id] /= num_ranks_by_workers + for model_id in model_id_to_avg_score: + model_id_to_avg_score[model_id] /= num_scores_by_workers + + return model_id_to_avg_rank, model_id_to_avg_score diff --git a/schema.prisma b/schema.prisma index 7da06143..b656db00 100644 --- a/schema.prisma +++ b/schema.prisma @@ -1,6 +1,6 @@ generator client { provider = "prisma-client-py" - recursive_type_depth = -1 + recursive_type_depth = 5 output = "./database/prisma" } @@ -9,77 +9,78 @@ datasource db { url = env("DATABASE_URL") } -enum Criteria_Type_Enum_Model { +enum CriteriaTypeEnum { RANKING_CRITERIA MULTI_SCORE SCORE MULTI_SELECT } -model Feedback_Request_Model { - id String @id @default(uuid()) - request_id String @unique - task_type String - prompt String - ground_truth Json - criteria_types Criteria_Type_Model[] - miner_responses Miner_Response_Model[] - created_at DateTime @default(now()) - updated_at DateTime @updatedAt +model Ground_Truth_Model { + id String @id @default(uuid()) + request_id String + obfuscated_model_id String + real_model_id String + rank_id Int + feedback_request Feedback_Request_Model @relation(fields: [feedback_request_id], references: [id]) + feedback_request_id String + created_at DateTime @default(now()) + updated_at DateTime @updatedAt + + @@unique([request_id, obfuscated_model_id, rank_id]) } -model Miner_Response_Model { - id String @id @default(uuid()) - request_id String - miner_hotkey String - dojo_task_id String - expire_at String - completions Completion_Response_Model[] - feedback_request Feedback_Request_Model @relation(fields: [request_id], references: [request_id]) - created_at DateTime @default(now()) - updated_at DateTime @updatedAt +model Feedback_Request_Model { + id String @id @unique @default(uuid()) + request_id String + prompt String + completions Completion_Response_Model[] + task_type String + criteria_types Criteria_Type_Model[] + is_processed Boolean @default(false) + dojo_task_id String? + hotkey String + expire_at DateTime + created_at DateTime @default(now()) + updated_at DateTime @updatedAt + + ground_truths Ground_Truth_Model[] + + parent_request Feedback_Request_Model? @relation("ParentChild", fields: [parent_id], references: [id]) + parent_id String? + child_requests Feedback_Request_Model[] @relation("ParentChild") + + @@unique([request_id, hotkey]) } model Completion_Response_Model { - id String @id @default(uuid()) - completion_id String - model String - completion Json - rank_id Int? - score Float? - miner_response_id String - miner_response Miner_Response_Model @relation(fields: [miner_response_id], references: [id]) - created_at DateTime @default(now()) - updated_at DateTime @updatedAt + id String @id @default(uuid()) + completion_id String + model String + completion Json + rank_id Int? + score Float? + created_at DateTime @default(now()) + updated_at DateTime @updatedAt + feedback_request_relation Feedback_Request_Model @relation(fields: [feedback_request_id], references: [id]) + feedback_request_id String } model Criteria_Type_Model { - id String @id @default(uuid()) - type Criteria_Type_Enum_Model - options Json? - min Float? - max Float? - request_id String - feedback_request Feedback_Request_Model @relation(fields: [request_id], references: [request_id]) - created_at DateTime @default(now()) - updated_at DateTime @updatedAt -} - -// request_id, miner_hotkey, task_id are composite key -model Validator_State_Model { - id String @id @default(uuid()) - request_id String - miner_hotkey String - task_id String - expire_at String - obfuscated_model String - real_model String - created_at DateTime @default(now()) - updated_at DateTime @updatedAt + id String @id @default(uuid()) + type CriteriaTypeEnum + options Json + min Float? + max Float? + feedback_request_relation Feedback_Request_Model? @relation(fields: [feedback_request_id], references: [id]) + feedback_request_id String? + created_at DateTime @default(now()) + updated_at DateTime @updatedAt } model Score_Model { id String @id @default(uuid()) + // json array of scores score Json created_at DateTime @default(now()) updated_at DateTime @updatedAt