Skip to content

Commit

Permalink
fix: added expiredAt in DojoTaskTracker and DataManager
Browse files Browse the repository at this point in the history
  • Loading branch information
codebender37 committed Jul 5, 2024
1 parent 8deb576 commit ad46be8
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 26 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,6 @@ repos:
rev: v3.27.0 # This specifies that we are using version 3.27.0 of the commitizen repository
hooks:
- id: commitizen
- id: commitizen-branch
stages: [commit-msg]
- id: commitizen-branch
stages: [pre-push]
19 changes: 13 additions & 6 deletions commons/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@
import json
import pickle
from pathlib import Path
from typing import Any, List, Optional
from typing import Any, List

import torch
from commons.objects import ObjectManager
from loguru import logger
from strenum import StrEnum

from commons.objects import ObjectManager
from template.protocol import DendriteQueryResponse, FeedbackRequest


class ValidatorStateKeys(StrEnum):
SCORES = "scores"
DOJO_TASKS_TO_TRACK = "dojo_tasks_to_track"
MODEL_MAP = "model_map"
Task_TO_EXPIRY = "task_to_expiry"


class DataManager:
Expand All @@ -24,7 +26,7 @@ class DataManager:

def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(DataManager, cls).__new__(cls)
cls._instance = super().__new__(cls)
cls._ensure_paths_exist()
return cls._instance

Expand All @@ -46,7 +48,7 @@ def get_validator_state_filepath() -> Path:
return base_path / "data" / "validator_state.pt"

