Skip to content

Commit

Permalink
fix: refactored from MR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
codebender37 committed Jul 11, 2024
1 parent 2a8cb96 commit cc230e0
Show file tree
Hide file tree
Showing 10 changed files with 89 additions and 31 deletions.
4 changes: 2 additions & 2 deletions commons/data_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import json
import pickle
from datetime import datetime
from pathlib import Path
from typing import Any, List

Expand All @@ -10,6 +9,7 @@
from strenum import StrEnum

from commons.objects import ObjectManager
from commons.utils import get_current_utc_time_iso
from template.protocol import DendriteQueryResponse, FeedbackRequest


Expand Down Expand Up @@ -225,7 +225,7 @@ async def remove_expired_tasks_from_storage():
return

# Identify expired tasks
current_time = datetime.utcnow().isoformat() + "Z"
current_time = get_current_utc_time_iso()
task_to_expiry = state_data.get(ValidatorStateKeys.TASK_TO_EXPIRY, {})
expired_tasks = [
task_id
Expand Down
20 changes: 16 additions & 4 deletions commons/dojo_task_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@
import copy
import traceback
from collections import defaultdict
from datetime import datetime
from typing import Dict, List

import bittensor as bt
from loguru import logger

import template
from commons.data_manager import DataManager
from commons.human_feedback.dojo import DojoAPI
from commons.objects import ObjectManager
from commons.utils import get_epoch_time
from commons.utils import (
get_current_utc_time_iso,
get_epoch_time,
is_valid_expiry,
set_expire_time,
)
from template.protocol import (
CriteriaTypeEnum,
FeedbackRequest,
Expand Down Expand Up @@ -75,15 +80,22 @@ async def update_task_map(
cls._rid_to_mhotkey_to_task_id[request_id][r.axon.hotkey] = (
r.dojo_task_id
)
cls._task_to_expiry[r.dojo_task_id] = r.expireAt

# Ensure expire_at is set and is reasonable
expire_at = r.expire_at
if expire_at is None or is_valid_expiry(expire_at) is not True:
expire_at = set_expire_time(template.TASK_DEADLINE)

cls._task_to_expiry[r.dojo_task_id] = expire_at

cls._rid_to_model_map[request_id] = obfuscated_model_to_model
logger.debug("released lock for task tracker")
return

@classmethod
async def remove_expired_tasks(cls):
# Identify expired tasks
current_time = datetime.utcnow().isoformat() + "Z"
current_time = get_current_utc_time_iso()
expired_tasks = [
task_id
for task_id, expiry_time in cls._task_to_expiry.items()
Expand Down
17 changes: 4 additions & 13 deletions commons/human_feedback/dojo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import datetime
import json
from typing import Dict, List

Expand All @@ -7,7 +6,7 @@
from requests_toolbelt import MultipartEncoder

import template
from commons.utils import loaddotenv
from commons.utils import loaddotenv, set_expire_time
from template import DOJO_API_BASE_URL
from template.protocol import (
FeedbackRequest,
Expand Down Expand Up @@ -92,20 +91,12 @@ async def create_task(
else:
logger.error(f"Unrecognized criteria type: {type(criteria_type)}")

expireAt = (
(
datetime.datetime.utcnow()
+ datetime.timedelta(seconds=template.TASK_DEADLINE)
)
.replace(microsecond=0, tzinfo=datetime.timezone.utc)
.isoformat()
.replace("+00:00", "Z")
)
expire_at = set_expire_time(template.TASK_DEADLINE)

body = {
"title": "LLM Code Generation Task",
"body": ranking_request.prompt,
"expireAt": expireAt,
"expireAt": expire_at,
"taskData": json.dumps([taskData]),
"maxResults": "1",
}
Expand Down Expand Up @@ -144,4 +135,4 @@ async def create_task(
f"Error occurred when trying to create task\nErr:{response.json()['error']}"
)
response.raise_for_status()
return task_ids, expireAt
return task_ids
55 changes: 54 additions & 1 deletion commons/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uuid
from collections import OrderedDict
from collections.abc import Mapping
from datetime import datetime, timedelta, timezone
from functools import lru_cache, update_wrapper
from math import floor
from typing import Any, Callable, Tuple, Type, get_origin
Expand All @@ -12,12 +13,13 @@
import jsonref
import requests
import torch
import wandb
from Crypto.Hash import keccak
from loguru import logger
from pydantic import BaseModel
from tenacity import RetryError, Retrying, stop_after_attempt, wait_exponential_jitter

import wandb


def get_new_uuid():
return str(uuid.uuid4())
Expand Down Expand Up @@ -329,3 +331,54 @@ def ttl_get_block(subtensor) -> int:
Note: self here is the miner or validator instance
"""
return subtensor.get_current_block()


def get_current_utc_time_iso():
return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")


def set_expire_time(expire_in_seconds: int) -> str:
"""
Sets the expiration time based on the current UTC time and the given number of seconds.
Args:
expire_in_seconds (int): The number of seconds from now when the expiration should occur.
Returns:
str: The expiration time in ISO 8601 format with 'Z' as the UTC indicator.
"""
return (
(datetime.utcnow() + timedelta(seconds=expire_in_seconds))
.replace(microsecond=0, tzinfo=timezone.utc)
.isoformat()
.replace("+00:00", "Z")
)


def is_valid_expiry(expire_at: str) -> bool:
"""
Checks if the given expiry time is not None and falls within a reasonable time period.
Args:
expire_at (str): The expiry time in ISO format.
Returns:
bool: True if the expiry time is valid, False otherwise.
"""
if expire_at is None:
return False

try:
expiry_time = datetime.fromisoformat(expire_at)
except ValueError:
logger.error(f"Invalid expiry time format: {expire_at}")
return False

current_time = datetime.now(timezone.utc)
max_reasonable_time = current_time + timedelta(days=5)

if current_time <= expiry_time <= max_reasonable_time:
return True
else:
logger.warning(f"Expiry time {expire_at} is out of the reasonable range.")
return False
6 changes: 3 additions & 3 deletions neurons/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ async def forward_feedback_request(
scoring_method = self.config.scoring_method
if scoring_method.casefold() == ScoringMethod.DOJO:
synapse.scoring_method = ScoringMethod.DOJO
task_ids, expireAt = await DojoAPI.create_task(synapse)
task_ids = await DojoAPI.create_task(synapse)
assert len(task_ids) == 1
synapse.dojo_task_id = task_ids[0]
synapse.expireAt = expireAt

else:
logger.error("Unrecognized scoring method!")
except Exception:
Expand Down Expand Up @@ -136,7 +136,7 @@ async def blacklist_feedback_request(
async def priority_ranking(self, synapse: FeedbackRequest) -> float:
"""
The priority function determines the order in which requests are handled. Higher-priority
requests are processed before others. Miners may recieve messages from multiple entities at
requests are processed before others. Miners may receive messages from multiple entities at
once. This function determines which request should be processed first.
Higher values indicate that the request should be processed first.
Lower values indicate that the request should be processed later.
Expand Down
7 changes: 5 additions & 2 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from commons.dataset.synthetic import SyntheticAPI
from commons.dojo_task_tracker import DojoTaskTracker
from commons.scoring import Scoring
from commons.utils import get_epoch_time, get_new_uuid, init_wandb
from commons.utils import get_epoch_time, get_new_uuid, init_wandb, set_expire_time
from template.base.neuron import BaseNeuron
from template.protocol import (
DendriteQueryResponse,
Expand Down Expand Up @@ -121,7 +121,7 @@ async def update_score_and_send_feedback(self):
continue

logger.debug(
f"Initiailly had {len(d.miner_responses)} responses from miners, but only {len(hotkey_to_score.keys())} valid responses"
f"Initially had {len(d.miner_responses)} responses from miners, but only {len(hotkey_to_score.keys())} valid responses"
)

self.update_scores(hotkey_to_scores=hotkey_to_score)
Expand Down Expand Up @@ -216,6 +216,8 @@ async def send_request(
obfuscated_model_to_model[new_uuid] = completion.model
completion.model = new_uuid

expire_at = set_expire_time(template.TASK_DEADLINE)

synapse = FeedbackRequest(
task_type=str(TaskType.CODE_GENERATION),
criteria_types=[
Expand All @@ -227,6 +229,7 @@ async def send_request(
],
prompt=data.prompt,
responses=data.responses,
expire_at=expire_at,
)

all_miner_uids = extract_miner_uids(metagraph=self.metagraph)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ py-modules = ["dojo_cli"]

[tool.setuptools.dynamic]
dependencies = {file = ["requirements.txt"]}
optional-dependencies-dev = {file = ["requirements-dev.txt"]}
optional-dependencies.dev = {file = ["requirements-dev.txt"]}

[tool.setuptools_scm]
version_scheme = "only-version"
Expand Down
3 changes: 1 addition & 2 deletions template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ def get_latest_git_tag():

# TODO @dev change before live
VALIDATOR_MIN_STAKE = 99
TASK_DEADLINE = 5 * 60
# TASK_DEADLINE = 8 * 60 * 60
TASK_DEADLINE = 8 * 60 * 60

DOJO_API_BASE_URL = os.getenv("DOJO_API_BASE_URL")
if DOJO_API_BASE_URL is None:
Expand Down
4 changes: 2 additions & 2 deletions template/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ class FeedbackRequest(bt.Synapse):
allow_mutation=False,
)
scoring_method: str | None = Field(
decscription="Method to use for scoring completions"
description="Method to use for scoring completions"
)
dojo_task_id: str | None = Field(description="Dojo task ID for the request")
expireAt: str | None = Field(description="Expired time for Dojo task")
expire_at: str | None = Field(description="Expired time for Dojo task")


class ScoringResult(bt.Synapse):
Expand Down

0 comments on commit cc230e0

Please sign in to comment.