From b28ca8328f1cff32d22e10ef448b9135ae013fce Mon Sep 17 00:00:00 2001 From: jarvis8x7b <157810922+jarvis8x7b@users.noreply.github.com> Date: Thu, 12 Dec 2024 15:38:45 +0800 Subject: [PATCH] feat: add ground truth to dataset collection --- scripts/extract_dataset.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/scripts/extract_dataset.py b/scripts/extract_dataset.py index 0ccd043..ad5c1f9 100644 --- a/scripts/extract_dataset.py +++ b/scripts/extract_dataset.py @@ -51,17 +51,21 @@ class Row(BaseModel): raw_scores: list[list[float]] # shape (num_completions) mean_scores: list[float] + # ground truth ranks + cid_to_ground_truth_rank: dict[str, int] class Config: arbitrary_types_allowed = True @model_serializer def serialize_model(self): + """Custom serializer method to ensure that types such as np.ndarray and torch.tensor are serialized correctly""" return { "prompt": self.prompt, "completions": self.completions, "raw_scores": self.raw_scores, "mean_scores": self.mean_scores, + "cid_to_ground_truth_rank": self.cid_to_ground_truth_rank, } @@ -103,6 +107,7 @@ async def build_jsonl(filename: str): completions=completions, raw_scores=raw_scores, mean_scores=mean_scores.tolist(), + cid_to_ground_truth_rank=task.request.ground_truth, ) else: jsonl_row = Row( @@ -110,6 +115,7 @@ async def build_jsonl(filename: str): completions=completions, raw_scores=[], mean_scores=[], + cid_to_ground_truth_rank={}, ) # Write the entry as a JSON line @@ -193,6 +199,7 @@ async def get_processed_tasks( vali_request = map_feedback_request_model_to_feedback_request( validator_request ) + logger.info(f"Vali request ground truths: {vali_request.ground_truth}") m_responses = list( map(