Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Mturk refactoring #128

Merged
merged 7 commits into from
Jun 9, 2017
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions parlai/mturk/core/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.

from parlai.core.agents import Agent
from parlai.core.worlds import display_messages

import os
import time
from datetime import datetime
import random
import string
import webbrowser
import json
import requests
from parlai.core.agents import create_agent_from_shared
from .setup_aws import setup_aws, check_mturk_balance, create_hit_type, create_hit_with_hit_type, setup_aws_credentials


def _get_new_messages(json_api_endpoint_url, task_group_id, conversation_id, after_message_id, excluded_agent_id=None, included_agent_id=None):
params = {
'method_name': 'get_new_messages',
'task_group_id': task_group_id,
'last_message_id': after_message_id,
'conversation_id': conversation_id,
}
if excluded_agent_id:
params['excluded_agent_id'] = excluded_agent_id
if included_agent_id:
params['included_agent_id'] = included_agent_id

request = requests.get(json_api_endpoint_url, params=params)
return json.loads(request.json())

def _send_new_message(json_api_endpoint_url, task_group_id, conversation_id, agent_id, message_text=None, reward=None, episode_done=False):
post_data_dict = {
'method_name': 'send_new_message',
'task_group_id': task_group_id,
'conversation_id': conversation_id,
'cur_agent_id': agent_id,
'episode_done': episode_done,
}
if message_text:
post_data_dict['text'] = message_text
if reward:
post_data_dict['reward'] = reward

request = requests.post(json_api_endpoint_url, data=json.dumps(post_data_dict))
return json.loads(request.json())

def _get_review_status_count(json_api_endpoint_url, task_group_id, conversation_id, review_status, requester_key):
params = {
'method_name': 'get_review_status_count',
'task_group_id': task_group_id,
'conversation_id': conversation_id,
'review_status': review_status,
'requester_key': requester_key
}
request = requests.get(json_api_endpoint_url, params=params)
return request.json()

class MTurkAgent(Agent):

skip_init = False
html_api_endpoint_url = None
json_api_endpoint_url = None
requester_key_gt = None

def __init__(self, opt, shared=None):
super().__init__(opt)

self.id = opt['agent_id']
self.task_name = opt['task']
self.is_sandbox = opt['is_sandbox']
self.conversation_id = opt['conversation_id']
self.mturk_agent_ids = opt['mturk_agent_ids']
self.all_agent_ids = opt['all_agent_ids']
self.hit_reward = opt['reward']
self.hit_title = opt['hit_title']
self.hit_description = opt['hit_description']
self.hit_keywords = opt['hit_keywords']
self.task_description = opt['task_description']

self.last_message_id = 0

if not self.__class__.skip_init:
print("\nYou are going to allow workers from Amazon Mechanical Turk to be an agent in ParlAI.\nDuring this process, Internet connection is required, and you should turn off your computer's auto-sleep feature.\n")
key_input = input("Please press Enter to continue... ")
print("")

setup_aws_credentials()

if not check_mturk_balance(num_hits=1, hit_reward=self.hit_reward, is_sandbox=self.is_sandbox):
return

if not self.__class__.skip_init:
print('Setting up MTurk backend...')
html_api_endpoint_url, json_api_endpoint_url, requester_key_gt = setup_aws(self.task_description, 1, self.is_sandbox)
self.__class__.html_api_endpoint_url = html_api_endpoint_url
self.__class__.json_api_endpoint_url = json_api_endpoint_url
self.__class__.requester_key_gt = requester_key_gt
print("MTurk setup done.\n")

self.__class__.skip_init = True
self.html_api_endpoint_url = self.__class__.html_api_endpoint_url
self.json_api_endpoint_url = self.__class__.json_api_endpoint_url
self.requester_key_gt = self.__class__.requester_key_gt

self.task_group_id = str(self.task_name) + '_' + str(self.conversation_id)

