Skip to content

Commit

Permalink
fix: add verification of mturk task on validator side
Browse files Browse the repository at this point in the history
  • Loading branch information
jarvis8x7b committed Mar 8, 2024
1 parent 48af1e1 commit 0506416
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 75 deletions.
3 changes: 2 additions & 1 deletion .env.miner.example
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ AWS_ACCESS_KEY_ID=
AWS_SECRET_KEY=
# should look like the form: arn:aws:sns:us-east-1:1234567890:sns_topic_name
AWS_SNS_ARN_ID=
AWS_ASSUME_ROLE_ARN=
TOGETHER_API_KEY=
OPENAI_API_KEY=
OPENAI_API_KEY=
108 changes: 79 additions & 29 deletions commons/human_feedback/aws_mturk.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import textwrap
from collections import defaultdict
from strenum import StrEnum
from typing import Any, Dict, List
import xml.etree.ElementTree as ET
import json

import bittensor as bt
import boto3
Expand All @@ -12,7 +13,7 @@
from commons.factory import Factory

from commons.llm.prompts import ScoreRange
from template.protocol import Completion
from template.protocol import AWSCredentials, Completion

load_dotenv()

Expand All @@ -21,15 +22,7 @@
US_EAST_REGION = "us-east-1"
# should look like the form: arn:aws:sns:us-east-1:1234567890:sns_topic_name
AWS_SNS_ARN_ID = os.getenv("AWS_SNS_ARN_ID")


class MTurkEventTypes(StrEnum):
AssignmentAccepted = "AssignmentAccepted"
AssignmentSubmitted = "AssignmentSubmitted"
AssignmentReturned = "AssignmentReturned"
AssignmentAbandoned = "AssignmentAbandoned"
HITReviewable = "HITReviewable"
HITExpired = "HITExpired"
AWS_ASSUME_ROLE_ARN = os.getenv("AWS_ASSUME_ROLE_ARN")


# ensure regions in 'endpoint' key matches
Expand Down Expand Up @@ -61,28 +54,83 @@ def get_environment_config(environment: str) -> Dict[str, Any]:
return current_env


def get_aws_client(environment_name):
env_config = get_environment_config(environment_name)
mturk_client = boto3.client(
"mturk",
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_KEY,
region_name=US_EAST_REGION,
endpoint_url=env_config["endpoint_url"],
)
return mturk_client
def parse_assignment(assignment):
result = {
# "WorkerId": assignment["WorkerId"],
"Answer": [],
"HITId": assignment["HITId"],
}

ns = {
"mt": "http://mechanicalturk.amazonaws.com/AWSMechanicalTurkDataSchemas/2005-10-01/QuestionFormAnswers.xsd"
}
root = ET.fromstring(assignment["Answer"])

for a in root.findall("mt:Answer", ns):
name = a.find("mt:QuestionIdentifier", ns).text
value = a.find("mt:FreeText", ns).text
result["Answer"].append({name: json.loads(value)})
return result


class STSUtils:
_sts_client = None

@classmethod
def get_client(
cls,
access_key_id: str = AWS_ACCESS_KEY_ID,
secret_access_key: str = AWS_SECRET_KEY,
):
if cls._sts_client is None:
kwargs = {
"aws_access_key_id": access_key_id,
"aws_secret_access_key": secret_access_key,
"region_name": US_EAST_REGION,
}
sts_client = boto3.client("sts", **kwargs)
cls._sts_client = sts_client
return cls._sts_client

@classmethod
def assume_role(cls, role_arn: str = AWS_ASSUME_ROLE_ARN):
assert isinstance(role_arn, str)
client = cls.get_client()
res = client.assume_role(RoleArn=role_arn, RoleSessionName="subnet_validator")

return AWSCredentials(
access_key_id=res["Credentials"]["AccessKeyId"],
secret_access_key=res["Credentials"]["SecretAccessKey"],
session_token=res["Credentials"]["SessionToken"],
access_expiration=res["Credentials"]["Expiration"],
)


class MTurkUtils:
_aws_client = None
_mturk_client = None

@classmethod
def get_client(cls):
def get_client(
cls,
access_key_id: str = AWS_ACCESS_KEY_ID,
secret_access_key: str = AWS_SECRET_KEY,
session_token: str = None,
):
config = Factory.get_config()
if MTurkUtils._aws_client is None:
MTurkUtils._aws_client = get_aws_client(config.aws_mturk_environment)

