diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index af174f5e..401fea13 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,5 +39,6 @@ repos: rev: v3.27.0 # This specifies that we are using version 3.27.0 of the commitizen repository hooks: - id: commitizen - - id: commitizen-branch stages: [commit-msg] + - id: commitizen-branch + stages: [pre-push] diff --git a/commons/data_manager.py b/commons/data_manager.py index 99aa77a6..bc69fa4c 100644 --- a/commons/data_manager.py +++ b/commons/data_manager.py @@ -2,12 +2,13 @@ import json import pickle from pathlib import Path -from typing import Any, List, Optional +from typing import Any, List import torch -from commons.objects import ObjectManager from loguru import logger from strenum import StrEnum + +from commons.objects import ObjectManager from template.protocol import DendriteQueryResponse, FeedbackRequest @@ -15,6 +16,7 @@ class ValidatorStateKeys(StrEnum): SCORES = "scores" DOJO_TASKS_TO_TRACK = "dojo_tasks_to_track" MODEL_MAP = "model_map" + Task_TO_EXPIRY = "task_to_expiry" class DataManager: @@ -24,7 +26,7 @@ class DataManager: def __new__(cls, *args, **kwargs): if cls._instance is None: - cls._instance = super(DataManager, cls).__new__(cls) + cls._instance = super().__new__(cls) cls._ensure_paths_exist() return cls._instance @@ -46,7 +48,7 @@ def get_validator_state_filepath() -> Path: return base_path / "data" / "validator_state.pt" @classmethod - async def _load_without_lock(cls, path) -> Optional[List[DendriteQueryResponse]]: + async def _load_without_lock(cls, path) -> List[DendriteQueryResponse] | None: try: with open(str(path), "rb") as file: return pickle.load(file) @@ -145,7 +147,7 @@ async def get_by_request_id(cls, request_id): @classmethod async def remove_responses( cls, responses: List[DendriteQueryResponse] - ) -> Optional[DendriteQueryResponse]: + ) -> DendriteQueryResponse | None: path = DataManager.get_requests_data_path() async with cls._lock: data = await DataManager._load_without_lock(path=path) @@ -168,7 +170,9 @@ async def remove_responses( await DataManager._save_without_lock(path, new_data) @classmethod - async def validator_save(cls, scores, requestid_to_mhotkey_to_task_id, model_map): + async def validator_save( + cls, scores, requestid_to_mhotkey_to_task_id, model_map, task_to_expiry + ): """Saves the state of the validator to a file.""" logger.debug("Attempting to save validator state.") async with cls._validator_lock: @@ -184,6 +188,9 @@ async def validator_save(cls, scores, requestid_to_mhotkey_to_task_id, model_map json.dumps(requestid_to_mhotkey_to_task_id) ), ValidatorStateKeys.MODEL_MAP: json.loads(json.dumps(model_map)), + ValidatorStateKeys.Task_TO_EXPIRY: json.loads( + json.dumps(task_to_expiry) + ), }, cls.get_validator_state_filepath(), ) diff --git a/commons/human_feedback/dojo.py b/commons/human_feedback/dojo.py index 879751e2..46fa8d51 100644 --- a/commons/human_feedback/dojo.py +++ b/commons/human_feedback/dojo.py @@ -92,16 +92,20 @@ async def create_task( else: logger.error(f"Unrecognized criteria type: {type(criteria_type)}") - body = { - "title": "LLM Code Generation Task", - "body": ranking_request.prompt, - "expireAt": ( + expireAt = ( + ( datetime.datetime.utcnow() + datetime.timedelta(seconds=template.TASK_DEADLINE) ) .replace(microsecond=0, tzinfo=datetime.timezone.utc) .isoformat() - .replace("+00:00", "Z"), + .replace("+00:00", "Z") + ) + + body = { + "title": "LLM Code Generation Task", + "body": ranking_request.prompt, + "expireAt": expireAt, "taskData": json.dumps([taskData]), "maxResults": "1", } @@ -140,4 +144,4 @@ async def create_task( f"Error occurred when trying to create task\nErr:{response.json()['error']}" ) response.raise_for_status() - return task_ids + return task_ids, expireAt diff --git a/neurons/miner.py b/neurons/miner.py index 9de3e585..22c66e78 100644 --- a/neurons/miner.py +++ b/neurons/miner.py @@ -87,9 +87,10 @@ async def forward_feedback_request( scoring_method = self.config.scoring_method if scoring_method.casefold() == ScoringMethod.DOJO: synapse.scoring_method = ScoringMethod.DOJO - task_ids = await DojoAPI.create_task(synapse) + task_ids, expireAt = await DojoAPI.create_task(synapse) assert len(task_ids) == 1 synapse.dojo_task_id = task_ids[0] + synapse.expireAt = expireAt else: logger.error("Unrecognized scoring method!") except Exception: diff --git a/neurons/validator.py b/neurons/validator.py index 86b16eb9..32c318c7 100644 --- a/neurons/validator.py +++ b/neurons/validator.py @@ -10,12 +10,12 @@ import bittensor as bt import numpy as np import torch -import wandb from fastapi.encoders import jsonable_encoder from loguru import logger from torch.nn import functional as F import template +import wandb from commons.data_manager import DataManager, ValidatorStateKeys from commons.dataset.synthetic import SyntheticAPI from commons.human_feedback.dojo import DojoAPI @@ -48,6 +48,7 @@ class DojoTaskTracker: lambda: defaultdict(str) ) _rid_to_model_map: Dict[str, Dict[str, str]] = defaultdict(lambda: defaultdict(str)) + _task_to_expiry: Dict[str, str] = defaultdict(str) _lock = asyncio.Lock() _should_exit: bool = False @@ -75,7 +76,7 @@ async def update_task_map( logger.debug("update_task_map attempting to acquire lock") async with cls._lock: - valid_responses = list( + valid_responses: List[FeedbackRequest] = list( filter( lambda r: r.request_id == request_id and r.axon.hotkey @@ -94,6 +95,7 @@ async def update_task_map( cls._rid_to_mhotkey_to_task_id[request_id][r.axon.hotkey] = ( r.dojo_task_id ) + cls._task_to_expiry[r.dojo_task_id] = r.expireAt cls._rid_to_model_map[request_id] = obfuscated_model_to_model logger.debug("released lock for task tracker") return @@ -107,6 +109,7 @@ async def monitor_task_completions(cls): logger.info( f"Monitoring Dojo Task completions... {get_epoch_time()} for {len(cls._rid_to_mhotkey_to_task_id)} requests" ) + logger.info(f"print task_to_expiry {cls._task_to_expiry}") if not cls._rid_to_mhotkey_to_task_id: await asyncio.sleep(SLEEP_SECONDS) continue @@ -685,6 +688,7 @@ def save_state(self): self.scores, DojoTaskTracker._rid_to_mhotkey_to_task_id, DojoTaskTracker._rid_to_model_map, + DojoTaskTracker._task_to_expiry, ) ) except Exception as e: @@ -705,11 +709,13 @@ def load_state(self): ValidatorStateKeys.DOJO_TASKS_TO_TRACK ] DojoTaskTracker._rid_to_model_map = state_data[ValidatorStateKeys.MODEL_MAP] + DojoTaskTracker._task_to_expiry = state_data[ValidatorStateKeys.Task_TO_EXPIRY] logger.info(f"Scores state: {self.scores}") logger.info( f"Dojo Tasks to track: {DojoTaskTracker._rid_to_mhotkey_to_task_id}" ) + logger.info(f"Task to expiry {DojoTaskTracker._task_to_expiry}") @classmethod async def log_validator_status(cls): diff --git a/pyproject.toml b/pyproject.toml index ddef10ba..614c1073 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,7 +137,7 @@ py-modules = ["dojo_cli"] [tool.setuptools.dynamic] dependencies = {file = ["requirements.txt"]} -optional-dependencies.dev = { file = ["requirements-dev.txt"] } +optional-dependencies-dev = { file = ["requirements-dev.txt"] } [tool.commitizen] name = "cz_conventional_commits" diff --git a/template/protocol.py b/template/protocol.py index 254f9ce2..b7d1fa8d 100644 --- a/template/protocol.py +++ b/template/protocol.py @@ -1,10 +1,10 @@ -from typing import Dict, List, Optional, Union +from typing import Dict, List -from commons.utils import get_epoch_time, get_new_uuid +import bittensor as bt from pydantic import BaseModel, Field from strenum import StrEnum -import bittensor as bt +from commons.utils import get_epoch_time, get_new_uuid class TaskType(StrEnum): @@ -40,7 +40,7 @@ class Config: max: float = Field(description="Maximum score for the task") -CriteriaType = Union[MultiScoreCriteria, RankingCriteria] +CriteriaType = MultiScoreCriteria | RankingCriteria class ScoringMethod(StrEnum): @@ -61,7 +61,7 @@ class CodeAnswer(BaseModel): installation_commands: str = Field( description="Terminal commands for the code to be able to run to install any third-party packages for the code to be able to run" ) - additional_notes: Optional[str] = Field( + additional_notes: str | None = Field( description="Any additional notes or comments about the code solution" ) @@ -74,10 +74,10 @@ class Response(BaseModel): default_factory=get_new_uuid, description="Unique identifier for the completion", ) - rank_id: Optional[int] = Field( + rank_id: int | None = Field( description="Rank of the completion", examples=[1, 2, 3, 4] ) - score: Optional[float] = Field(description="Score of the completion") + score: float | None = Field(description="Score of the completion") class SyntheticQA(BaseModel): @@ -108,10 +108,11 @@ class FeedbackRequest(bt.Synapse): description="Types of criteria for the task", allow_mutation=False, ) - scoring_method: Optional[str] = Field( + scoring_method: str | None = Field( decscription="Method to use for scoring completions" ) - dojo_task_id: Optional[str] = Field(description="Dojo task ID for the request") + dojo_task_id: str | None = Field(description="Dojo task ID for the request") + expireAt: str | None = Field(description="Expired time for Dojo task") class ScoringResult(bt.Synapse):