diff --git a/scripts/extract_dataset.py b/scripts/extract_dataset.py index 29a95f1..0ccd043 100644 --- a/scripts/extract_dataset.py +++ b/scripts/extract_dataset.py @@ -95,7 +95,8 @@ async def build_jsonl(filename: str): logger.info(f"raw_scores_vec: {raw_scores_vec}") if raw_scores_vec.size > 0: - mean_scores = raw_scores_vec.mean(axis=1) + # ensure we're taking mean for each completion, across all miners + mean_scores = raw_scores_vec.mean(axis=0) logger.info(f"mean_scores shape: {mean_scores.shape}") jsonl_row = Row( prompt=prompt,