return MTurkUtils._aws_client
if cls._mturk_client is None:
env_config = get_environment_config(config.aws_mturk_environment)
kwargs = {
"aws_access_key_id": access_key_id,
"aws_secret_access_key": secret_access_key,
"region_name": US_EAST_REGION,
"endpoint_url": env_config["endpoint_url"],
}
if session_token:
kwargs["aws_session_token"] = session_token
mturk_client = boto3.client("mturk", **kwargs)
cls._mturk_client = mturk_client

return cls._mturk_client

@staticmethod
def encode_task_key(completion_id: str):
Expand Down Expand Up @@ -120,6 +168,7 @@ def create_mturk_task(
"""Create a human intellgence task to send to AWS MTurk workers."""
payout_auto_approval_seconds = 3600 * 24
success = False
hit_id = None
try:
new_hit = MTurkUtils.get_client().create_hit(
Title=title,
Expand Down Expand Up @@ -148,6 +197,7 @@ def create_mturk_task(
"HITID = " + new_hit["HIT"]["HITId"] + " (Use to Get Results)"
)

hit_id = new_hit["HIT"]["HITId"]
try:
hit_type_id = new_hit["HIT"]["HITTypeId"]
MTurkUtils.get_client().update_notification_settings(
Expand All @@ -167,12 +217,12 @@ def create_mturk_task(
bt.logging.error("Failed to update notification settings: " + str(e))
pass

return success
return success, hit_id
except botocore.exceptions.ClientError as e:
bt.logging.error(
f"Error occurred while trying to create hit... exception: {e}"
)
return False
return False, None

@staticmethod
async def handle_mturk_event(event_payload: Dict):
Expand Down
53 changes: 12 additions & 41 deletions neurons/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,15 @@
from concurrent.futures import ThreadPoolExecutor

import bittensor as bt
from commons.factory import Factory
from commons.llm.openai_proxy import Provider
from commons.human_feedback.aws_mturk import MTurkUtils
from commons.reward_model.models import ModelUtils
from commons.scoring import Scoring
from commons.human_feedback.aws_mturk import MTurkUtils, STSUtils
from commons.utils import get_epoch_time

from template.base.miner import BaseMinerNeuron
from template.protocol import (
ModelConfig,
MTurkResponse,
Rank,
RankingRequest,
RankingResult,
ScoringMethod,
Expand Down Expand Up @@ -60,48 +58,14 @@ async def forward_ranking_request(self, synapse: RankingRequest) -> RankingReque

scoring_method = self.config.scoring_method
if scoring_method.casefold() == ScoringMethod.HF_MODEL:
for completion in synapse.completions:
score = ModelUtils._hf_score(
self.config.model_name, synapse.prompt, completion.text
)
synapse.ranks.append(
Rank(
cid=completion.cid,
score=score,
)
)
synapse.scoring_method = ScoringMethod.HF_MODEL
synapse.model_config = ModelConfig(model_name=self.config.model_name)

elif scoring_method.casefold() == ScoringMethod.LLM_API:
llm_provider = Provider(self.config.llm_provider)
model_name = self.config.model_name
scores_response = await ModelUtils._llm_api_score(
provider=llm_provider,
model_name=model_name,
prompt=synapse.prompt,
completions=synapse.completions,
)
for completion in synapse.completions:
matching_score_item = next(
(
item
for item in scores_response.scores
if item.completion_id == completion.cid
),
None,
)

if matching_score_item:
synapse.ranks.append(
Rank(
cid=completion.cid,
score=matching_score_item.score,
)
)
synapse.scoring_method = ScoringMethod.LLM_API
synapse.model_config = ModelConfig(
provider=llm_provider, model_name=model_name
provider=Provider(self.config.llm_provider),
model_name=self.config.model_name,
)

elif scoring_method.casefold() == ScoringMethod.AWS_MTURK:
Expand All @@ -115,7 +79,9 @@ async def forward_ranking_request(self, synapse: RankingRequest) -> RankingReque
reward_in_dollars=0.01,
)
synapse.scoring_method = ScoringMethod.AWS_MTURK
await loop.run_in_executor(self.executor, task)
success, hit_id = await loop.run_in_executor(self.executor, task)
if success:
synapse.mturk_hit_id = hit_id
else:
bt.logging.error("Unrecognized scoring method!")
except:
Expand All @@ -132,6 +98,11 @@ async def send_mturk_response(self, synapse: MTurkResponse):
f"No hotkey found for completion ids: {synapse.completion_id_to_score.keys()}"
)
return
# generate temporary credentials to send to the validator
if not synapse.aws_credentials:
credentials = STSUtils.assume_role()
credentials.environment = Factory.get_config().aws_mturk_environment
synapse.aws_credentials = credentials

uid = self.metagraph.hotkeys.index(hotkey)
axon = self.metagraph.axons[uid]
Expand Down
63 changes: 59 additions & 4 deletions neurons/validator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import asyncio
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
import functools
import xml.etree.ElementTree as ET
import json
import threading
import time
from traceback import print_exception
import traceback
from typing import List, Tuple
from typing import List, Tuple, Dict
import copy
from torch.nn import functional as F

Expand All @@ -18,12 +18,15 @@
from commons.dataset.hf_utils import HuggingFaceUtils
from commons.evals import EvalUtils
from commons.factory import Factory
from commons.human_feedback.aws_mturk import MTurkUtils, parse_assignment
from commons.objects import DendriteQueryResponse
from commons.reward_model.models import ModelUtils
from commons.scoring import Scoring

from commons.utils import get_epoch_time, get_new_uuid, serve_axon
from template.base.neuron import BaseNeuron
from template.protocol import (
AWSCredentials,
Completion,
MTurkResponse,
Rank,
Expand Down Expand Up @@ -85,6 +88,48 @@ async def blacklist_mturk_response(

return False, "Valid request received from miner"

@staticmethod
def verify_mturk_task(
aws_credentials: AWSCredentials,
hit_id: str,
completion_id_to_score: Dict[str, float],
):
temp_client = MTurkUtils.get_client(
access_key_id=aws_credentials.access_key_id,
secret_access_key=aws_credentials.secret_access_key,
session_token=aws_credentials.session_token,
)
res = temp_client.list_assignments_for_hit(HITId=hit_id)
answers = [parse_assignment(assignment) for assignment in res["Assignments"]]
if not answers:
return False, "No assignments found for HIT"

cids_to_check = set(completion_id_to_score.keys())
# example of answers variable
# [{'Answer': [{'taskAnswers': [{'cid_a6f41ad1-d8a8-4bf3-a698-1b431bf2edac': 5.68, 'cid_f95cae4d-38ed-4911-b97a-f92a0c3bad9a': 7.49}]}], 'HITId': '3MG8450X3U3J7MRIYLXCI5SO1CIUPJ'}]
for answer in answers:
task_answers = answer.get("Answer", [])
for task_answer in task_answers:
for completion_id, score in task_answer.get("taskAnswers", {}).items():
if completion_id not in cids_to_check:
bt.logging.warning(
f"Completion ID {completion_id} found in MTurk task answers is not in the expected set."
)
return (
False,
f"Unexpected completion ID {completion_id} in MTurk task answers.",
)
elif completion_id_to_score[completion_id] != score:
bt.logging.warning(
f"Score mismatch for completion ID {completion_id}: expected {completion_id_to_score[completion_id]}, got {score} from MTurk task answers."
)
return (
False,
f"Score mismatch for completion ID {completion_id}.",
)

return True, "All checks passed"

async def forward_mturk_response(self, synapse: MTurkResponse):
"""Receives MTurk responses from miners after delayed response to allow for human feedback loop"""
# 1. check request from RankingRequest
Expand Down Expand Up @@ -135,6 +180,17 @@ async def forward_mturk_response(self, synapse: MTurkResponse):
)
continue

is_verified, reason = self.verify_mturk_task(
synapse.aws_credentials,
synapse.mturk_hit_id,
synapse.completion_id_to_score,
)
if not is_verified:
bt.logging.error(
f"MTurk task verification failed due to reason:{reason}"
)
return

request_copy = copy.deepcopy(d.request)
for cid in found_cids:
# similar to scoring method in Miner.forward(...)
Expand Down Expand Up @@ -321,7 +377,6 @@ async def send_request(
for response in responses
if len(response.ranks) > 0
and response.scoring_method in [method for method in ScoringMethod]
and response.scoring_method != ScoringMethod.AWS_MTURK
]

if not len(valid_responses):
Expand Down

0 comments on commit 0506416

Please sign in to comment.