@classmethod
async def _load_without_lock(cls, path) -> Optional[List[DendriteQueryResponse]]:
async def _load_without_lock(cls, path) -> List[DendriteQueryResponse] | None:
try:
with open(str(path), "rb") as file:
return pickle.load(file)
Expand Down Expand Up @@ -145,7 +147,7 @@ async def get_by_request_id(cls, request_id):
@classmethod
async def remove_responses(
cls, responses: List[DendriteQueryResponse]
) -> Optional[DendriteQueryResponse]:
) -> DendriteQueryResponse | None:
path = DataManager.get_requests_data_path()
async with cls._lock:
data = await DataManager._load_without_lock(path=path)
Expand All @@ -168,7 +170,9 @@ async def remove_responses(
await DataManager._save_without_lock(path, new_data)

@classmethod
async def validator_save(cls, scores, requestid_to_mhotkey_to_task_id, model_map):
async def validator_save(
cls, scores, requestid_to_mhotkey_to_task_id, model_map, task_to_expiry
):
"""Saves the state of the validator to a file."""
logger.debug("Attempting to save validator state.")
async with cls._validator_lock:
Expand All @@ -184,6 +188,9 @@ async def validator_save(cls, scores, requestid_to_mhotkey_to_task_id, model_map
json.dumps(requestid_to_mhotkey_to_task_id)
),
ValidatorStateKeys.MODEL_MAP: json.loads(json.dumps(model_map)),
ValidatorStateKeys.Task_TO_EXPIRY: json.loads(
json.dumps(task_to_expiry)
),
},
cls.get_validator_state_filepath(),
)
Expand Down
16 changes: 10 additions & 6 deletions commons/human_feedback/dojo.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,20 @@ async def create_task(
else:
logger.error(f"Unrecognized criteria type: {type(criteria_type)}")

body = {
"title": "LLM Code Generation Task",
"body": ranking_request.prompt,
"expireAt": (
expireAt = (
(
datetime.datetime.utcnow()
+ datetime.timedelta(seconds=template.TASK_DEADLINE)
)
.replace(microsecond=0, tzinfo=datetime.timezone.utc)
.isoformat()
.replace("+00:00", "Z"),
.replace("+00:00", "Z")
)

body = {
"title": "LLM Code Generation Task",
"body": ranking_request.prompt,
"expireAt": expireAt,
"taskData": json.dumps([taskData]),
"maxResults": "1",
}
Expand Down Expand Up @@ -140,4 +144,4 @@ async def create_task(
f"Error occurred when trying to create task\nErr:{response.json()['error']}"
)
response.raise_for_status()
return task_ids
return task_ids, expireAt
3 changes: 2 additions & 1 deletion neurons/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,10 @@ async def forward_feedback_request(
scoring_method = self.config.scoring_method
if scoring_method.casefold() == ScoringMethod.DOJO:
synapse.scoring_method = ScoringMethod.DOJO
task_ids = await DojoAPI.create_task(synapse)
task_ids, expireAt = await DojoAPI.create_task(synapse)
assert len(task_ids) == 1
synapse.dojo_task_id = task_ids[0]
synapse.expireAt = expireAt
else:
logger.error("Unrecognized scoring method!")
except Exception:
Expand Down
10 changes: 8 additions & 2 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
import bittensor as bt
import numpy as np
import torch
import wandb
from fastapi.encoders import jsonable_encoder
from loguru import logger
from torch.nn import functional as F

import template
import wandb
from commons.data_manager import DataManager, ValidatorStateKeys
from commons.dataset.synthetic import SyntheticAPI
from commons.human_feedback.dojo import DojoAPI
Expand Down Expand Up @@ -48,6 +48,7 @@ class DojoTaskTracker:
lambda: defaultdict(str)
)
_rid_to_model_map: Dict[str, Dict[str, str]] = defaultdict(lambda: defaultdict(str))
_task_to_expiry: Dict[str, str] = defaultdict(str)
_lock = asyncio.Lock()
_should_exit: bool = False

Expand Down Expand Up @@ -75,7 +76,7 @@ async def update_task_map(

logger.debug("update_task_map attempting to acquire lock")
async with cls._lock:
valid_responses = list(
valid_responses: List[FeedbackRequest] = list(
filter(
lambda r: r.request_id == request_id
and r.axon.hotkey
Expand All @@ -94,6 +95,7 @@ async def update_task_map(
cls._rid_to_mhotkey_to_task_id[request_id][r.axon.hotkey] = (
r.dojo_task_id
)
cls._task_to_expiry[r.dojo_task_id] = r.expireAt
cls._rid_to_model_map[request_id] = obfuscated_model_to_model
logger.debug("released lock for task tracker")
return
Expand All @@ -107,6 +109,7 @@ async def monitor_task_completions(cls):
logger.info(
f"Monitoring Dojo Task completions... {get_epoch_time()} for {len(cls._rid_to_mhotkey_to_task_id)} requests"
)
logger.info(f"print task_to_expiry {cls._task_to_expiry}")
if not cls._rid_to_mhotkey_to_task_id:
await asyncio.sleep(SLEEP_SECONDS)
continue
Expand Down Expand Up @@ -685,6 +688,7 @@ def save_state(self):
self.scores,
DojoTaskTracker._rid_to_mhotkey_to_task_id,
DojoTaskTracker._rid_to_model_map,
DojoTaskTracker._task_to_expiry,
)
)
except Exception as e:
Expand All @@ -705,11 +709,13 @@ def load_state(self):
ValidatorStateKeys.DOJO_TASKS_TO_TRACK
]
DojoTaskTracker._rid_to_model_map = state_data[ValidatorStateKeys.MODEL_MAP]
DojoTaskTracker._task_to_expiry = state_data[ValidatorStateKeys.Task_TO_EXPIRY]

logger.info(f"Scores state: {self.scores}")
logger.info(
f"Dojo Tasks to track: {DojoTaskTracker._rid_to_mhotkey_to_task_id}"
)
logger.info(f"Task to expiry {DojoTaskTracker._task_to_expiry}")

@classmethod
async def log_validator_status(cls):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ py-modules = ["dojo_cli"]

[tool.setuptools.dynamic]
dependencies = {file = ["requirements.txt"]}
optional-dependencies.dev = { file = ["requirements-dev.txt"] }
optional-dependencies-dev = { file = ["requirements-dev.txt"] }

[tool.commitizen]
name = "cz_conventional_commits"
19 changes: 10 additions & 9 deletions template/protocol.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Dict, List, Optional, Union
from typing import Dict, List

from commons.utils import get_epoch_time, get_new_uuid
import bittensor as bt
from pydantic import BaseModel, Field
from strenum import StrEnum

import bittensor as bt
from commons.utils import get_epoch_time, get_new_uuid


class TaskType(StrEnum):
Expand Down Expand Up @@ -40,7 +40,7 @@ class Config:
max: float = Field(description="Maximum score for the task")


CriteriaType = Union[MultiScoreCriteria, RankingCriteria]
CriteriaType = MultiScoreCriteria | RankingCriteria


class ScoringMethod(StrEnum):
Expand All @@ -61,7 +61,7 @@ class CodeAnswer(BaseModel):
installation_commands: str = Field(
description="Terminal commands for the code to be able to run to install any third-party packages for the code to be able to run"
)
additional_notes: Optional[str] = Field(
additional_notes: str | None = Field(
description="Any additional notes or comments about the code solution"
)

Expand All @@ -74,10 +74,10 @@ class Response(BaseModel):
default_factory=get_new_uuid,
description="Unique identifier for the completion",
)
rank_id: Optional[int] = Field(
rank_id: int | None = Field(
description="Rank of the completion", examples=[1, 2, 3, 4]
)
score: Optional[float] = Field(description="Score of the completion")
score: float | None = Field(description="Score of the completion")


class SyntheticQA(BaseModel):
Expand Down Expand Up @@ -108,10 +108,11 @@ class FeedbackRequest(bt.Synapse):
description="Types of criteria for the task",
allow_mutation=False,
)
scoring_method: Optional[str] = Field(
scoring_method: str | None = Field(
decscription="Method to use for scoring completions"
)
dojo_task_id: Optional[str] = Field(description="Dojo task ID for the request")
dojo_task_id: str | None = Field(description="Dojo task ID for the request")
expireAt: str | None = Field(description="Expired time for Dojo task")


class ScoringResult(bt.Synapse):
Expand Down

0 comments on commit ad46be8

Please sign in to comment.