print('Creating HITs...')
hit_type_id = create_hit_type(
hit_title=self.hit_title,
hit_description=self.hit_description + ' (ID: ' + self.task_group_id + ', Role: ' + self.id + ')',
hit_keywords=self.hit_keywords,
hit_reward=self.hit_reward,
is_sandbox=self.is_sandbox
)
all_agent_ids_string = str(self.all_agent_ids).replace("'", '''"''')
mturk_chat_url = self.html_api_endpoint_url + "?method_name=chat_index&task_group_id="+str(self.task_group_id)+"&conversation_id="+str(self.conversation_id)+"&all_agent_ids="+all_agent_ids_string+"&cur_agent_id="+str(self.id)
mturk_page_url = create_hit_with_hit_type(
page_url=mturk_chat_url,
hit_type_id=hit_type_id,
is_sandbox=self.is_sandbox
)

print("Link to HIT for " + self.id + ": " + mturk_page_url + "\n")
print("Waiting for Turkers to respond... (Please don't close your laptop or put your computer into sleep or standby mode.)\n")

def observe(self, msg):
if msg['id'] not in self.mturk_agent_ids: # If the message sender is an mturk agent, then there is no need to upload this message to db since it's already been done on the message sender side.
conversation_dict = _get_new_messages(
json_api_endpoint_url=self.json_api_endpoint_url,
task_group_id=self.task_group_id,
conversation_id=self.conversation_id,
after_message_id=self.last_message_id,
included_agent_id=msg['id'])['conversation_dict']
if self.conversation_id in conversation_dict:
agent_last_message_in_db = conversation_dict[self.conversation_id][0]
agent_last_message_in_db.pop('message_id', None)
if 'episode_done' not in msg:
msg['episode_done'] = False
if agent_last_message_in_db == msg:
return

_send_new_message(
json_api_endpoint_url=self.json_api_endpoint_url,
task_group_id=self.task_group_id,
conversation_id=self.conversation_id,
agent_id=msg['id'],
message_text=msg.get('text', None),
reward=msg.get('reward', None),
episode_done=msg.get('episode_done', False),
)

def act(self):
while True:
ret = _get_new_messages(
json_api_endpoint_url=self.json_api_endpoint_url,
task_group_id=self.task_group_id,
conversation_id=self.conversation_id,
after_message_id=self.last_message_id,
included_agent_id=self.id
)
conversation_dict = ret['conversation_dict']

if str(self.conversation_id) in conversation_dict:
new_last_message_id = ret['last_message_id']
if new_last_message_id:
self.last_message_id = new_last_message_id

new_messages = conversation_dict[str(self.conversation_id)]

return new_messages[0]

time.sleep(1) # Wait for 1 second, so that we are not polling too frequently.

def episode_done(self):
return False

def shutdown(self):
if _get_review_status_count(json_api_endpoint_url=self.json_api_endpoint_url, task_group_id=self.task_group_id, conversation_id=self.conversation_id, review_status='approved', requester_key=self.requester_key_gt) + \
_get_review_status_count(json_api_endpoint_url=self.json_api_endpoint_url, task_group_id=self.task_group_id, conversation_id=self.conversation_id, review_status='rejected', requester_key=self.requester_key_gt) > 0:
return
else:
# Loop to ensure all HITs are done
while _get_review_status_count(json_api_endpoint_url=self.json_api_endpoint_url, task_group_id=self.task_group_id, conversation_id=self.conversation_id, review_status='pending', requester_key=self.requester_key_gt) < len(self.mturk_agent_ids):
time.sleep(2)

mturk_agent_ids_string = str(self.mturk_agent_ids).replace("'", '''"''')
mturk_approval_url = self.html_api_endpoint_url + "?method_name=approval_index&task_group_id="+str(self.task_group_id)+"&conversation_id="+str(self.conversation_id)+"&mturk_agent_ids="+mturk_agent_ids_string+"&requester_key="+self.requester_key_gt
print("\nAll HITs are done! Please go to the following link to approve/reject them (or they will be auto-approved in 4 weeks if no action is taken):\n")
print(mturk_approval_url)
print("")

# Loop for checking review status
while _get_review_status_count(json_api_endpoint_url=self.json_api_endpoint_url, task_group_id=self.task_group_id, conversation_id=self.conversation_id, review_status='pending', requester_key=self.requester_key_gt) > 0:
time.sleep(2)

