Skip to content

Commit

Permalink
feat: add dojo task monitor on validator
Browse files Browse the repository at this point in the history
  • Loading branch information
jarvis8x7b committed Apr 29, 2024
1 parent cfea834 commit 8c8ff24
Showing 1 changed file with 68 additions and 34 deletions.
102 changes: 68 additions & 34 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from commons.dataset.synthetic import build_prompt_responses_pair
from commons.evals import EvalUtils
from commons.human_feedback.aws_mturk import MTurkUtils, parse_assignment
from commons.human_feedback.dojo import DojoAPI
from commons.logging.wandb_logging import wandb_log
from commons.reward_model.models import ModelUtils
from commons.scoring import Scoring
Expand All @@ -27,6 +28,7 @@
SCORING_METHOD_PRIORITY,
AWSCredentials,
Completion,
DendriteQueryResponse,
MTurkResponse,
Rank,
RankingRequest,
Expand Down Expand Up @@ -57,13 +59,34 @@ def __new__(cls, *args, **kwargs):
cls._instance = super(DojoTaskTracker, cls).__new__(cls)
return cls._instance

@staticmethod
def filter_dojo_responses(
responses: List[RankingRequest],
) -> List[RankingRequest]:
return list(filter(lambda r: r.scoring_method == ScoringMethod.DOJO, responses))

@classmethod
def update_task_map(cls, responses: List[RankingRequest]):
dojo_responses = DojoTaskTracker.filter_dojo_responses(responses)
with cls.lock:
dojo_responses = filter(
lambda r: r.scoring_method == ScoringMethod.DOJO, responses
)
for r in dojo_responses:
if not r.request_id:
bt.logging.error(
f"Request ID not found in response from miner: {r.axon.hotkey}"
)
continue
if not r.axon.hotkey:
bt.logging.error(
f"Hotkey not found in response from miner: {r.axon.hotkey}"
)
continue

if not r.dojo_task_id:
bt.logging.error(
f"Dojo task ID not found in response from miner: {r.axon.hotkey}"
)
continue

cls.request_id_to_miner_to_task_ids[r.request_id][r.axon.hotkey].append(
r.dojo_task_id
)
Expand All @@ -72,14 +95,35 @@ def update_task_map(cls, responses: List[RankingRequest]):
)

@classmethod
async def check_task_completion(cls, task_id: str):
async with httpx.AsyncClient() as client:
response = await client.get(cls._api_url)
response_data = response.json()
# TODO process task data JSON
# if response_data['status'].lower() == 'completed':
# TODO append to DendriteQueryResponse for post-processing later
return response_data
async def monitor_task_completions(cls):
# TODO implement locking
while True:
bt.logging.info(f"Monitoring Dojo Task completions... {get_epoch_time()}")
for (
request_id,
miner_to_task_ids,
) in cls.request_id_to_miner_to_task_ids.items():
for miner, task_ids in miner_to_task_ids.items():
if not task_ids:
continue
task_responses: List[List | None] = await asyncio.gather(
*[DojoAPI.get_task_and_results(task_id) for task_id in task_ids]
)

# filter None, since get_task_and_results returns None if task is not completed
ranking_results: List[List] = list(
filter(
lambda r: r,
task_responses,
)
)
# TODO for each ranking result append to our Datamanager
await asyncio.sleep(10)
return


def parse_dojo_results(results: List[Dict]):
pass


class Validator(BaseNeuron):
Expand Down Expand Up @@ -470,7 +514,7 @@ async def poll_dojo_tasks(self):
for miner_hotkey, task_ids in miner_hotkey_to_task_ids.items():
results = asyncio.gather(
*[
DojoTaskTracker.check_task_completion(task_id)
DojoTaskTracker.monitor_task_completions(task_id)
for task_id in task_ids
]
)
Expand All @@ -479,15 +523,13 @@ async def poll_dojo_tasks(self):
async def send_request(
self,
synapse: RankingRequest = None,
# ) -> DendriteQueryResponse:
):
# typically the request may come from an external source however,
# initially will seed it with some data for miners to get started
if synapse is None:
data = build_prompt_responses_pair()
data = await build_prompt_responses_pair()
synapse = RankingRequest(
task=TaskType.TEXT,
pid=get_new_uuid(),
task=TaskType.CODE_GENERATION,
prompt=data["prompt"],
completions=[Completion.parse_obj(d) for d in data["responses"]],
)
Expand All @@ -507,27 +549,19 @@ async def send_request(
bt.logging.warning("No axons to query ... skipping")
return

# The dendrite client queries the network.
responses: List[RankingRequest] = await self.dendrite.forward(
axons=axons, synapse=synapse, deserialize=False, timeout=24
)

DojoTaskTracker().update_task_map(responses)
bt.logging.success("Successfully sent requests... now we wait.")

# NOTE we don't need to filter responses here as tasks will take longer than normal
# valid_responses = [
# response
# for response in responses
# if response.scoring_method in [method for method in ScoringMethod]
# ]

# TODO manage data better
# response_data = DendriteQueryResponse(
# request=synapse,
# responses=valid_responses,
# )
# await DataManager.append_responses(response=response_data)
dojo_responses = DojoTaskTracker.filter_dojo_responses(responses)
DojoTaskTracker().update_task_map(dojo_responses)
non_dojo_responses = list(filter(lambda r: r not in dojo_responses, responses))

response_data = DendriteQueryResponse(
request=synapse,
responses=non_dojo_responses,
)
await DataManager.append_responses(response=response_data)
return

async def run(self):
Expand Down Expand Up @@ -576,7 +610,7 @@ async def run(self):
self.sync()

self.step += 1
await asyncio.sleep(24)
await asyncio.sleep(12)

# If someone intentionally stops the validator, it'll safely terminate operations.
except KeyboardInterrupt:
Expand Down

0 comments on commit 8c8ff24

Please sign in to comment.