Skip to content

Commit

Permalink
feat: add ICC metric for consensus scoring
Browse files Browse the repository at this point in the history
  • Loading branch information
jarvis8x7b committed May 31, 2024
1 parent ed8be18 commit 30a5f6e
Show file tree
Hide file tree
Showing 7 changed files with 352 additions and 201 deletions.
104 changes: 30 additions & 74 deletions commons/human_feedback/dojo.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import asyncio
import datetime
import json
import os
from typing import Dict, List
from typing import Dict, List, Optional
from requests_toolbelt import MultipartEncoder
import httpx
from commons import utils

from template.protocol import (
FeedbackRequest,
Expand All @@ -17,51 +16,15 @@
load_dotenv()


DOJO_API_BASE_URL = os.getenv("DOJO_API_BASE_URL")
if not DOJO_API_BASE_URL:
raise ValueError("DOJO_API_BASE_URL is not set")


def get_dojo_api_key():
key = os.getenv("DOJO_API_KEY")
if key is None:
raise ValueError("DOJO_API_KEY is not set")
return key


def _extract_ranking_result_data(result_data: List[Dict]) -> Dict[str, str]:
# sample data: [{'type': 'ranking', 'value': {'1': 'Code 2', '2': 'Code 1'}]
ranking_results = list(
filter(lambda x: x["type"] == RankingCriteria.type, result_data)
)
if len(ranking_results) == 1:
return ranking_results[0].get("value", {})
return {}


def parse_task_results(response_json: Dict) -> List[Dict[str, str]]:
task_results = response_json.get("body", {}).get("taskResults", [])
parsed_results = []
for t in task_results:
res = _extract_ranking_result_data(t.get("result_data", []))
if not res:
continue
parsed_results.append(res)
return parsed_results


def check_task_completion_status(response_json: Dict):
status = response_json.get("body", {}).get("status")
return status and status.lower() == "completed"
DOJO_API_BASE_URL = "***REMOVED***"
DOJO_API_KEY = utils.loaddotenv("DOJO_API_KEY")


class DojoAPI:
# _api_key = get_dojo_api_key()
# _http_client = httpx.AsyncClient(headers={"Authorization": f"Bearer {_api_key}"})
_http_client = httpx.AsyncClient()

@classmethod
async def get_task_by_id(cls, task_id: str):
async def _get_task_by_id(cls, task_id: str):
"""Gets task by task id and checks completion status"""
url = f"{DOJO_API_BASE_URL}/api/v1/tasks/{task_id}"
async with cls._http_client as client:
Expand All @@ -70,29 +33,36 @@ async def get_task_by_id(cls, task_id: str):
return response.json()

@classmethod
async def get_task_results_by_task_id(cls, task_id: str):
"""Gets ranking task results from task id"""
async def _get_task_results_by_task_id(cls, task_id: str):
"""Gets task results from task id"""
url = f"{DOJO_API_BASE_URL}/api/v1/tasks/get-results/{task_id}"
async with cls._http_client as client:
response = await client.get(url)
response.raise_for_status()
return response.json()

@classmethod
async def get_task_and_results(cls, task_id: str):
"""Returns optional [{'1': 'model hash 1', '2': 'model hash 2'}],
where '1' and '2' are the explicit rank integers"""
completion_status = check_task_completion_status(
await cls.get_task_by_id(task_id)
)
if not completion_status:
return None
ranking_results = parse_task_results(
await cls.get_task_results_by_task_id(task_id)
)
if not ranking_results:
return None
return ranking_results
async def get_task_results_by_task_id(cls, task_id: str) -> Optional[List[Dict]]:
"""Gets task results from task id to prepare for scoring later on"""
task_response = await cls._get_task_by_id(task_id)
task_status = task_response.get("body", {}).get("status", None)
is_completed = task_status and task_status.lower() == "completed"
if is_completed is None:
logger.error(f"Failed to read status field for task_id: {task_id}")
return

if is_completed is False:
return
task_results_response = await cls._get_task_results_by_task_id(task_id)
task_results = task_results_response.get("body", {}).get("taskResults")
if task_results is None:
logger.error(f"Failed to read task results for task_id: {task_id}")
return

if not task_results:
return

return task_results

@classmethod
async def create_task(
Expand All @@ -116,6 +86,7 @@ async def create_task(
{
**criteria_type.dict(),
"options": [
# TODO remove model from name
f"Model {option}"
for option in criteria_type.dict().get("options", [])
],
Expand All @@ -127,6 +98,7 @@ async def create_task(
{
**criteria_type.dict(),
"options": [
# TODO remove model from name
f"Model {option}"
for option in criteria_type.dict().get("options", [])
],
Expand All @@ -146,10 +118,6 @@ async def create_task(
"maxResults": "10",
}

DOJO_API_KEY = os.getenv("DOJO_API_KEY")
if not DOJO_API_KEY:
logger.error("DOJO_API_KEY is not set")

mp = MultipartEncoder(fields=body)
response = await client.post(
path,
Expand All @@ -169,15 +137,3 @@ async def create_task(
)
response.raise_for_status()
return task_ids


if __name__ == "__main__":

async def main():
print(
await DojoAPI.get_task_results_by_task_id(
"bdb56d72-dd98-40c0-a42f-312d018a0a1e"
)
)

asyncio.run(main())
2 changes: 1 addition & 1 deletion commons/reward_model/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from commons.llm.prompts import PromptBuilder, ScoreRange
from commons.utils import PydanticUtils
from template.protocol import (
Response,
ModelConfig,
PreferenceResponse,
Response,
ScoresResponse,
)

Expand Down
Loading

0 comments on commit 30a5f6e

Please sign in to comment.