Skip to content

Commit

Permalink
feat: add dojo task tracker, use synthetic data generator
Browse files Browse the repository at this point in the history
  • Loading branch information
jarvis8x7b committed Apr 25, 2024
1 parent 2c802e3 commit d074fc0
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 34 deletions.
3 changes: 2 additions & 1 deletion main_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from commons.api.reward_route import reward_router
from commons.factory import Factory
from commons.logging.patch_logging import apply_patch
from neurons.validator import log_validator_status
from neurons.validator import DojoTaskTracker, log_validator_status

load_dotenv()
apply_patch()
Expand Down Expand Up @@ -60,6 +60,7 @@ async def main():
validator.calculate_miner_classification_accuracy, trigger=every_30_min_trigger
)
scheduler.add_job(validator.reset_accuracy, trigger=daily_trigger)
scheduler.add_job(validator.poll_dojo_tasks, trigger=IntervalTrigger(minutes=2))
scheduler.start()

config = uvicorn.Config(
Expand Down
119 changes: 86 additions & 33 deletions neurons/validator.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import asyncio
import copy
import os
import threading
import time
from collections import defaultdict
from traceback import print_exception
from typing import Dict, List, Tuple

import bittensor as bt
from itsdangerous import URLSafeTimedSerializer
import httpx
import numpy as np
import torch
from fastapi.encoders import jsonable_encoder
from torch.nn import functional as F

from commons.data_manager import DataManager
from commons.dataset.dataset import SeedDataManager
from commons.dataset.synthetic import build_prompt_responses_pair
from commons.evals import EvalUtils
from commons.human_feedback.aws_mturk import MTurkUtils, parse_assignment
from commons.logging.wandb_logging import wandb_log
Expand All @@ -26,27 +27,61 @@
SCORING_METHOD_PRIORITY,
AWSCredentials,
Completion,
DendriteQueryResponse,
Modality,
MTurkResponse,
Rank,
RankingRequest,
RankingResult,
ScoringMethod,
TaskType,
)
from template.utils.config import get_config
from template.utils.uids import (
get_random_miner_uids,
is_miner,
extract_miner_uids,
MinerUidSelector,
extract_miner_uids,
is_miner,
)


def _filter_valid_responses(responses: List[RankingRequest]) -> List[RankingRequest]:
return [response for response in responses if len(response.ranks) > 0]


class DojoTaskTracker:
_instance = None
request_id_to_miner_to_task_ids = defaultdict(lambda: defaultdict(list))
lock = asyncio.Lock()
_api_url = os.getenv("DOJO_WORKER_API_URL")

def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(DojoTaskTracker, cls).__new__(cls)
return cls._instance

@classmethod
def update_task_map(cls, responses: List[RankingRequest]):
with cls.lock:
dojo_responses = filter(
lambda r: r.scoring_method == ScoringMethod.DOJO, responses
)
for r in dojo_responses:
cls.request_id_to_miner_to_task_ids[r.request_id][r.axon.hotkey].append(
r.dojo_task_id
)
bt.logging.info(
f"Processed N={len(dojo_responses)} miner responses using Dojo for request id: {r.request_id}"
)

@classmethod
async def check_task_completion(cls, task_id: str):
async with httpx.AsyncClient() as client:
response = await client.get(cls._api_url)
response_data = response.json()
# TODO process task data JSON
# if response_data['status'].lower() == 'completed':
# TODO append to DendriteQueryResponse for post-processing later
return response_data


class Validator(BaseNeuron):
def __init__(self):
super(Validator, self).__init__()
Expand Down Expand Up @@ -424,20 +459,39 @@ async def update_score_and_send_feedback(self):
# once we have scored certain responses, just remove them
await DataManager.remove_responses(consumed_responses)

async def poll_dojo_tasks(self):
# TODO prioritise tasks that have an earlier created_at
data = DojoTaskTracker().request_id_to_miner_to_task_ids
if not data:
bt.logging.warning("No Dojo task IDs to poll")
return

for _, miner_hotkey_to_task_ids in data.items():
for miner_hotkey, task_ids in miner_hotkey_to_task_ids.items():
results = asyncio.gather(
*[
DojoTaskTracker.check_task_completion(task_id)
for task_id in task_ids
]
)
# TODO parse results

async def send_request(
self, synapse: RankingRequest = None
) -> DendriteQueryResponse:
self,
synapse: RankingRequest = None,
# ) -> DendriteQueryResponse:
):
# typically the request may come from an external source however,
# initially will seed it with some data for miners to get started
if synapse is None:
prompt, completions = SeedDataManager.get_prompt_and_completions()
data = build_prompt_responses_pair()
synapse = RankingRequest(
modality=Modality.TEXT,
n_completions=len(completions),
task=TaskType.TEXT,
pid=get_new_uuid(),
prompt=prompt,
completions=[Completion(text=c) for c in completions],
prompt=data["prompt"],
completions=[Completion.parse_obj(d) for d in data["responses"]],
)
bt.logging.info(f"Parsed synapse: {synapse.dict()}")

all_miner_uids = extract_miner_uids(metagraph=self.metagraph)
sel_miner_uids = MinerUidSelector(all_miner_uids).get_target_uids(
Expand All @@ -455,27 +509,26 @@ async def send_request(

# The dendrite client queries the network.
responses: List[RankingRequest] = await self.dendrite.forward(
axons=axons, synapse=synapse, deserialize=False, timeout=30
axons=axons, synapse=synapse, deserialize=False, timeout=24
)

valid_responses = [
response
for response in responses
if response.scoring_method in [method for method in ScoringMethod]
]

if not len(valid_responses):
bt.logging.error("No valid responses received from miners.")
return
else:
bt.logging.success(f"Received {len(valid_responses)} valid responses")

response_data = DendriteQueryResponse(
request=synapse,
responses=valid_responses,
)
await DataManager.append_responses(response=response_data)
return response_data
DojoTaskTracker().update_task_map(responses)
bt.logging.success("Successfully sent requests... now we wait.")

# NOTE we don't need to filter responses here as tasks will take longer than normal
# valid_responses = [
# response
# for response in responses
# if response.scoring_method in [method for method in ScoringMethod]
# ]

# TODO manage data better
# response_data = DendriteQueryResponse(
# request=synapse,
# responses=valid_responses,
# )
# await DataManager.append_responses(response=response_data)
return

async def run(self):
"""
Expand Down

0 comments on commit d074fc0

Please sign in to comment.