Skip to content

Commit

Permalink
feat: add weights classification accuracy multiplier to scoring
Browse files Browse the repository at this point in the history
  • Loading branch information
jarvis8x7b committed Feb 22, 2024
1 parent 1f3b627 commit 6c5772c
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 38 deletions.
2 changes: 1 addition & 1 deletion commons/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def get_ranking_data_filepath() -> Path:
return base_path / "data" / "ranking" / "data.pkl"

@staticmethod
def load(path) -> List[DendriteQueryResponse]:
def load(path) -> Optional[List[DendriteQueryResponse]]:
try:
# Load the list of Pydantic objects from the pickle file
with open(str(path), "rb") as file:
Expand Down
49 changes: 25 additions & 24 deletions commons/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,25 +70,6 @@ def get_openai_webgpt_comparisons():
)


eval_datasets = {
EvalDataset.ANTHROPIC_HHRLHF: iter(get_anthropic_hhrlhf()),
EvalDataset.STANFORD_SHP: iter(get_stanford_shp()),
EvalDataset.OPENAI_WEBGPT_COMPARISONS: iter(get_openai_webgpt_comparisons()),
EvalDataset.YITINGXIE_RLHF_REWARD_DATASETS: iter(
load_dataset(
EvalDataset.YITINGXIE_RLHF_REWARD_DATASETS, split="train", streaming=True
)
),
EvalDataset.DAHOAS_SYNTHETIC_INSTRUCT_GPTJ_PAIRWISE: iter(
load_dataset(
EvalDataset.DAHOAS_SYNTHETIC_INSTRUCT_GPTJ_PAIRWISE,
split="train",
streaming=True,
)
),
}


def next_circular(dataset_dict: Dict[str, Iterable], key: str):
try:
return next(dataset_dict[key])
Expand All @@ -106,13 +87,33 @@ def next_circular(dataset_dict: Dict[str, Iterable], key: str):


class EvalDatasetManager:
@staticmethod
def get_batch() -> List[Dict]:
dataset_names = list(eval_datasets.keys())
_eval_datasets = {
EvalDataset.ANTHROPIC_HHRLHF: iter(get_anthropic_hhrlhf()),
EvalDataset.STANFORD_SHP: iter(get_stanford_shp()),
EvalDataset.OPENAI_WEBGPT_COMPARISONS: iter(get_openai_webgpt_comparisons()),
EvalDataset.YITINGXIE_RLHF_REWARD_DATASETS: iter(
load_dataset(
EvalDataset.YITINGXIE_RLHF_REWARD_DATASETS,
split="train",
streaming=True,
)
),
EvalDataset.DAHOAS_SYNTHETIC_INSTRUCT_GPTJ_PAIRWISE: iter(
load_dataset(
EvalDataset.DAHOAS_SYNTHETIC_INSTRUCT_GPTJ_PAIRWISE,
split="train",
streaming=True,
)
),
}

@classmethod
def get_batch(cls) -> List[Dict]:
dataset_names = list(cls._eval_datasets.keys())
key = random.choice(dataset_names)
bt.logging.info(f"Using dataset: {key}, for evaluation")
batch_size = 32
return [next_circular(eval_datasets, key) for _ in range(batch_size)]
return [next_circular(cls._eval_datasets, key) for _ in range(batch_size)]


# # NOTE this serves as a start for prompt/completion pairs to be generated because at first there will be no requests coming in
Expand Down Expand Up @@ -151,7 +152,7 @@ def get_batch() -> List[Dict]:
# )
# }

# TODO change name to actual datset name
# TODO @dev change name to actual datset name
seed_dataset_name = "prooompt/test_dataset"


Expand Down
15 changes: 12 additions & 3 deletions main_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,19 @@ async def main():
validator = Factory.get_validator()
config = Factory.get_config()
scheduler = AsyncIOScheduler(
job_defaults={"max_instances": 1, "misfire_grace_time": 3}
job_defaults={"max_instances": 2, "misfire_grace_time": 3}
)
trigger = OrTrigger([CronTrigger(second=0), CronTrigger(second=30)])
scheduler.add_job(validator.update_score_and_send_feedback, trigger=trigger)

scheduler.add_job(
validator.update_score_and_send_feedback,
trigger=OrTrigger([CronTrigger(second=0), CronTrigger(second=30)]),
)
hourly_trigger = CronTrigger(minute=0)
daily_trigger = CronTrigger(hour=0, minute=0)
scheduler.add_job(
validator.calculate_miner_classification_accuracy, trigger=hourly_trigger
)
scheduler.add_job(validator.reset_accuracy, trigger=daily_trigger)
scheduler.start()

config = uvicorn.Config(
Expand Down
7 changes: 5 additions & 2 deletions neurons/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from template.base.miner import BaseMinerNeuron
from template.protocol import (
LLMConfig,
ModelConfig,
MTurkResponse,
Rank,
RankingRequest,
Expand Down Expand Up @@ -64,6 +64,7 @@ async def forward_ranking_request(self, synapse: RankingRequest) -> RankingReque
)
)
synapse.scoring_method = ScoringMethod.HF_MODEL
synapse.model_config = ModelConfig(model_name=self.config.model_name)

