diff --git a/README.md b/README.md index f06fcac44cf..afc0ae6eb74 100644 --- a/README.md +++ b/README.md @@ -263,36 +263,36 @@ To add your own task: ### MTurk An important part of ParlAI is seamless integration with Mechanical Turk for data collection, training and evaluation. + Human Turkers are also viewed as agents in ParlAI and hence person-person, person-bot, or multiple people and bots in group chat can all converse within the standard framework, switching out the roles as desired with no code changes to the agents. This is because Turkers also receive and send via a (pretty printed) version of the same interface, using the fields of the observation/action dict. -We provide two examples in the first release, collecting data, and human evaluation of a bot. + +We currently provide three examples: collecting data, human evaluation of a bot, and round-robin chat between local humans and remote Turkers.

-The mturk library contains the following directories and files: +The mturk library contains the following directories: -- **core**: this directory contains the core code for setting up AWS backend that supports the MTurk chat interface, and code for HIT creation and approval. -- **tasks**: this directory contains two sample MTurk tasks that are provided in the first release. +- **core**: this directory contains the core code for setting up AWS backend that supports the MTurk chat interface, code for HIT creation and approval, and the wrapper class `MTurkAgent` which encapsulates the MTurk interface into a standard `Agent` class. +- **tasks**: this directory contains three sample MTurk tasks. - **_qa\_data\_collection_**: get questions and answers from turkers, given a random paragraph from SQuAD. - - **_model\_evaluator_**: evaluate the information retrieval baseline model on the Reddit movie dialog dataset. -- **run_mturk.py**: file for calling mturk core code with user-specified task module, dialog model agent, number of HITs, and reward for each HIT. + - **_model\_evaluator_**: ask turkers to evaluate the information retrieval baseline model on the Reddit movie dialog dataset. + - **_multi\_agent\_dialog_**: round-robin chat between two local human agents and two Turkers. -To run sample MTurk task and agent: -- In __run\_mturk.py__, uncomment the task module and the agent class you want to use -- For `create_hits` method, change `num_hits` and `hit_reward` if needed. Set `is_sandbox` to `True` if you want to run the sample in MTurk sandbox only, or set it to `False` to allow turkers to work on it and potentially get paid for it. -- Run `python run_mturk.py` +To run an MTurk task: +- Go into the directory for the task you want to run. +- Run `python run.py -nh -r [--sandbox]/[--live]`, with `` and `` set appropriately. Use `--sandbox` to run the task in MTurk sandbox mode before pushing it live. -To add your own MTurk task and dialog model: +To add your own MTurk task: - create a new folder within the mturk/tasks directory for your new task -- implement __task\_config.py__, with at least the following fields in the task_config dictionary: +- implement __task\_config.py__, with at least the following fields in the `task_config` dictionary: - `hit_title`: a short and descriptive title about the kind of task the HIT contains. On the Amazon Mechanical Turk web site, the HIT title appears in search results, and everywhere the HIT is mentioned. - `hit_description`: a description includes detailed information about the kind of task the HIT contains. On the Amazon Mechanical Turk web site, the HIT description appears in the expanded view of search results, and in the HIT and assignment screens. - `hit_keywords`: one or more words or phrases that describe the HIT, separated by commas. On MTurk website, these words are used in searches to find HITs. - - `worker_agent_id`: a short name indicating the turker's role in the conversation. - `task_description`: a detailed task description that will be shown on the HIT task preview page and on the left side of the chat page. Supports HTML formatting. -- implement __agents.py__, with at least an agent class that extends from Agent - - write your own `__init__()` method that wraps your dialog model agent. (Please see mturk/tasks/model_evaluator/agents.py file for a concrete example.) - - write your own `act()` method that returns your dialog model's response as well as helpful text to the turker for what action they should take next. -- import your task module and agent class in __run\_mturk.py__ file, and then run `python run_mturk.py`. +- implement __run.py__, with code for setting up and running the world where `MTurkAgent` lives in. +- (Optional) implement __worlds.py__, with a world class that extends from `World`. + +Please see [the MTurk tutorial](http://parl.ai/static/docs/mturk.html) to learn more about the MTurk examples and how to create and run your own task. ## Support If you have any questions, bug reports or feature requests, please don't hesitate to post on our [Github Issues page](https://github.com/facebookresearch/ParlAI/issues). diff --git a/docs/source/mturk_new.rst b/docs/source/mturk_new.rst new file mode 100644 index 00000000000..5cec6b111ed --- /dev/null +++ b/docs/source/mturk_new.rst @@ -0,0 +1,120 @@ +.. + Copyright (c) 2017-present, Facebook, Inc. + All rights reserved. + This source code is licensed under the BSD-style license found in the + LICENSE file in the root directory of this source tree. An additional grant + of patent rights can be found in the PATENTS file in the same directory. + +Using Mechanical Turk +===================== + +In ParlAI, you can use Amazon Mechanical Turk for **data collection**, **training** and **evaluation** of your dialog model. + +Human Turkers are viewed as just another type of agent in ParlAI, and hence person-to-person, person-to-bot, or multiple people and bots in group chat can all talk to each other within the same framework. + +The human Turkers communicate in observation/action dict format, the same as all other agents in ParlAI. During the conversation, the message that human Turkers receive is rendered on the live chat webpage in a pretty printed format, similar to the following: + +.. figure:: _static/img/mturk-small.png + :align: center + + Example: Human Turker participating in a QA data collection task + +Each MTurk task has at least one human Turker that connects to ParlAI via the Mechanical Turk Live Chat interface, encapsulated as an ``MTurkAgent`` object. + +Each MTurk task also consists of a ``World`` where all agents live and interact within. + +Example Tasks +------------- + +We provide a few examples of using Mechanical Turk with ParlAI: + +- `QA Data Collection `__: collect questions and answers from Turkers, given a random Wikipedia paragraph from SQuAD. +- `Model Evaluator `__: ask Turkers to evaluate the information retrieval baseline model on the Reddit movie dialog dataset. +- `Multi-Agent Dialog `__: round-robin chat between two local human agents and two Turkers. + +Task 1: Collecting Data +^^^^^^^^^^^^^^^^^^^^^^^ + +One of the biggest use cases of Mechanical Turk is to collect natural language data from human Turkers. + +As an example, the `QA Data Collection task `__ does the following: + +1. Pick a random Wikipedia paragraph from SQuAD dataset. +2. Ask a Turker to provide a question given the paragraph. +3. Ask the same Turker to provide an answer to their question. + +In ``QADataCollectionWorld``, there are two agents: one is the human Turker (``MTurkAgent``), the other is the task agent (``DefaultTeacher`` from SQuAD) that provides the Wikipedia paragraph. + +The ``QADataCollectionWorld`` uses ``turn_index`` to denote what stage the conversation is at. One *turn* means that the world has been called ``parley()`` once. + +After two turns, the task is finished, and the Turker's work is submitted for your review. + + +Task 2: Evaluating a Dialog Model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +You can easily evaluate your dialog model's performance with human Turkers using ParlAI. As an example, the `Model Evaluator task `__ does the following: + +1. Initialize a task world with a dialog model agent (`ir_baseline `__) and a dataset (`MovieDD-Reddit `__). +2. Let all the agents in the task world ``observe()`` and ``act()`` once, by calling ``parley()`` on the world. +3. Ask the human Turker to rate the dialog model agent's response on a scale of 0-10. + +In ``ModelEvaluatorWorld``, there are two main components: one is the ``task_world`` that contains the task and the dialog model we are evaluating, the other is the ``MTurkAgent`` which is an interface to the human Turker. + +Note that since the human Turker speaks only once to provide the rating, the ``ModelEvaluatorWorld`` doesn't need to use ``turn_index`` to keep track of the turns. + +After one turn, the task is finished, and the Turker's work is submitted for your review. + + +Creating Your Own Task +---------------------- + +ParlAI provides a generic MTurk dialog interface that one can use to implement any kind of dialog tasks. To create your own task, start with reading the tutorials on the provided examples, and then copy and modify the example ``worlds.py``, ``run.py`` and ``task_config.py`` files to create your task. + +A few things to keep in mind: + +1. To end a conversation, you should send a message with ``episode_done = True`` from the first non-MTurk agent, and the conversation is ended after all MTurk agents respond. +2. Make sure to test your dialog task using MTurk's sandbox mode before pushing it live, by using the ``--sandbox`` flag (enabled by default) when running ``run.py``. + + +Running a Task +-------------- + +If you have not used Mechanical Turk before, you will need an MTurk Requester Account and an AWS account (these are two separate accounts). Follow the steps below: + +- Sign up for an AWS account at `aws.amazon.com `__ + +- Sign up for an MTurk account at `requester.mturk.com `__ + +- Go to the developer tab (`https://requester.mturk.com/developer `__) and link your AWS account to your MTurk account (Step 2 on that screen) + +- MTurk also has a “Sandbox” which is a test version of the MTurk marketplace. You can use it to test publishing and completing tasks without paying any money. ParlAI supports the Sandbox. To use the Sandbox, you need to sign up for a `Sandbox account `__. You will then also need to `link your AWS account `__ to your Sandbox account. In order to test faster, you will also want to create a `Sandbox Worker account `__. You can then view tasks your publish from ParlAI and complete them yourself. + +- ParlAI will connect to your AWS account and set up some supporting resources including a Lambda function, an API Gateway and an RDS database. It will also use your AWS account to connect to the MTurk API. In order to do this, it will require credentials to access your AWS account. To set this up, you will need to create an `IAM user `__ with programmatic access and an AdministratorAccess policy. You can learn more about how to set up IAM users `here `__. Once you have created the account, keep its access key and the secret key handy as you will need it next. + +Then, to run an MTurk task, first ensure that the task directory is in `parlai/mturk/tasks/ `__. Then, run its ``run.py`` file with proper flags: + +.. code-block:: python + + python run.py -nh -r [--sandbox]/[--live] + +E.g. to create 2 HITs for the `QA Data Collection `__ example with $0.05 each in sandbox mode, first go into the task directory and then run: + +.. code-block:: python + + python run.py -nh 2 -r 0.05 --sandbox + +Please make sure to test your task in MTurk sandbox mode first (``--sandbox``) before pushing it live (``--live``). + + +Reviewing Turker's Work +----------------------- + +After all HITs are completed, you will be provided a webpage link to review them. + +If you don't take any action in 4 weeks, all HITs will be auto-approved and Turkers will be paid. + + +------- + +\* Turker icon credit: `Amazon Mechanical Turk `__. Robot icon credit: `Icons8 `__. \ No newline at end of file diff --git a/parlai/mturk/core/agents.py b/parlai/mturk/core/agents.py new file mode 100644 index 00000000000..0a40cb7a767 --- /dev/null +++ b/parlai/mturk/core/agents.py @@ -0,0 +1,202 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. An additional grant +# of patent rights can be found in the PATENTS file in the same directory. + +from parlai.core.agents import Agent +from parlai.core.worlds import display_messages + +import os +import time +from datetime import datetime +import random +import string +import webbrowser +import json +import requests +from parlai.core.agents import create_agent_from_shared +from .setup_aws import setup_aws, check_mturk_balance, create_hit_type, create_hit_with_hit_type, setup_aws_credentials + + +def _get_new_messages(json_api_endpoint_url, task_group_id, conversation_id, after_message_id, excluded_agent_id=None, included_agent_id=None): + params = { + 'method_name': 'get_new_messages', + 'task_group_id': task_group_id, + 'last_message_id': after_message_id, + 'conversation_id': conversation_id, + } + if excluded_agent_id: + params['excluded_agent_id'] = excluded_agent_id + if included_agent_id: + params['included_agent_id'] = included_agent_id + + request = requests.get(json_api_endpoint_url, params=params) + return json.loads(request.json()) + +def _send_new_message(json_api_endpoint_url, task_group_id, conversation_id, agent_id, message_text=None, reward=None, episode_done=False): + post_data_dict = { + 'method_name': 'send_new_message', + 'task_group_id': task_group_id, + 'conversation_id': conversation_id, + 'cur_agent_id': agent_id, + 'episode_done': episode_done, + } + if message_text: + post_data_dict['text'] = message_text + if reward: + post_data_dict['reward'] = reward + + request = requests.post(json_api_endpoint_url, data=json.dumps(post_data_dict)) + return json.loads(request.json()) + +def _get_review_status_count(json_api_endpoint_url, task_group_id, conversation_id, review_status, requester_key): + params = { + 'method_name': 'get_review_status_count', + 'task_group_id': task_group_id, + 'conversation_id': conversation_id, + 'review_status': review_status, + 'requester_key': requester_key + } + request = requests.get(json_api_endpoint_url, params=params) + return request.json() + +class MTurkAgent(Agent): + + skip_init = False + html_api_endpoint_url = None + json_api_endpoint_url = None + requester_key_gt = None + + def __init__(self, opt, shared=None): + super().__init__(opt) + + self.id = opt['agent_id'] + self.task_name = opt['task'] + self.is_sandbox = opt['is_sandbox'] + self.conversation_id = opt['conversation_id'] + self.mturk_agent_ids = opt['mturk_agent_ids'] + self.all_agent_ids = opt['all_agent_ids'] + self.hit_reward = opt['reward'] + self.hit_title = opt['hit_title'] + self.hit_description = opt['hit_description'] + self.hit_keywords = opt['hit_keywords'] + self.task_description = opt['task_description'] + + self.last_message_id = 0 + + if not self.__class__.skip_init: + print("\nYou are going to allow workers from Amazon Mechanical Turk to be an agent in ParlAI.\nDuring this process, Internet connection is required, and you should turn off your computer's auto-sleep feature.\n") + key_input = input("Please press Enter to continue... ") + print("") + + setup_aws_credentials() + + if not check_mturk_balance(num_hits=1, hit_reward=self.hit_reward, is_sandbox=self.is_sandbox): + return + + if not self.__class__.skip_init: + print('Setting up MTurk backend...') + html_api_endpoint_url, json_api_endpoint_url, requester_key_gt = setup_aws(self.task_description, 1, self.is_sandbox) + self.__class__.html_api_endpoint_url = html_api_endpoint_url + self.__class__.json_api_endpoint_url = json_api_endpoint_url + self.__class__.requester_key_gt = requester_key_gt + print("MTurk setup done.\n") + + self.__class__.skip_init = True + self.html_api_endpoint_url = self.__class__.html_api_endpoint_url + self.json_api_endpoint_url = self.__class__.json_api_endpoint_url + self.requester_key_gt = self.__class__.requester_key_gt + + self.task_group_id = str(self.task_name) + '_' + str(self.conversation_id) + + print('Creating HITs...') + hit_type_id = create_hit_type( + hit_title=self.hit_title, + hit_description=self.hit_description + ' (ID: ' + self.task_group_id + ', Role: ' + self.id + ')', + hit_keywords=self.hit_keywords, + hit_reward=self.hit_reward, + is_sandbox=self.is_sandbox + ) + all_agent_ids_string = str(self.all_agent_ids).replace("'", '''"''') + mturk_chat_url = self.html_api_endpoint_url + "?method_name=chat_index&task_group_id="+str(self.task_group_id)+"&conversation_id="+str(self.conversation_id)+"&all_agent_ids="+all_agent_ids_string+"&cur_agent_id="+str(self.id) + mturk_page_url = create_hit_with_hit_type( + page_url=mturk_chat_url, + hit_type_id=hit_type_id, + is_sandbox=self.is_sandbox + ) + + print("Link to HIT for " + self.id + ": " + mturk_page_url + "\n") + print("Waiting for Turkers to respond... (Please don't close your laptop or put your computer into sleep or standby mode.)\n") + + def observe(self, msg): + if msg['id'] not in self.mturk_agent_ids: # If the message sender is an mturk agent, then there is no need to upload this message to db since it's already been done on the message sender side. + conversation_dict = _get_new_messages( + json_api_endpoint_url=self.json_api_endpoint_url, + task_group_id=self.task_group_id, + conversation_id=self.conversation_id, + after_message_id=self.last_message_id, + included_agent_id=msg['id'])['conversation_dict'] + if self.conversation_id in conversation_dict: + agent_last_message_in_db = conversation_dict[self.conversation_id][0] + agent_last_message_in_db.pop('message_id', None) + if 'episode_done' not in msg: + msg['episode_done'] = False + if agent_last_message_in_db == msg: + return + + _send_new_message( + json_api_endpoint_url=self.json_api_endpoint_url, + task_group_id=self.task_group_id, + conversation_id=self.conversation_id, + agent_id=msg['id'], + message_text=msg.get('text', None), + reward=msg.get('reward', None), + episode_done=msg.get('episode_done', False), + ) + + def act(self): + while True: + ret = _get_new_messages( + json_api_endpoint_url=self.json_api_endpoint_url, + task_group_id=self.task_group_id, + conversation_id=self.conversation_id, + after_message_id=self.last_message_id, + included_agent_id=self.id + ) + conversation_dict = ret['conversation_dict'] + + if str(self.conversation_id) in conversation_dict: + new_last_message_id = ret['last_message_id'] + if new_last_message_id: + self.last_message_id = new_last_message_id + + new_messages = conversation_dict[str(self.conversation_id)] + + return new_messages[0] + + time.sleep(1) # Wait for 1 second, so that we are not polling too frequently. + + def episode_done(self): + return False + + def shutdown(self): + if _get_review_status_count(json_api_endpoint_url=self.json_api_endpoint_url, task_group_id=self.task_group_id, conversation_id=self.conversation_id, review_status='approved', requester_key=self.requester_key_gt) + \ + _get_review_status_count(json_api_endpoint_url=self.json_api_endpoint_url, task_group_id=self.task_group_id, conversation_id=self.conversation_id, review_status='rejected', requester_key=self.requester_key_gt) > 0: + return + else: + # Loop to ensure all HITs are done + while _get_review_status_count(json_api_endpoint_url=self.json_api_endpoint_url, task_group_id=self.task_group_id, conversation_id=self.conversation_id, review_status='pending', requester_key=self.requester_key_gt) < len(self.mturk_agent_ids): + time.sleep(2) + + mturk_agent_ids_string = str(self.mturk_agent_ids).replace("'", '''"''') + mturk_approval_url = self.html_api_endpoint_url + "?method_name=approval_index&task_group_id="+str(self.task_group_id)+"&conversation_id="+str(self.conversation_id)+"&mturk_agent_ids="+mturk_agent_ids_string+"&requester_key="+self.requester_key_gt + print("\nAll HITs are done! Please go to the following link to approve/reject them (or they will be auto-approved in 4 weeks if no action is taken):\n") + print(mturk_approval_url) + print("") + + # Loop for checking review status + while _get_review_status_count(json_api_endpoint_url=self.json_api_endpoint_url, task_group_id=self.task_group_id, conversation_id=self.conversation_id, review_status='pending', requester_key=self.requester_key_gt) > 0: + time.sleep(2) + + print("All reviews are done!") \ No newline at end of file diff --git a/parlai/mturk/core/data_model.py b/parlai/mturk/core/data_model.py index 9f204440926..f93e3579bee 100644 --- a/parlai/mturk/core/data_model.py +++ b/parlai/mturk/core/data_model.py @@ -100,7 +100,7 @@ def send_new_message(db_session, task_group_id, conversation_id, agent_id, messa return new_message_object -def get_new_messages(db_session, task_group_id, conversation_id=None, after_message_id=None, excluded_agent_id=None, populate_meta_info=False): +def get_new_messages(db_session, task_group_id, conversation_id=None, after_message_id=None, excluded_agent_id=None, included_agent_id=None, populate_meta_info=False): """ Return: conversation_dict = { @@ -123,6 +123,10 @@ def get_new_messages(db_session, task_group_id, conversation_id=None, after_mess if not after_message_id: after_message_id = -1 + included_agent_ids = [] + if included_agent_id: + included_agent_ids = [included_agent_id] + excluded_agent_ids = [] if excluded_agent_id: excluded_agent_ids = [excluded_agent_id] @@ -130,6 +134,8 @@ def get_new_messages(db_session, task_group_id, conversation_id=None, after_mess last_message_id = None query = db_session.query(Message).filter(Message.task_group_id==task_group_id).filter(Message.id > after_message_id) + if len(included_agent_ids) > 0: + query = query.filter(Message.agent_id.in_(included_agent_ids)) if len(excluded_agent_ids) > 0: query = query.filter(~Message.agent_id.in_(excluded_agent_ids)) if conversation_id: @@ -164,7 +170,12 @@ def get_new_messages(db_session, task_group_id, conversation_id=None, after_mess def set_hit_info(db_session, task_group_id, conversation_id, assignment_id, hit_id, worker_id, is_sandbox, approval_status='pending'): - existing_object = db_session.query(MTurkHITInfo).filter(MTurkHITInfo.task_group_id==task_group_id).filter(MTurkHITInfo.conversation_id==conversation_id).first() + existing_object = db_session.query(MTurkHITInfo) \ + .filter(MTurkHITInfo.task_group_id==task_group_id) \ + .filter(MTurkHITInfo.conversation_id==conversation_id) \ + .filter(MTurkHITInfo.assignment_id==assignment_id) \ + .filter(MTurkHITInfo.hit_id==hit_id) \ + .first() if not existing_object: new_hit_info_object = MTurkHITInfo( task_group_id=task_group_id, @@ -187,12 +198,12 @@ def set_hit_info(db_session, task_group_id, conversation_id, assignment_id, hit_ db_session.commit() -def get_hit_info(db_session, task_group_id, conversation_id): - existing_object = db_session.query(MTurkHITInfo).filter(MTurkHITInfo.task_group_id==task_group_id).filter(MTurkHITInfo.conversation_id==conversation_id).first() - return existing_object +def get_all_matching_hit_infos(db_session, task_group_id, conversation_id): + matching_hit_infos = list(db_session.query(MTurkHITInfo).filter(MTurkHITInfo.task_group_id==task_group_id).filter(MTurkHITInfo.conversation_id==conversation_id).all()) + return matching_hit_infos -def get_pending_review_count(db_session, task_group_id): - return db_session.query(MTurkHITInfo).filter(MTurkHITInfo.task_group_id==task_group_id).filter(MTurkHITInfo.approval_status=='pending').count() +def get_review_status_count(db_session, task_group_id, conversation_id, review_status): + return db_session.query(MTurkHITInfo).filter(MTurkHITInfo.task_group_id==task_group_id).filter(MTurkHITInfo.conversation_id==conversation_id).filter(MTurkHITInfo.approval_status==review_status).count() def get_all_review_status(db_session, task_group_id): return db_session.query(MTurkHITInfo).filter(MTurkHITInfo.task_group_id==task_group_id).order_by(MTurkHITInfo.conversation_id).all() \ No newline at end of file diff --git a/parlai/mturk/core/handler_template.py b/parlai/mturk/core/handler_template.py index ce5a0d3b0d0..240f4ff53c9 100755 --- a/parlai/mturk/core/handler_template.py +++ b/parlai/mturk/core/handler_template.py @@ -53,6 +53,7 @@ def chat_index(event, context): try: task_group_id = event['query']['task_group_id'] conversation_id = event['query']['conversation_id'] + all_agent_ids = event['query']['all_agent_ids'] cur_agent_id = event['query']['cur_agent_id'] assignment_id = event['query']['assignmentId'] # from mturk @@ -63,6 +64,7 @@ def chat_index(event, context): template_context['task_group_id'] = task_group_id template_context['conversation_id'] = conversation_id template_context['cur_agent_id'] = cur_agent_id + template_context['all_agent_ids'] = all_agent_ids template_context['task_description'] = task_description template_context['mturk_submit_url'] = mturk_submit_url template_context['is_cover_page'] = False @@ -112,6 +114,7 @@ def get_new_messages(event, context): if 'conversation_id' in event['query']: conversation_id = int(event['query']['conversation_id']) excluded_agent_id = event['query'].get('excluded_agent_id', None) + included_agent_id = event['query'].get('included_agent_id', None) conversation_dict, new_last_message_id = data_model.get_new_messages( db_session=db_session, @@ -119,6 +122,7 @@ def get_new_messages(event, context): conversation_id=conversation_id, after_message_id=last_message_id, excluded_agent_id=excluded_agent_id, + included_agent_id=included_agent_id, populate_meta_info=True ) @@ -179,12 +183,12 @@ def approval_index(event, context): task_group_id = event['query']['task_group_id'] conversation_id = event['query']['conversation_id'] - cur_agent_id = event['query']['cur_agent_id'] + mturk_agent_ids = event['query']['mturk_agent_ids'] template_context = {} template_context['task_group_id'] = task_group_id template_context['conversation_id'] = conversation_id - template_context['cur_agent_id'] = cur_agent_id + template_context['mturk_agent_ids'] = mturk_agent_ids template_context['task_description'] = task_description template_context['is_cover_page'] = False template_context['is_approval_page'] = True @@ -212,39 +216,41 @@ def review_hit(event, context): conversation_id = int(params['conversation_id']) action = params['action'] # 'approve' or 'reject' - hit_info = data_model.get_hit_info( + hit_infos = data_model.get_all_matching_hit_infos( db_session=db_session, task_group_id=task_group_id, conversation_id=conversation_id ) - if hit_info: - assignment_id = hit_info.assignment_id - client = boto3.client( - service_name = 'mturk', - region_name = 'us-east-1', - endpoint_url = 'https://mturk-requester-sandbox.us-east-1.amazonaws.com' - ) - # Region is always us-east-1 - if not hit_info.is_sandbox: - client = boto3.client(service_name = 'mturk', region_name='us-east-1') - - if action == 'approve': - client.approve_assignment(AssignmentId=assignment_id) - hit_info.approval_status = 'approved' - elif action == 'reject': - client.reject_assignment(AssignmentId=assignment_id, RequesterFeedback='') - hit_info.approval_status = 'rejected' - db_session.add(hit_info) - db_session.commit() + if len(hit_infos) > 0: + for hit_info in hit_infos: + assignment_id = hit_info.assignment_id + client = boto3.client( + service_name = 'mturk', + region_name = 'us-east-1', + endpoint_url = 'https://mturk-requester-sandbox.us-east-1.amazonaws.com' + ) + # Region is always us-east-1 + if not hit_info.is_sandbox: + client = boto3.client(service_name = 'mturk', region_name='us-east-1') + + if action == 'approve': + client.approve_assignment(AssignmentId=assignment_id) + hit_info.approval_status = 'approved' + elif action == 'reject': + client.reject_assignment(AssignmentId=assignment_id, RequesterFeedback='') + hit_info.approval_status = 'rejected' + db_session.add(hit_info) + db_session.commit() + except KeyError: raise Exception('400') -def get_pending_review_count(event, context): +def get_review_status_count(event, context): if event['method'] == 'GET': """ Handler for getting the number of pending reviews. - Expects , as query parameters. + Expects , , as query parameters. """ try: requester_key = event['query']['requester_key'] @@ -252,9 +258,13 @@ def get_pending_review_count(event, context): raise Exception('403') task_group_id = event['query']['task_group_id'] - return data_model.get_pending_review_count( + conversation_id = event['query']['conversation_id'] + review_status = event['query']['review_status'] + return data_model.get_review_status_count( db_session=db_session, - task_group_id=task_group_id + task_group_id=task_group_id, + conversation_id=conversation_id, + review_status=review_status ) except KeyError: raise Exception('400') diff --git a/parlai/mturk/core/manage_hit.py b/parlai/mturk/core/manage_hit.py deleted file mode 100644 index 244a22cd7b9..00000000000 --- a/parlai/mturk/core/manage_hit.py +++ /dev/null @@ -1,247 +0,0 @@ -# Copyright (c) 2017-present, Facebook, Inc. -# All rights reserved. -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. An additional grant -# of patent rights can be found in the PATENTS file in the same directory. - -import os -import time -from datetime import datetime -import random -import string -import webbrowser -import json -import requests -from parlai.core.agents import create_agent_from_shared -from .setup_aws import setup_aws, check_mturk_balance, create_hit_type, create_hit_with_hit_type, setup_aws_credentials - - -def _get_random_alphanumeric_string(N): - return ''.join(random.SystemRandom().choice(string.ascii_letters + string.digits) for _ in range(N)) - - -def _setup_relay(task_config, num_hits, is_sandbox): - """Sets up relay server - """ - # set up relay server - html_api_endpoint_url, json_api_endpoint_url, requester_key_gt = setup_aws(task_config, num_hits, is_sandbox) - - return html_api_endpoint_url, json_api_endpoint_url, requester_key_gt - -def _send_new_message(json_api_endpoint_url, task_group_id, conversation_id, agent_id, message_text=None, reward=None, episode_done=False): - post_data_dict = { - 'method_name': 'send_new_message', - 'task_group_id': task_group_id, - 'conversation_id': conversation_id, - 'cur_agent_id': agent_id, - 'episode_done': episode_done, - } - if message_text: - post_data_dict['text'] = message_text - if reward: - post_data_dict['reward'] = reward - - request = requests.post(json_api_endpoint_url, data=json.dumps(post_data_dict)) - return json.loads(request.json()) - -def _get_new_messages(json_api_endpoint_url, task_group_id, after_message_id, excluded_agent_id=None): - params = { - 'method_name': 'get_new_messages', - 'task_group_id': task_group_id, - 'last_message_id': after_message_id, - } - if excluded_agent_id: - params['excluded_agent_id'] = excluded_agent_id - - request = requests.get(json_api_endpoint_url, params=params) - return json.loads(request.json()) - -def _get_pending_review_count(json_api_endpoint_url, task_group_id, requester_key): - params = { - 'method_name': 'get_pending_review_count', - 'task_group_id': task_group_id, - 'requester_key': requester_key - } - request = requests.get(json_api_endpoint_url, params=params) - return request.json() - -def _get_all_review_status(json_api_endpoint_url, task_group_id, requester_key): - params = { - 'method_name': 'get_all_review_status', - 'task_group_id': task_group_id, - 'requester_key': requester_key - } - request = requests.get(json_api_endpoint_url, params=params) - return request.json() - -def create_hits(opt, task_config, task_module_name, bot, chat_page_only=False): - num_hits = opt['num_hits'] - hit_reward = opt['reward'] - is_sandbox = opt['is_sandbox'] - verbose = opt['verbose'] - - print("\nYou are going to allow workers from Amazon Mechanical Turk to chat with your dialog model running on your local machine.\nDuring this process, Internet connection is required, and you should turn off your computer's auto-sleep feature.\n") - key_input = input("Please press Enter to continue... ") - print("") - - setup_aws_credentials() - if not check_mturk_balance(num_hits=num_hits, hit_reward=hit_reward, is_sandbox=is_sandbox): - return - - task_group_id = str(int(time.time())) + '_' + _get_random_alphanumeric_string(10) # Random string to further avoid collision - - print('Setting up MTurk backend...') - html_api_endpoint_url, json_api_endpoint_url, requester_key_gt = _setup_relay(task_config, num_hits, is_sandbox) - - approval_index_url_template = html_api_endpoint_url + "?method_name=approval_index&task_group_id={{task_group_id}}&conversation_id=1&cur_agent_id={{cur_agent_id}}&requester_key="+requester_key_gt - - worker_agent_id = task_config['worker_agent_id'] - bot_agent_id = bot.getID() - cids = range(1, num_hits+1) - cid_map = {cid: i for i, cid in enumerate(cids)} - c_done_map = {cid: False for cid in cids} - logs = {cid: [] for cid in cids} - - shared = bot.share() - bots = [] - last_message_id = -1 - - # If the bot needs to send the first message of the conversation, it will send it here - for cid in cids: - new_bot = create_agent_from_shared(shared) - new_bot.conversation_id = cid - bots.append(new_bot) - response = new_bot.act() - if response: - if response.get('episode_done', False): - c_done_map[cid] = True - if verbose: - print('Conversation '+str(cid)+' - Bot says: ' + str(response)) - logs[cid].append(response) - new_message = _send_new_message( - json_api_endpoint_url=json_api_endpoint_url, - task_group_id=task_group_id, - conversation_id=cid, - agent_id=bot_agent_id, - message_text=response.get('text', None), - reward=response.get('reward', None), - episode_done=response.get('episode_done', False), - ) - if new_message['message_id'] > last_message_id: - last_message_id = new_message['message_id'] - - hits_created = False - conversations_remaining = set(cids) - - # Main loop for polling and handling new messages - while len(conversations_remaining) > 0: - ret = _get_new_messages( - json_api_endpoint_url=json_api_endpoint_url, - task_group_id=task_group_id, - after_message_id=last_message_id, - excluded_agent_id=bot_agent_id, - ) - conversation_dict = ret['conversation_dict'] - new_last_message_id = ret['last_message_id'] - - if new_last_message_id: - last_message_id = new_last_message_id - - time.sleep(1) - - for conversation_id, new_messages in conversation_dict.items(): - conversation_id = int(conversation_id) - if conversation_id in conversations_remaining and len(new_messages) > 0: - agent = bots[cid_map[conversation_id]] - for new_message in new_messages: - if verbose: - print('Conversation '+str(conversation_id)+' - Bot received: ' + str(new_message)) - logs[conversation_id].append(new_message) - agent.observe(new_message) - if new_message.get('episode_done', False) or c_done_map[conversation_id]: - # We're done here - conversations_remaining.remove(conversation_id) - print('Conversation '+str(conversation_id)+' is DONE!\n') - else: - # Agent still needs to reply - response = agent.act() - if response: - if response.get('episode_done', False): - c_done_map[conversation_id] = True - if verbose: - print('Conversation '+str(conversation_id)+' - Bot says: ' + str(response)) - logs[conversation_id].append(response) - _send_new_message( - json_api_endpoint_url=json_api_endpoint_url, - task_group_id=task_group_id, - conversation_id=conversation_id, - agent_id=bot_agent_id, - message_text=response.get('text', None), - reward=response.get('reward', None), - episode_done=response.get('episode_done', False), - ) - - # We don't create new HITs until this point, so that the HIT page will always have the conversation fully populated. - if not hits_created: - print('Creating HITs...') - hit_type_id = create_hit_type( - hit_title=task_config['hit_title'], - hit_description=task_config['hit_description'] + ' (ID: ' + task_group_id + ')', - hit_keywords=task_config['hit_keywords'], - hit_reward=hit_reward, - is_sandbox=is_sandbox - ) - mturk_chat_url = None - mturk_page_url = None - for cid in cids: - mturk_chat_url = html_api_endpoint_url + "?method_name=chat_index&task_group_id="+str(task_group_id)+"&conversation_id="+str(cid)+"&cur_agent_id="+str(worker_agent_id) - if not chat_page_only: - mturk_page_url = create_hit_with_hit_type( - page_url=mturk_chat_url, - hit_type_id=hit_type_id, - is_sandbox=is_sandbox - ) - - print("MTurk setup done.\n") - if chat_page_only: - webbrowser.open(mturk_chat_url) - else: - print("Link to your HIT: " + mturk_page_url + "\n") - print("Waiting for Turkers to complete the tasks... (Please don't close your laptop or put your computer into sleep or standby mode.)\n") - hits_created = True - - while _get_pending_review_count(json_api_endpoint_url=json_api_endpoint_url, task_group_id=task_group_id, requester_key=requester_key_gt) != num_hits: - time.sleep(2) - - mturk_approval_url = html_api_endpoint_url + "?method_name=approval_index&task_group_id="+str(task_group_id)+"&conversation_id=1&cur_agent_id="+worker_agent_id+"&requester_key="+requester_key_gt - print("\nAll HITs are done! Please go to the following link to approve/reject them (or they will be auto-approved in 4 weeks if no action is taken):\n") - print(mturk_approval_url) - print("") - - approval_status_dict = {cid: '' for cid in cids} - # Loop for checking approval status - while _get_pending_review_count(json_api_endpoint_url=json_api_endpoint_url, task_group_id=task_group_id, requester_key=requester_key_gt) > 0: - time.sleep(2) - - print("Approvals are done!") - - for hit_info in _get_all_review_status(json_api_endpoint_url=json_api_endpoint_url, task_group_id=task_group_id, requester_key=requester_key_gt): - conversation_id = hit_info['conversation_id'] - approval_status_dict[conversation_id] = hit_info['approval_status'] - - logs_approved = {cid:log for (cid,log) in logs.items() if approval_status_dict[cid] == 'approved'} - logs_rejected = {cid:log for (cid,log) in logs.items() if approval_status_dict[cid] == 'rejected'} - - # Saving logs to file - # Log format: {conversation_id: [list of messages in the conversation]} - mturk_log_path = opt['mturk_log_path'] - task_group_path = os.path.join(mturk_log_path, - task_module_name + '_' + - datetime.now().strftime('%Y-%m-%d_%H:%M:%S')) - os.makedirs(task_group_path) - with open(os.path.join(task_group_path, 'approved.json'), 'w') as fout: - fout.write(json.dumps(logs_approved)) - with open(os.path.join(task_group_path, 'rejected.json'), 'w') as fout: - fout.write(json.dumps(logs_rejected)) - - print("All conversations are saved to "+opt['mturk_log_path']+" in JSON format.\n") diff --git a/parlai/mturk/core/mturk_index.html b/parlai/mturk/core/mturk_index.html index 3ab11c89838..7513b73098b 100755 --- a/parlai/mturk/core/mturk_index.html +++ b/parlai/mturk/core/mturk_index.html @@ -180,6 +180,7 @@

Live Chat

show_new_messages_on_UI(new_messages); if (done_after_responding) { update_for_response_type('done'); + task_done = true; } else { check_done(new_messages); } @@ -246,6 +247,22 @@

Live Chat

var task_group_id = `{{task_group_id}}`; var conversation_id = parseInt(`{{conversation_id}}`); var cur_agent_id = `{{cur_agent_id}}`; + var mturk_agent_ids = `{{mturk_agent_ids}}`; + if (mturk_agent_ids) { + mturk_agent_ids = JSON.parse(mturk_agent_ids); + } + var all_agent_ids = `{{all_agent_ids}}`; + var cur_agent_index = null; + var previous_agent_id = null; + if (all_agent_ids) { + all_agent_ids = JSON.parse(all_agent_ids); + cur_agent_index = all_agent_ids.indexOf(cur_agent_id); + if (cur_agent_index == 0) { + previous_agent_id = all_agent_ids[all_agent_ids.length-1]; + } else { + previous_agent_id = all_agent_ids[cur_agent_index-1]; + } + } var is_cover_page = (`{{is_cover_page}}` === 'True') ? true : false; var is_approval_page = (`{{is_approval_page}}` === 'True') ? true : false; var num_hits = parseInt(`{{num_hits}}`); @@ -254,7 +271,7 @@

Live Chat

var messages_processed = {}; var messages_shown = {}; var done_after_responding = false; - var self_display_name = 'You'; + var task_done = false; function show_new_messages_on_UI(new_messages) { for (var i = 0; i < new_messages.length; i++) { @@ -269,7 +286,7 @@

Live Chat

} if (!(message_id in messages_shown)) { - if (agent_id != cur_agent_id) { + if ((!is_approval_page && agent_id != cur_agent_id) || (is_approval_page && $.inArray(agent_id, mturk_agent_ids) == -1)) { $('div#message_thread').append(`