print("All reviews are done!")
25 changes: 18 additions & 7 deletions parlai/mturk/core/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def send_new_message(db_session, task_group_id, conversation_id, agent_id, messa
return new_message_object


def get_new_messages(db_session, task_group_id, conversation_id=None, after_message_id=None, excluded_agent_id=None, populate_meta_info=False):
def get_new_messages(db_session, task_group_id, conversation_id=None, after_message_id=None, excluded_agent_id=None, included_agent_id=None, populate_meta_info=False):
"""
Return:
conversation_dict = {
Expand All @@ -123,13 +123,19 @@ def get_new_messages(db_session, task_group_id, conversation_id=None, after_mess
if not after_message_id:
after_message_id = -1

included_agent_ids = []
if included_agent_id:
included_agent_ids = [included_agent_id]

excluded_agent_ids = []
if excluded_agent_id:
excluded_agent_ids = [excluded_agent_id]

last_message_id = None

query = db_session.query(Message).filter(Message.task_group_id==task_group_id).filter(Message.id > after_message_id)
if len(included_agent_ids) > 0:
query = query.filter(Message.agent_id.in_(included_agent_ids))
if len(excluded_agent_ids) > 0:
query = query.filter(~Message.agent_id.in_(excluded_agent_ids))
if conversation_id:
Expand Down Expand Up @@ -164,7 +170,12 @@ def get_new_messages(db_session, task_group_id, conversation_id=None, after_mess


def set_hit_info(db_session, task_group_id, conversation_id, assignment_id, hit_id, worker_id, is_sandbox, approval_status='pending'):
existing_object = db_session.query(MTurkHITInfo).filter(MTurkHITInfo.task_group_id==task_group_id).filter(MTurkHITInfo.conversation_id==conversation_id).first()
existing_object = db_session.query(MTurkHITInfo) \
.filter(MTurkHITInfo.task_group_id==task_group_id) \
.filter(MTurkHITInfo.conversation_id==conversation_id) \
.filter(MTurkHITInfo.assignment_id==assignment_id) \
.filter(MTurkHITInfo.hit_id==hit_id) \
.first()
if not existing_object:
new_hit_info_object = MTurkHITInfo(
task_group_id=task_group_id,
Expand All @@ -187,12 +198,12 @@ def set_hit_info(db_session, task_group_id, conversation_id, assignment_id, hit_
db_session.commit()


def get_hit_info(db_session, task_group_id, conversation_id):
existing_object = db_session.query(MTurkHITInfo).filter(MTurkHITInfo.task_group_id==task_group_id).filter(MTurkHITInfo.conversation_id==conversation_id).first()
return existing_object
def get_all_matching_hit_infos(db_session, task_group_id, conversation_id):
matching_hit_infos = list(db_session.query(MTurkHITInfo).filter(MTurkHITInfo.task_group_id==task_group_id).filter(MTurkHITInfo.conversation_id==conversation_id).all())
return matching_hit_infos

def get_pending_review_count(db_session, task_group_id):
return db_session.query(MTurkHITInfo).filter(MTurkHITInfo.task_group_id==task_group_id).filter(MTurkHITInfo.approval_status=='pending').count()
def get_review_status_count(db_session, task_group_id, conversation_id, review_status):
return db_session.query(MTurkHITInfo).filter(MTurkHITInfo.task_group_id==task_group_id).filter(MTurkHITInfo.conversation_id==conversation_id).filter(MTurkHITInfo.approval_status==review_status).count()

def get_all_review_status(db_session, task_group_id):
return db_session.query(MTurkHITInfo).filter(MTurkHITInfo.task_group_id==task_group_id).order_by(MTurkHITInfo.conversation_id).all()
62 changes: 36 additions & 26 deletions parlai/mturk/core/handler_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def chat_index(event, context):
try:
task_group_id = event['query']['task_group_id']
conversation_id = event['query']['conversation_id']
all_agent_ids = event['query']['all_agent_ids']
cur_agent_id = event['query']['cur_agent_id']
assignment_id = event['query']['assignmentId'] # from mturk

Expand All @@ -63,6 +64,7 @@ def chat_index(event, context):
template_context['task_group_id'] = task_group_id
template_context['conversation_id'] = conversation_id
template_context['cur_agent_id'] = cur_agent_id
template_context['all_agent_ids'] = all_agent_ids
template_context['task_description'] = task_description
template_context['mturk_submit_url'] = mturk_submit_url
template_context['is_cover_page'] = False
Expand Down Expand Up @@ -112,13 +114,15 @@ def get_new_messages(event, context):
if 'conversation_id' in event['query']:
conversation_id = int(event['query']['conversation_id'])
excluded_agent_id = event['query'].get('excluded_agent_id', None)
included_agent_id = event['query'].get('included_agent_id', None)

conversation_dict, new_last_message_id = data_model.get_new_messages(
db_session=db_session,
task_group_id=task_group_id,
conversation_id=conversation_id,
after_message_id=last_message_id,
excluded_agent_id=excluded_agent_id,
included_agent_id=included_agent_id,
populate_meta_info=True
)

Expand Down Expand Up @@ -179,12 +183,12 @@ def approval_index(event, context):

task_group_id = event['query']['task_group_id']
conversation_id = event['query']['conversation_id']
cur_agent_id = event['query']['cur_agent_id']
mturk_agent_ids = event['query']['mturk_agent_ids']

template_context = {}
template_context['task_group_id'] = task_group_id
template_context['conversation_id'] = conversation_id
template_context['cur_agent_id'] = cur_agent_id
template_context['mturk_agent_ids'] = mturk_agent_ids
template_context['task_description'] = task_description
template_context['is_cover_page'] = False
template_context['is_approval_page'] = True
Expand Down Expand Up @@ -212,49 +216,55 @@ def review_hit(event, context):
conversation_id = int(params['conversation_id'])
action = params['action'] # 'approve' or 'reject'

hit_info = data_model.get_hit_info(
hit_infos = data_model.get_all_matching_hit_infos(
db_session=db_session,
task_group_id=task_group_id,
conversation_id=conversation_id
)

if hit_info:
assignment_id = hit_info.assignment_id
client = boto3.client(
service_name = 'mturk',
region_name = 'us-east-1',
endpoint_url = 'https://mturk-requester-sandbox.us-east-1.amazonaws.com'
)
# Region is always us-east-1
if not hit_info.is_sandbox:
client = boto3.client(service_name = 'mturk', region_name='us-east-1')

if action == 'approve':
client.approve_assignment(AssignmentId=assignment_id)
hit_info.approval_status = 'approved'
elif action == 'reject':
client.reject_assignment(AssignmentId=assignment_id, RequesterFeedback='')
hit_info.approval_status = 'rejected'
db_session.add(hit_info)
db_session.commit()
if len(hit_infos) > 0:
for hit_info in hit_infos:
assignment_id = hit_info.assignment_id
client = boto3.client(
service_name = 'mturk',
region_name = 'us-east-1',
endpoint_url = 'https://mturk-requester-sandbox.us-east-1.amazonaws.com'
)
# Region is always us-east-1
if not hit_info.is_sandbox:
client = boto3.client(service_name = 'mturk', region_name='us-east-1')

if action == 'approve':
client.approve_assignment(AssignmentId=assignment_id)
hit_info.approval_status = 'approved'
elif action == 'reject':
client.reject_assignment(AssignmentId=assignment_id, RequesterFeedback='')
hit_info.approval_status = 'rejected'
db_session.add(hit_info)
db_session.commit()

except KeyError:
raise Exception('400')

def get_pending_review_count(event, context):
def get_review_status_count(event, context):
if event['method'] == 'GET':
"""
Handler for getting the number of pending reviews.
Expects <requester_key>, <task_group_id> as query parameters.
Expects <requester_key>, <task_group_id>, <conversation_id> as query parameters.
"""
try:
requester_key = event['query']['requester_key']
if not requester_key == requester_key_gt:
raise Exception('403')

task_group_id = event['query']['task_group_id']
return data_model.get_pending_review_count(
conversation_id = event['query']['conversation_id']
review_status = event['query']['review_status']
return data_model.get_review_status_count(
db_session=db_session,
task_group_id=task_group_id
task_group_id=task_group_id,
conversation_id=conversation_id,
review_status=review_status
)
except KeyError:
raise Exception('400')
Expand Down
Loading