elif scoring_method.casefold() == ScoringMethod.LLM_API:
llm_provider = Provider(self.config.llm_provider)
Expand Down Expand Up @@ -92,7 +93,9 @@ async def forward_ranking_request(self, synapse: RankingRequest) -> RankingReque
)
)
synapse.scoring_method = ScoringMethod.LLM_API
synapse.llm_config = LLMConfig(provider=llm_provider, model_name=model_name)
synapse.model_config = ModelConfig(
provider=llm_provider, model_name=model_name
)

elif scoring_method.casefold() == ScoringMethod.AWS_MTURK:
# send off to MTurk workers... will timeout on validator side
Expand Down
54 changes: 46 additions & 8 deletions neurons/validator.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import asyncio
from collections import defaultdict
import threading
import time
from traceback import print_exception
from typing import List, Tuple
from datetime import datetime, timedelta
import copy
from torch.nn import functional as F

import bittensor as bt
import numpy as np
import torch
from commons.data_manager import DataManager
from commons.dataset import SeedDataManager
from commons.evals import EvalUtils
from commons.objects import DendriteQueryResponse
from commons.scoring import Scoring

Expand Down Expand Up @@ -69,6 +72,8 @@ def __init__(self):
self.thread: threading.Thread = None
self.lock = asyncio.Lock()
self.moving_averaged_scores = None
self.hotkey_to_accuracy = defaultdict(float)
# self.lock = asyncio.Lock()

async def blacklist_mturk_response(
self, synapse: MTurkResponse
Expand Down Expand Up @@ -152,6 +157,32 @@ async def send_consensus(self, synapse: RankingResult, hotkeys: List[str]):
axons = [axon for axon in self.metagraph.axons if axon.hotkey in hotkeys]
await self.dendrite(axons=axons, synapse=synapse, deserialize=False, timeout=12)

async def calculate_miner_classification_accuracy(self):
data = DataManager.load(path=DataManager.get_ranking_data_filepath())
if not data:
bt.logging.debug(
"Skipping classification accuracy as no ranking data found."
)
return

for d in data:
for r in d.responses:
participant = r.axon.hotkey
if participant in self.hotkey_to_accuracy:
bt.logging.warning(
f"Participant {participant} already has an accuracy score... skipping"
)
continue

accuracy = await EvalUtils.classification_accuracy(
scoring_method=r.scoring_method, model_config=r.model_config
)
self.hotkey_to_accuracy[participant] = accuracy
return

async def reset_accuracy(self):
self.hotkey_to_accuracy.clear()

async def update_score_and_send_feedback(self):
bt.logging.debug(
f"Scheduled update score and send feedback triggered at time: {time.time()}"
Expand Down Expand Up @@ -373,9 +404,11 @@ def set_weights(self):

# Calculate the average reward for each uid across non-zero values.
# Replace any NaN values with 0.
raw_weights = torch.nn.functional.normalize(
self.moving_averaged_scores, p=1, dim=0
)
raw_weights = F.normalize(self.moving_averaged_scores, p=1, dim=0)
if torch.all(raw_weights == 0):
bt.logging.warning(
"All weights are zero, therefore no valid weights to set"
)

bt.logging.debug("raw_weights", raw_weights)
bt.logging.debug("raw_weight_uids", self.metagraph.uids.to("cpu"))
Expand Down Expand Up @@ -451,24 +484,29 @@ def resync_metagraph(self):
self.scores = updated_scores
self.hotkeys = self.metagraph.hotkeys

def update_scores(self, hotkey_to_rewards):
def update_scores(self, hotkey_to_scores):
"""Performs exponential moving average on the scores based on the rewards received from the miners,
after setting the self.scores variable here, `set_weights` will be called to set the weights on chain."""

nan_value_indices = np.isnan(list(hotkey_to_rewards.values()))
nan_value_indices = np.isnan(list(hotkey_to_scores.values()))
if nan_value_indices.any():
bt.logging.warning(f"NaN values detected in rewards: {hotkey_to_rewards}")
bt.logging.warning(f"NaN values detected in rewards: {hotkey_to_scores}")

# Compute forward pass rewards, assumes uids are mutually exclusive.
# scores dimensions might have been updated after resyncing... len(uids) != len(self.scores)
rewards = torch.zeros((len(self.hotkeys),))
for index, (key, value) in enumerate(hotkey_to_rewards.items()):
for index, (key, value) in enumerate(hotkey_to_scores.items()):
# handle nan values
if nan_value_indices[index]:
rewards[key] = 0.0
# search metagraph for hotkey and grab uid
uid = self.hotkeys.index(key)
rewards[uid] = value

# multiply by the classification accuracy
if (accuracy := self.hotkey_to_accuracy[key]) and accuracy == 0.0:
bt.logging.warning(f"Classification accuracy for hotkey {key} is 0")

rewards[uid] = value * accuracy
bt.logging.debug(f"Rewards: {rewards}")

# Update scores with rewards produced by this step.
Expand Down

0 comments on commit 6c5772c

Please sign in to comment.