Skip to content

Commit

Permalink
fix: handle aws mturk payload properly
Browse files Browse the repository at this point in the history
  • Loading branch information
jarvis8x7b committed Feb 19, 2024
1 parent 2a4000f commit a696e5c
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 10 deletions.
3 changes: 3 additions & 0 deletions commons/api/human_feedback_route.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ async def task_completion_callback(request: Request):
response_json = await request.json()
bt.logging.info(f"Received task completion callback with body: {response_json}")
completion_id_to_scores = await MTurkUtils.handle_mturk_event(response_json)
if not completion_id_to_scores:
return

try:
miner = Miner()
await miner.send_mturk_response(
Expand Down
48 changes: 44 additions & 4 deletions commons/human_feedback/aws_mturk.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,33 @@
import os
import textwrap
from collections import defaultdict
from enum import Enum, StrEnum
from typing import Any, Dict, List

import bittensor as bt
import boto3
import botocore.exceptions
import markdown
from dotenv import load_dotenv

from commons.llm.prompts import ScoreRange
from template.protocol import Completion
import markdown

load_dotenv()
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
AWS_SECRET_KEY = os.getenv("AWS_SECRET_KEY")
US_EAST_REGION = "us-east-1"


class MTurkEventTypes(StrEnum):
AssignmentAccepted = "AssignmentAccepted"
AssignmentSubmitted = "AssignmentSubmitted"
AssignmentReturned = "AssignmentReturned"
AssignmentAbandoned = "AssignmentAbandoned"
HITReviewable = "HITReviewable"
HITExpired = "HITExpired"


# ensure regions in 'endpoint' key matches
mturk_env_dict = {
"production": {
Expand Down Expand Up @@ -93,6 +104,7 @@ def create_mturk_task(
):
"""Create a human intellgence task to send to AWS MTurk workers."""
payout_auto_approval_seconds = 3600 * 24
success = False
try:
new_hit = mturk_client.create_hit(
Title=title,
Expand All @@ -107,16 +119,37 @@ def create_mturk_task(
prompt, completions, score_range
),
)
success = True
hit_url = (
f"{env_config['preview_url']}?groupId={new_hit['HIT']['HITGroupId']}"
)
bt.logging.info(
bt.logging.success(
f"A new HIT has been created. You can preview it here:\n{hit_url}"
)
bt.logging.info(
bt.logging.success(
"HITID = " + new_hit["HIT"]["HITId"] + " (Use to Get Results)"
)
return True

try:
hit_type_id = new_hit["HIT"]["HITTypeId"]
mturk_client.update_notification_settings(
HITTypeId=hit_type_id,
Notification={
"Destination": "arn:aws:sns:us-east-1:364251527502:test_topic",
"Transport": "SNS",
"Version": "2006-05-05",
"EventTypes": [
"AssignmentSubmitted",
],
},
Active=True,
)
except Exception as e:
success = False
bt.logging.error("Failed to update notification settings: " + str(e))
pass

return success
except botocore.exceptions.ClientError as e:
bt.logging.error(
f"Error occurred while trying to create hit... exception: {e}"
Expand All @@ -136,6 +169,7 @@ async def handle_mturk_event(event_payload: Dict):
task_answers = answer.get("taskAnswers")
if task_answers is None:
bt.logging.warning("MTurk event has no task answers")
continue

for task_answer in task_answers:
for task_key, score in task_answer.items():
Expand All @@ -146,6 +180,12 @@ async def handle_mturk_event(event_payload: Dict):
bt.logging.info(
f"Processed MTurk event, completion ID to scores: {completion_id_to_scores}"
)
for k, v in completion_id_to_scores.items():
completion_id_to_scores[k] = float(sum(v) / len(v))

bt.logging.info(
f"Taking the average of set of scores: {completion_id_to_scores}"
)
return completion_id_to_scores

@staticmethod
Expand Down
7 changes: 3 additions & 4 deletions neurons/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def __new__(cls, *args, **kwargs):
cls._instance = super(Miner, cls).__new__(cls)
return cls._instance

def __init__(self, config=None):
super(Miner, self).__init__(config=config)
def __init__(self):
super(Miner, self).__init__()

# TODO(developer): Anything specific to your use case you can do here
# Warn if allowing incoming requests from anyone.
Expand Down Expand Up @@ -124,7 +124,7 @@ async def forward_ranking_request(self, synapse: RankingRequest) -> RankingReque
async def send_mturk_response(self, synapse: MTurkResponse):
"""After receiving a response from MTurk, send the response back to the calling validator"""
# 1. figure out which validator hotkey sent the original request
hotkey = Miner._find_hotkey_by_completions(synapse.completion_id_to_score)
hotkey = self._find_hotkey_by_completions(synapse.completion_id_to_score)
if not hotkey and not self.hotkey_to_request:
bt.logging.error(
f"No hotkey found for completion ids: {synapse.completion_id_to_score.keys()}"
Expand All @@ -141,7 +141,6 @@ async def send_mturk_response(self, synapse: MTurkResponse):
)
return

@staticmethod
def _find_hotkey_by_completions(self, completion_id_to_scores: Dict[str, float]):
if not self.hotkey_to_request:
bt.logging.warning(
Expand Down
4 changes: 2 additions & 2 deletions template/base/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class BaseMinerNeuron(BaseNeuron):
Base class for Bittensor miners.
"""

def __init__(self, config=None):
super().__init__(config=config)
def __init__(self):
super(BaseMinerNeuron, self).__init__()

def run(self):
"""
Expand Down

0 comments on commit a696e5c

Please sign in to comment.