diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b6e4761 --- /dev/null +++ b/.gitignore @@ -0,0 +1,129 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..0d31b1f --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,5 @@ +# Code of Conduct + +Facebook has adopted a Code of Conduct that we expect project participants to adhere to. +Please read the [full text](https://code.fb.com/codeofconduct/) +so that you can understand what actions will and will not be tolerated. \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..fb166ff --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,31 @@ +# Contributing to Zero-Shot-DST +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `master`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to Zero-Shot-DST, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..4fae8d5 --- /dev/null +++ b/LICENSE @@ -0,0 +1,437 @@ +Attribution-NonCommercial-ShareAlike 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International +Public License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial-ShareAlike 4.0 International Public License +("Public License"). To the extent this Public License may be +interpreted as a contract, You are granted the Licensed Rights in +consideration of Your acceptance of these terms and conditions, and the +Licensor grants You such rights in consideration of benefits the +Licensor receives from making the Licensed Material available under +these terms and conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. BY-NC-SA Compatible License means a license listed at + creativecommons.org/compatiblelicenses, approved by Creative + Commons as essentially the equivalent of this Public License. + + d. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + + e. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + f. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + g. License Elements means the license attributes listed in the name + of a Creative Commons Public License. The License Elements of this + Public License are Attribution, NonCommercial, and ShareAlike. + + h. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + i. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + j. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + k. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + l. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + m. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + n. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. Additional offer from the Licensor -- Adapted Material. + Every recipient of Adapted Material from You + automatically receives an offer from the Licensor to + exercise the Licensed Rights in the Adapted Material + under the conditions of the Adapter's License You apply. + + c. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + b. ShareAlike. + + In addition to the conditions in Section 3(a), if You Share + Adapted Material You produce, the following conditions also apply. + + 1. The Adapter's License You apply must be a Creative Commons + license with the same License Elements, this version or + later, or a BY-NC-SA Compatible License. + + 2. You must include the text of, or the URI or hyperlink to, the + Adapter's License You apply. You may satisfy this condition + in any reasonable manner based on the medium, means, and + context in which You Share Adapted Material. + + 3. You may not offer or impose any additional or different terms + or conditions on, or apply any Effective Technological + Measures to, Adapted Material that restrict exercise of the + rights granted under the Adapter's License You apply. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material, + including for purposes of Section 3(b); and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..7016478 --- /dev/null +++ b/README.md @@ -0,0 +1,25 @@ +# Zero-Shot-DST +This repository includes the implementation of the **NAACL 2021** paper: + +**Leveraging Slot Descriptions for Zero-Shot Cross-Domain Dialogue StateTracking**. + +**Authors**: Zhaojiang Lin, Bing Liu, Seungwhan Moon, Paul Crook, Zhenpeng Zhou, Zhiguang Wang, Zhou Yu, Andrea Madotto, Eunjoon Cho, Rajen Subba + +## Citations +If you want to publish experimental results with our source code, please cite the following articles: +
+@inproceedings{lin2021leveraging,
+  title={Leveraging Slot Descriptions for Zero-Shot Cross-Domain Dialogue StateTracking},
+  author={Lin, Zhaojiang and Liu, Bing and Moon, Seungwhan and Crook, Paul A and Zhou, Zhenpeng and Wang, Zhiguang and Yu, Zhou and Madotto, Andrea and Cho, Eunjoon and Subba, Rajen},
+  booktitle={Proceedings of the 2021 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies},
+  pages={5640--5648},
+  year={2021}
+}
+
+ + +## Bug Report +Feel free to create an issue or send email to zlinao@connect.ust.hk + +## License +The majority of Zero-Shot DST is licensed under [CC-BY-NC-SA-4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode), however portions of the project are available under separate license terms: Transformers is licensed under the Apache 2.0 license; som-DST and MultiWoz are licensed under the MIT license; and license information for TRADE is available at https://github.com/jasonwu0731/trade-dst#license. diff --git a/T5DST/README.md b/T5DST/README.md new file mode 100644 index 0000000..5bb2b4d --- /dev/null +++ b/T5DST/README.md @@ -0,0 +1,46 @@ +# Leveraging Slot Descriptions for Zero-Shot Cross-Domain Dialogue State Tracking + +## Abstract: +Zero-shot cross-domain dialogue state tracking (DST) enables us to handle task-oriented dialogue in unseen domains without the expense of collecting in-domain data. In this paper, we propose a slot description enhanced generative approach for zero-shot cross-domain DST. Specifically, our model first encodes dialogue context and slots with a pre-trained self-attentive encoder, and generates slot values in an auto-regressive manner. In addition, we incorporate Slot Type Informed Descriptions that capture the shared information across slots to facilitate cross-domain knowledge transfer. Experimental results on the MultiWOZ dataset show that our proposed method significantly improves existing state-of-the-art results in the zero-shot cross-domain setting. + +## Method: +

+ + +

+a) Left figure: High-level description of the T5DST. The model (T5) takes dialogue history and slot name (or slot descriptions) as input, and generates the value. +b) Right figure: Slot description examples. + + +## Dependency +Check the packages needed or simply run the command +```console +❱❱❱ pip install -r utils/requirements.txt +``` + +## Experiments +**Dataset** +```console +❱❱❱ python create_data.py +``` +use create_data_2_1.py if want to run with multiwoz2.1 + +**Zero-shot cross-domain** +```console +❱❱❱ python T5.py --train_batch_size 16 --GPU 8 --except_domain ${domain} --slot_lang ${description type} +``` +* --GPU: the number of gpu to use +* --except_domain: hold out domain, choose one from [hotel, train, attraction, restaurant, taxi] +* --slot_lang: slot description type, choose one from [none, human, naive, value, question, slottype] + +**Few-shot cross-domain** +```console +❱❱❱ python T5.py --train_batch_size 16 --GPU 8 --slot_lang slottype --model_checkpoint ${checkpoint} --n_epochs 15 --fewshot 0.01 --mode finetune +``` +* --model_checkpoint: saved checkpoint of zero-shot model +* --fewshot: ratio of in-domain data, choose one from [0.01, 0.05, 0.1] + +**Full-shot** +```console +❱❱❱ python T5.py --train_batch_size 16 --GPU 8 --slot_lang slottype --n_epochs 15 +``` diff --git a/T5DST/T5.py b/T5DST/T5.py new file mode 100644 index 0000000..134a054 --- /dev/null +++ b/T5DST/T5.py @@ -0,0 +1,247 @@ +# Copyright (c) Facebook, Inc. and its affiliates + +import os, random +import torch +import argparse +import pytorch_lightning as pl +from pytorch_lightning import Trainer, seed_everything +from transformers import (AdamW, T5Tokenizer, BartTokenizer, BartForConditionalGeneration, T5ForConditionalGeneration, WEIGHTS_NAME,CONFIG_NAME) +from data_loader import prepare_data +from config import get_args +from evaluate import evaluate_metrics +import json +from tqdm import tqdm +from copy import deepcopy +import numpy as np +from collections import Counter + +# def consistency_cross_entropy(lm_logits1, lm_logits2, threshold=0.4): +# logsoftmax = torch.nn.LogSoftmax(dim=1) +# softmax = torch.nn.Softmax(dim=1) + +# lm_logits1 = lm_logits1.squeeze() +# lm_logits2 = lm_logits2.squeeze() +# # (batch, vocab_size) +# # give threshold +# prob2 = softmax(lm_logits2) +# # the result tuple of two output tensors (max, max_indices) +# # print(torch.max(prob2, dim=1)) +# prob2_max, prob2_index = torch.max(prob2, dim=1) +# valid = [] +# for i in range(prob2_max.shape[0]): +# if (prob2_index[i]==5839 and prob2_max[i]>0.9) or (prob2_index[i]!=5839 and prob2_max[i]>threshold): +# valid.append(1) +# else: +# valid.append(0) + +# #sharpening +# soft_targets = softmax(lm_logits2/0.5) + +# loss_temp = torch.sum(- soft_targets * logsoftmax(lm_logits1), 1) +# for i in range(prob2_max.shape[0]): +# if valid[i]==0: +# loss_temp[i]=0 + +# return torch.mean(loss_temp) + + + +class DST_Seq2Seq(pl.LightningModule): + + def __init__(self,args, tokenizer, model): + super().__init__() + self.tokenizer = tokenizer + self.model = model + self.lr = args["lr"] + + + def training_step(self, batch, batch_idx): + self.model.train() + (loss), *_ = self.model(input_ids=batch["encoder_input"], + attention_mask=batch["attention_mask"], + lm_labels=batch["decoder_output"] + ) + + # result = pl.TrainResult(loss) + # result.log('train_loss', loss, on_epoch=True) + return {'loss': loss, 'log': {'train_loss': loss}} + # return result + + def validation_step(self, batch, batch_idx): + self.model.eval() + (loss), *_ = self.model(input_ids=batch["encoder_input"], + attention_mask=batch["attention_mask"], + lm_labels=batch["decoder_output"] + ) + + + return {'val_loss': loss, 'log': {'val_loss': loss}} + # return result + + def validation_epoch_end(self, outputs): + val_loss_mean = sum([o['val_loss'] for o in outputs]) / len(outputs) + # show val_loss in progress bar but only log val_loss + results = {'progress_bar': {'val_loss': val_loss_mean.item()}, 'log': {'val_loss': val_loss_mean.item()}, + 'val_loss': val_loss_mean.item()} + return results + + def configure_optimizers(self): + return AdamW(self.parameters(), lr=self.lr, correct_bias=True) + + + +def train(args, *more): + args = vars(args) + args["model_name"] = args["model_checkpoint"]+args["model_name"]+"_except_domain_"+args["except_domain"]+ "_slotlang_" +str(args["slot_lang"]) + "_lr_" +str(args["lr"]) + "_epoch_" + str(args["n_epochs"]) + "_seed_" + str(args["seed"]) + # train! + seed_everything(args["seed"]) + + + if "t5" in args["model_name"]: + model = T5ForConditionalGeneration.from_pretrained(args["model_checkpoint"]) + tokenizer = T5Tokenizer.from_pretrained(args["model_checkpoint"], bos_token="[bos]", eos_token="[eos]", sep_token="[sep]") + model.resize_token_embeddings(new_num_tokens=len(tokenizer)) + elif "bart" in args["model_name"]: + model = BartForConditionalGeneration.from_pretrained(args["model_checkpoint"]) + tokenizer = BartTokenizer.from_pretrained(args["model_checkpoint"], bos_token="[bos]", eos_token="[eos]", sep_token="[sep]") + model.resize_token_embeddings(new_num_tokens=len(tokenizer)) + + task = DST_Seq2Seq(args, tokenizer, model) + + train_loader, val_loader, test_loader, ALL_SLOTS, fewshot_loader_dev, fewshot_loader_test = prepare_data(args, task.tokenizer) + + #save model path + save_path = os.path.join(args["saving_dir"],args["model_name"]) + if not os.path.exists(save_path): + os.makedirs(save_path) + + trainer = Trainer( + default_root_dir=save_path, + accumulate_grad_batches=args["gradient_accumulation_steps"], + gradient_clip_val=args["max_norm"], + max_epochs=args["n_epochs"], + callbacks=[pl.callbacks.EarlyStopping(monitor='val_loss',min_delta=0.00, patience=5,verbose=False, mode='min')], + gpus=args["GPU"], + deterministic=True, + num_nodes=1, + #precision=16, + accelerator="ddp" + ) + + trainer.fit(task, train_loader, val_loader) + + task.model.save_pretrained(save_path) + task.tokenizer.save_pretrained(save_path) + + print("test start...") + #evaluate model + _ = evaluate_model(args, task.tokenizer, task.model, test_loader, save_path, ALL_SLOTS) + +def evaluate_model(args, tokenizer, model, test_loader, save_path, ALL_SLOTS, prefix="zeroshot"): + save_path = os.path.join(save_path,"results") + if not os.path.exists(save_path): + os.makedirs(save_path) + predictions = {} + # to gpu + # gpu = args["GPU"][0] + device = torch.device("cuda:0") + model.to(device) + model.eval() + + slot_logger = {slot_name:[0,0,0] for slot_name in ALL_SLOTS} + + for batch in tqdm(test_loader): + dst_outputs = model.generate(input_ids=batch["encoder_input"].to(device), + attention_mask=batch["attention_mask"].to(device), + eos_token_id=tokenizer.eos_token_id, + max_length=200, + ) + + value_batch = tokenizer.batch_decode(dst_outputs, skip_special_tokens=True) + + for idx, value in enumerate(value_batch): + dial_id = batch["ID"][idx] + if dial_id not in predictions: + predictions[dial_id] = {} + predictions[dial_id]["domain"] = batch["domains"][idx][0] + predictions[dial_id]["turns"] = {} + if batch["turn_id"][idx] not in predictions[dial_id]["turns"]: + predictions[dial_id]["turns"][batch["turn_id"][idx]] = {"turn_belief":batch["turn_belief"][idx], "pred_belief":[]} + + if value!="none": + predictions[dial_id]["turns"][batch["turn_id"][idx]]["pred_belief"].append(str(batch["slot_text"][idx])+'-'+str(value)) + + # analyze slot acc: + if str(value)==str(batch["value_text"][idx]): + slot_logger[str(batch["slot_text"][idx])][1]+=1 # hit + slot_logger[str(batch["slot_text"][idx])][0]+=1 # total + + for slot_log in slot_logger.values(): + slot_log[2] = slot_log[1]/slot_log[0] + + with open(os.path.join(save_path, f"{prefix}_slot_acc.json"), 'w') as f: + json.dump(slot_logger,f, indent=4) + + with open(os.path.join(save_path, f"{prefix}_prediction.json"), 'w') as f: + json.dump(predictions,f, indent=4) + + joint_acc_score, F1_score, turn_acc_score = evaluate_metrics(predictions, ALL_SLOTS) + + evaluation_metrics = {"Joint Acc":joint_acc_score, "Turn Acc":turn_acc_score, "Joint F1":F1_score} + print(f"{prefix} result:",evaluation_metrics) + + with open(os.path.join(save_path, f"{prefix}_result.json"), 'w') as f: + json.dump(evaluation_metrics,f, indent=4) + + return predictions + + +def fine_tune(args, *more): + args = vars(args) + seed_everything(args["seed"]) + domains = ["hotel", "train", "restaurant", "attraction", "taxi"] + for domain in domains: + if domain in args["model_checkpoint"]: + args["only_domain"] = domain + assert args["only_domain"]!="none" + # args["model_checkpoint"] = os.path.join(args["saving_dir"],args["model_name"]) + print(args) + + if "t5" in args["model_name"]: + model = T5ForConditionalGeneration.from_pretrained(args["model_checkpoint"]) + tokenizer = T5Tokenizer.from_pretrained(args["model_checkpoint"], bos_token="[bos]", eos_token="[eos]", sep_token="[sep]") + elif "bart" in args["model_name"]: + model = BartForConditionalGeneration.from_pretrained(args["model_checkpoint"]) + tokenizer = BartTokenizer.from_pretrained(args["model_checkpoint"], bos_token="[bos]", eos_token="[eos]", sep_token="[sep]") + + task = DST_Seq2Seq(args, tokenizer, model) + train_loader, val_loader, test_loader, ALL_SLOTS, fewshot_loader_dev, fewshot_loader_test = prepare_data(args, tokenizer) + + trainer = Trainer( + default_root_dir=args["model_checkpoint"], + accumulate_grad_batches=args["gradient_accumulation_steps"], + gradient_clip_val=args["max_norm"], + max_epochs=20, + callbacks=[pl.callbacks.EarlyStopping(monitor='val_loss',min_delta=0.00, patience=8,verbose=False, mode='min')], + gpus=args["GPU"], + deterministic=True, + num_nodes=1, + # precision=16, + accelerator="ddp" + ) + + trainer.fit(task, train_loader, val_loader) + + print("test start...") + #evaluate model + ratio = "ratio_" + str(args["fewshot"]) + "_seed_" + str(args["seed"]) + _ = evaluate_model(args, task.tokenizer, task.model, test_loader, args["model_checkpoint"], ALL_SLOTS, prefix=ratio) + + + +if __name__ == "__main__": + args = get_args() + if args.mode=="train": + train(args) + if args.mode=="finetune": + fine_tune(args) diff --git a/T5DST/config.py b/T5DST/config.py new file mode 100644 index 0000000..ae0cf8c --- /dev/null +++ b/T5DST/config.py @@ -0,0 +1,33 @@ +# Copyright (c) Facebook, Inc. and its affiliates + +import argparse + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_checkpoint", type=str, default="t5-small", help="Path, url or short name of the model") + parser.add_argument("--saving_dir", type=str, default="save", help="Path for saving") + parser.add_argument("--train_batch_size", type=int, default=16, help="Batch size for training") + parser.add_argument("--meta_batch_size", type=int, default=1, help="Batch size for meta training") + parser.add_argument("--dev_batch_size", type=int, default=8, help="Batch size for validation") + parser.add_argument("--test_batch_size", type=int, default=8, help="Batch size for test") + parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients on several steps") + parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") + parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm") + parser.add_argument("--n_epochs", type=int, default=5, help="Number of training epochs") + parser.add_argument("--seed", type=int, default=557, help="Random seed") + parser.add_argument("--verbose", action='store_true', help="continual baseline") + parser.add_argument("--length", type=int, default=50, help="Batch size for validation") + parser.add_argument("--max_history", type=int, default=2, help="max number of turns in the dialogue") + parser.add_argument("--GPU", type=int, default=8, help="number of gpu to use") + parser.add_argument("--model_name", type=str, default="t5", help="use t5 or bart?") + parser.add_argument("--slot_lang", type=str, default="none", help="use 'none', 'human', 'naive', 'value', 'question', 'slottype' slot description") + parser.add_argument("--fewshot", type=float, default=0.0, help="data ratio for few shot experiment") + parser.add_argument("--fix_label", action='store_true') + parser.add_argument("--except_domain", type=str, default="none", help="hotel, train, restaurant, attraction, taxi") + parser.add_argument("--only_domain", type=str, default="none", help="hotel, train, restaurant, attraction, taxi") + parser.add_argument("--threshold", type=float, default=0.4) + parser.add_argument("--semi", action='store_true') + parser.add_argument("--mode", type=str, default="train") + + args = parser.parse_args() + return args diff --git a/T5DST/create_data.py b/T5DST/create_data.py new file mode 100644 index 0000000..e5d5709 --- /dev/null +++ b/T5DST/create_data.py @@ -0,0 +1,524 @@ +# Copyright (c) Facebook, Inc. and its affiliates + +# -*- coding: utf-8 -*- +import copy +import json +import os +import re +import shutil +import urllib.request +from collections import OrderedDict +from io import BytesIO +from zipfile import ZipFile +import difflib +import numpy as np + +np.set_printoptions(precision=3) + +np.random.seed(2) + + +''' +Most of the codes are from https://github.com/budzianowski/multiwoz +''' + + +# GLOBAL VARIABLES +DICT_SIZE = 400 +MAX_LENGTH = 50 +IGNORE_KEYS_IN_GOAL = ['eod', 'topic', 'messageLen', 'message'] + +fin = open('utils/mapping.pair','r') +replacements = [] +for line in fin.readlines(): + tok_from, tok_to = line.replace('\n', '').split('\t') + replacements.append((' ' + tok_from + ' ', ' ' + tok_to + ' ')) + + + +def clean_value(): + value_map = {["moderate -ly", ]} + +def is_ascii(s): + return all(ord(c) < 128 for c in s) + +def insertSpace(token, text): + sidx = 0 + while True: + sidx = text.find(token, sidx) + if sidx == -1: + break + if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \ + re.match('[0-9]', text[sidx + 1]): + sidx += 1 + continue + if text[sidx - 1] != ' ': + text = text[:sidx] + ' ' + text[sidx:] + sidx += 1 + if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ': + text = text[:sidx + 1] + ' ' + text[sidx + 1:] + sidx += 1 + return text + +def normalize(text, clean_value=True): + # lower case every word + text = text.lower() + + # replace white spaces in front and end + text = re.sub(r'^\s*|\s*$', '', text) + + # hotel domain pfb30 + text = re.sub(r"b&b", "bed and breakfast", text) + text = re.sub(r"b and b", "bed and breakfast", text) + + if clean_value: + # normalize phone number + ms = re.findall('\(?(\d{3})\)?[-.\s]?(\d{3})[-.\s]?(\d{4,5})', text) + if ms: + sidx = 0 + for m in ms: + sidx = text.find(m[0], sidx) + if text[sidx - 1] == '(': + sidx -= 1 + eidx = text.find(m[-1], sidx) + len(m[-1]) + text = text.replace(text[sidx:eidx], ''.join(m)) + + # normalize postcode + ms = re.findall('([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})', + text) + if ms: + sidx = 0 + for m in ms: + sidx = text.find(m, sidx) + eidx = sidx + len(m) + text = text[:sidx] + re.sub('[,\. ]', '', m) + text[eidx:] + + # weird unicode bug + text = re.sub(u"(\u2018|\u2019)", "'", text) + + if clean_value: + # replace time and and price + text = re.sub(timepat, ' [value_time] ', text) + text = re.sub(pricepat, ' [value_price] ', text) + #text = re.sub(pricepat2, '[value_price]', text) + + # replace st. + text = text.replace(';', ',') + text = re.sub('$\/', '', text) + text = text.replace('/', ' and ') + + # replace other special characters + text = text.replace('-', ' ') + text = re.sub('[\"\<>@\(\)]', '', text) # remove + + # insert white space before and after tokens: + for token in ['?', '.', ',', '!']: + text = insertSpace(token, text) + + # insert white space for 's + text = insertSpace('\'s', text) + + # replace it's, does't, you'd ... etc + text = re.sub('^\'', '', text) + text = re.sub('\'$', '', text) + text = re.sub('\'\s', ' ', text) + text = re.sub('\s\'', ' ', text) + for fromx, tox in replacements: + text = ' ' + text + ' ' + text = text.replace(fromx, tox)[1:-1] + + # remove multiple spaces + text = re.sub(' +', ' ', text) + + # concatenate numbers + tmp = text + tokens = text.split() + i = 1 + while i < len(tokens): + if re.match(u'^\d+$', tokens[i]) and \ + re.match(u'\d+$', tokens[i - 1]): + tokens[i - 1] += tokens[i] + del tokens[i] + else: + i += 1 + text = ' '.join(tokens) + + return text + +def fixDelex(filename, data, data2, idx, idx_acts): + """Given system dialogue acts fix automatic delexicalization.""" + try: + turn = data2[filename.strip('.json')][str(idx_acts)] + except: + return data + + if not isinstance(turn, str):# and not isinstance(turn, unicode): + for k, act in turn.items(): + if 'Attraction' in k: + if 'restaurant_' in data['log'][idx]['text']: + data['log'][idx]['text'] = data['log'][idx]['text'].replace("restaurant", "attraction") + if 'hotel_' in data['log'][idx]['text']: + data['log'][idx]['text'] = data['log'][idx]['text'].replace("hotel", "attraction") + if 'Hotel' in k: + if 'attraction_' in data['log'][idx]['text']: + data['log'][idx]['text'] = data['log'][idx]['text'].replace("attraction", "hotel") + if 'restaurant_' in data['log'][idx]['text']: + data['log'][idx]['text'] = data['log'][idx]['text'].replace("restaurant", "hotel") + if 'Restaurant' in k: + if 'attraction_' in data['log'][idx]['text']: + data['log'][idx]['text'] = data['log'][idx]['text'].replace("attraction", "restaurant") + if 'hotel_' in data['log'][idx]['text']: + data['log'][idx]['text'] = data['log'][idx]['text'].replace("hotel", "restaurant") + + return data + + +def getDialogueAct(filename, data, data2, idx, idx_acts): + """Given system dialogue acts fix automatic delexicalization.""" + acts = [] + try: + turn = data2[filename.strip('.json')][str(idx_acts)] + except: + return acts + + if not isinstance(turn, str): # and not isinstance(turn, unicode): + for k in turn.keys(): + + if k.split('-')[1].lower() == 'request': + for a in turn[k]: + acts.append(a[0].lower()) + elif k.split('-')[1].lower() == 'inform': + for a in turn[k]: + acts.append([a[0].lower(), normalize(a[1].lower(), False)]) + + return acts + + +def get_summary_bstate(bstate, get_domain=False): + """Based on the mturk annotations we form multi-domain belief state""" + domains = [u'taxi',u'restaurant', u'hospital', u'hotel',u'attraction', u'train', u'police'] + summary_bstate = [] + summary_bvalue = [] + active_domain = [] + for domain in domains: + domain_active = False + + booking = [] + #print(domain,len(bstate[domain]['book'].keys())) + for slot in sorted(bstate[domain]['book'].keys()): + if slot == 'booked': + if len(bstate[domain]['book']['booked'])!=0: + booking.append(1) + # summary_bvalue.append("book {} {}:{}".format(domain, slot, "Yes")) + else: + booking.append(0) + else: + if bstate[domain]['book'][slot] != "": + booking.append(1) + summary_bvalue.append(["{}-book {}".format(domain, slot.strip().lower()), normalize(bstate[domain]['book'][slot].strip().lower(), False)]) #(["book", domain, slot, bstate[domain]['book'][slot]]) + else: + booking.append(0) + if domain == 'train': + if 'people' not in bstate[domain]['book'].keys(): + booking.append(0) + if 'ticket' not in bstate[domain]['book'].keys(): + booking.append(0) + summary_bstate += booking + + for slot in bstate[domain]['semi']: + slot_enc = [0, 0, 0] # not mentioned, dontcare, filled + if bstate[domain]['semi'][slot] == 'not mentioned': + slot_enc[0] = 1 + elif bstate[domain]['semi'][slot] in ['dont care', 'dontcare', "don't care", "do not care"]: + slot_enc[1] = 1 + summary_bvalue.append(["{}-{}".format(domain, slot.strip().lower()), "dontcare"]) #(["semi", domain, slot, "dontcare"]) + elif bstate[domain]['semi'][slot]: + summary_bvalue.append(["{}-{}".format(domain, slot.strip().lower()), normalize(bstate[domain]['semi'][slot].strip().lower(), False)]) #(["semi", domain, slot, bstate[domain]['semi'][slot]]) + if slot_enc != [0, 0, 0]: + domain_active = True + summary_bstate += slot_enc + + # quasi domain-tracker + if domain_active: + summary_bstate += [1] + active_domain.append(domain) + else: + summary_bstate += [0] + + #print(len(summary_bstate)) + assert len(summary_bstate) == 94 + if get_domain: + return active_domain + else: + return summary_bstate, summary_bvalue + + +def analyze_dialogue(dialogue, maxlen): + """Cleaning procedure for all kinds of errors in text and annotation.""" + d = dialogue + # do all the necessary postprocessing + if len(d['log']) % 2 != 0: + #print path + print('odd # of turns') + return None # odd number of turns, wrong dialogue + d_pp = {} + d_pp['goal'] = d['goal'] # for now we just copy the goal + usr_turns = [] + sys_turns = [] + # last_bvs = [] + for i in range(len(d['log'])): + if len(d['log'][i]['text'].split()) > maxlen: + # print('too long') + return None # too long sentence, wrong dialogue + if i % 2 == 0: # usr turn + text = d['log'][i]['text'] + if not is_ascii(text): + # print('not ascii') + return None + usr_turns.append(d['log'][i]) + else: # sys turn + text = d['log'][i]['text'] + if not is_ascii(text): + # print('not ascii') + return None + belief_summary, belief_value_summary = get_summary_bstate(d['log'][i]['metadata']) + d['log'][i]['belief_summary'] = str(belief_summary) + d['log'][i]['belief_value_summary'] = belief_value_summary + sys_turns.append(d['log'][i]) + d_pp['usr_log'] = usr_turns + d_pp['sys_log'] = sys_turns + + return d_pp + + +def get_dial(dialogue): + """Extract a dialogue from the file""" + dial = [] + d_orig = analyze_dialogue(dialogue, MAX_LENGTH) # max turn len is 50 words + if d_orig is None: + return None + usr = [t['text'] for t in d_orig['usr_log']] + sys = [t['text'] for t in d_orig['sys_log']] + sys_a = [t['dialogue_acts'] for t in d_orig['sys_log']] + bvs = [t['belief_value_summary'] for t in d_orig['sys_log']] + domain = [t['domain'] for t in d_orig['usr_log']] + for item in zip(usr, sys, sys_a, domain, bvs): + dial.append({'usr':item[0],'sys':item[1], 'sys_a':item[2], 'domain':item[3], 'bvs':item[4]}) + return dial + + +def loadData(): + data_url = "data/multi-woz/data.json" + dataset_url = "https://www.repository.cam.ac.uk/bitstream/handle/1810/280608/MULTIWOZ2.zip?sequence=3&isAllowed=y" + if not os.path.exists("data"): + os.makedirs("data") + os.makedirs("data/multi-woz") + + if not os.path.exists(data_url): + print("Downloading and unzipping the MultiWOZ dataset") + resp = urllib.request.urlopen(dataset_url) + zip_ref = ZipFile(BytesIO(resp.read())) + zip_ref.extractall("data/multi-woz") + zip_ref.close() + shutil.copy('data/multi-woz/MULTIWOZ2 2/data.json', 'data/multi-woz/') + shutil.copy('data/multi-woz/MULTIWOZ2 2/valListFile.json', 'data/multi-woz/') + shutil.copy('data/multi-woz/MULTIWOZ2 2/testListFile.json', 'data/multi-woz/') + shutil.copy('data/multi-woz/MULTIWOZ2 2/dialogue_acts.json', 'data/multi-woz/') + + +def getDomain(idx, log, domains, last_domain): + if idx == 1: + active_domains = get_summary_bstate(log[idx]["metadata"], True) + crnt_doms = active_domains[0] if len(active_domains)!=0 else domains[0] + return crnt_doms + else: + ds_diff = get_ds_diff(log[idx-2]["metadata"], log[idx]["metadata"]) + if len(ds_diff.keys()) == 0: # no clues from dialog states + crnt_doms = last_domain + else: + crnt_doms = list(ds_diff.keys()) + # print(crnt_doms) + return crnt_doms[0] # How about multiple domains in one sentence senario ? + + +def get_ds_diff(prev_d, crnt_d): + diff = {} + # Sometimes, metadata is an empty dictionary, bug? + if not prev_d or not crnt_d: + return diff + + for ((k1, v1), (k2, v2)) in zip(prev_d.items(), crnt_d.items()): + assert k1 == k2 + if v1 != v2: # updated + diff[k2] = v2 + return diff + + +def createData(): + # download the data + loadData() + + # create dictionary of delexicalied values that then we will search against, order matters here! + # dic = delexicalize.prepareSlotValuesIndependent() + delex_data = {} + + fin1 = open('data/multi-woz/data.json', 'r') + data = json.load(fin1) + + fin2 = open('data/multi-woz/dialogue_acts.json', 'r') + data2 = json.load(fin2) + + for didx, dialogue_name in enumerate(data): + + dialogue = data[dialogue_name] + + domains = [] + for dom_k, dom_v in dialogue['goal'].items(): + if dom_v and dom_k not in IGNORE_KEYS_IN_GOAL: # check whether contains some goal entities + domains.append(dom_k) + + idx_acts = 1 + last_domain, last_slot_fill = "", [] + for idx, turn in enumerate(dialogue['log']): + # normalization, split and delexicalization of the sentence + origin_text = normalize(turn['text'], False) + # origin_text = delexicalize.markEntity(origin_text, dic) + dialogue['log'][idx]['text'] = origin_text + + if idx % 2 == 1: # if it's a system turn + + cur_domain = getDomain(idx, dialogue['log'], domains, last_domain) + last_domain = [cur_domain] + + dialogue['log'][idx - 1]['domain'] = cur_domain + dialogue['log'][idx]['dialogue_acts'] = getDialogueAct(dialogue_name, dialogue, data2, idx, idx_acts) + idx_acts += 1 + + # FIXING delexicalization: + dialogue = fixDelex(dialogue_name, dialogue, data2, idx, idx_acts) + + delex_data[dialogue_name] = dialogue + + # if didx > 10: + # break + + # with open('data/multi-woz/woz2like_data.json', 'w') as outfile: + # json.dump(delex_data, outfile) + + return delex_data + + +def buildDelexDict(origin_sent, delex_sent): + dictionary = {} + s = difflib.SequenceMatcher(None, delex_sent.split(), origin_sent.split()) + bs = s.get_matching_blocks() + for i, b in enumerate(bs): + if i < len(bs)-2: + a_start = b.a + b.size + b_start = b.b + b.size + b_end = bs[i+1].b + dictionary[a_start] = " ".join(origin_sent.split()[b_start:b_end]) + return dictionary + + +def divideData(data): + """Given test and validation sets, divide + the data for three different sets""" + testListFile = [] + fin = open('data/multi-woz/testListFile.json', 'r') + for line in fin: + testListFile.append(line[:-1]) + fin.close() + + valListFile = [] + fin = open('data/multi-woz/valListFile.json', 'r') + for line in fin: + valListFile.append(line[:-1]) + fin.close() + + trainListFile = open('data/trainListFile', 'w') + + test_dials = [] + val_dials = [] + train_dials = [] + + # dictionaries + word_freqs_usr = OrderedDict() + word_freqs_sys = OrderedDict() + + count_train, count_val, count_test = 0, 0, 0 + + ontology = {} + + for dialogue_name in data: + # print dialogue_name + dial_item = data[dialogue_name] + domains = [] + for dom_k, dom_v in dial_item['goal'].items(): + if dom_v and dom_k not in IGNORE_KEYS_IN_GOAL: # check whether contains some goal entities + domains.append(dom_k) + + turn_exmaple = {"system":"none", "user":"none", "state":{"active_intent":"none", "slot_values":{} } } + dial = get_dial(data[dialogue_name]) + if dial: + dial_example = {"dial_id":dialogue_name, "domains":list(set(domains)) ,"turns":[]} + # dialogue = {} + # dialogue['dialogue_idx'] = dialogue_name + # dialogue['domains'] = list(set(domains)) #list(set([d['domain'] for d in dial])) + # last_bs = [] + # dialogue['dialogue'] = [] + + for turn_i, turn in enumerate(dial): + # usr, usr_o, sys, sys_o, sys_a, domain + turn_exmaple = {"system":"none", "user":"none", "state":{"active_intent":"none", "slot_values":{} } } + turn_exmaple['system'] = dial[turn_i-1]['sys'] if turn_i > 0 else "none" + turn_exmaple['state']["slot_values"] = {s[0]:s[1] for s in turn['bvs']} + turn_exmaple['user'] = turn['usr'] + dial_example['turns'].append(turn_exmaple) + + for ss, vv in turn_exmaple['state']["slot_values"].items(): + if ss not in ontology: + ontology[ss] = [] + if vv not in ontology[ss]: + ontology[ss].append(vv) + + if dialogue_name in testListFile: + test_dials.append(dial_example) + count_test += 1 + elif dialogue_name in valListFile: + val_dials.append(dial_example) + count_val += 1 + else: + trainListFile.write(dialogue_name + '\n') + train_dials.append(dial_example) + count_train += 1 + + print("# of dialogues: Train {}, Val {}, Test {}".format(count_train, count_val, count_test)) + + # save all dialogues + with open('data/dev_dials.json', 'w') as f: + json.dump(val_dials, f, indent=4) + + with open('data/test_dials.json', 'w') as f: + json.dump(test_dials, f, indent=4) + + with open('data/train_dials.json', 'w') as f: + json.dump(train_dials, f, indent=4) + + with open('data/ontology.json', 'w') as f: + json.dump(ontology, f, indent=4) + + # return word_freqs_usr, word_freqs_sys + + +def main(): + print('Create WOZ-like dialogues. Get yourself a coffee, this might take a while.') + delex_data = createData() + print('Divide dialogues...') + divideData(delex_data) + # print('Building dictionaries') + # buildDictionaries(word_freqs_usr, word_freqs_sys) + + +if __name__ == "__main__": + main() diff --git a/T5DST/create_data_2_1.py b/T5DST/create_data_2_1.py new file mode 100644 index 0000000..39078a5 --- /dev/null +++ b/T5DST/create_data_2_1.py @@ -0,0 +1,534 @@ +# Copyright (c) Facebook, Inc. and its affiliates +# -*- coding: utf-8 -*- +import copy +import json +import os +import re +import shutil +import urllib.request +from collections import OrderedDict +from io import BytesIO +from zipfile import ZipFile +import difflib +import numpy as np +import argparse +from shutil import copyfile + +np.set_printoptions(precision=3) + +np.random.seed(2) + + +''' +Most of the codes are from https://github.com/budzianowski/multiwoz +''' + + +# GLOBAL VARIABLES +DICT_SIZE = 400 +MAX_LENGTH = 50 +IGNORE_KEYS_IN_GOAL = ['eod', 'topic', 'messageLen', 'message'] + +fin = open('utils/mapping.pair','r') +replacements = [] +for line in fin.readlines(): + tok_from, tok_to = line.replace('\n', '').split('\t') + replacements.append((' ' + tok_from + ' ', ' ' + tok_to + ' ')) + + +def is_ascii(s): + return all(ord(c) < 128 for c in s) + +def insertSpace(token, text): + sidx = 0 + while True: + sidx = text.find(token, sidx) + if sidx == -1: + break + if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \ + re.match('[0-9]', text[sidx + 1]): + sidx += 1 + continue + if text[sidx - 1] != ' ': + text = text[:sidx] + ' ' + text[sidx:] + sidx += 1 + if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ': + text = text[:sidx + 1] + ' ' + text[sidx + 1:] + sidx += 1 + return text + +def normalize(text, clean_value=True): + # lower case every word + text = text.lower() + + # replace white spaces in front and end + text = re.sub(r'^\s*|\s*$', '', text) + + # hotel domain pfb30 + text = re.sub(r"b&b", "bed and breakfast", text) + text = re.sub(r"b and b", "bed and breakfast", text) + + if clean_value: + # normalize phone number + ms = re.findall('\(?(\d{3})\)?[-.\s]?(\d{3})[-.\s]?(\d{4,5})', text) + if ms: + sidx = 0 + for m in ms: + sidx = text.find(m[0], sidx) + if text[sidx - 1] == '(': + sidx -= 1 + eidx = text.find(m[-1], sidx) + len(m[-1]) + text = text.replace(text[sidx:eidx], ''.join(m)) + + # normalize postcode + ms = re.findall('([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})', + text) + if ms: + sidx = 0 + for m in ms: + sidx = text.find(m, sidx) + eidx = sidx + len(m) + text = text[:sidx] + re.sub('[,\. ]', '', m) + text[eidx:] + + # weird unicode bug + text = re.sub(u"(\u2018|\u2019)", "'", text) + + if clean_value: + # replace time and and price + text = re.sub(timepat, ' [value_time] ', text) + text = re.sub(pricepat, ' [value_price] ', text) + #text = re.sub(pricepat2, '[value_price]', text) + + # replace st. + text = text.replace(';', ',') + text = re.sub('$\/', '', text) + text = text.replace('/', ' and ') + + # replace other special characters + text = text.replace('-', ' ') + text = re.sub('[\"\<>@\(\)]', '', text) # remove + + # insert white space before and after tokens: + for token in ['?', '.', ',', '!']: + text = insertSpace(token, text) + + # insert white space for 's + text = insertSpace('\'s', text) + + # replace it's, does't, you'd ... etc + text = re.sub('^\'', '', text) + text = re.sub('\'$', '', text) + text = re.sub('\'\s', ' ', text) + text = re.sub('\s\'', ' ', text) + for fromx, tox in replacements: + text = ' ' + text + ' ' + text = text.replace(fromx, tox)[1:-1] + + # remove multiple spaces + text = re.sub(' +', ' ', text) + + # concatenate numbers + tmp = text + tokens = text.split() + i = 1 + while i < len(tokens): + if re.match(u'^\d+$', tokens[i]) and \ + re.match(u'\d+$', tokens[i - 1]): + tokens[i - 1] += tokens[i] + del tokens[i] + else: + i += 1 + text = ' '.join(tokens) + + return text + +def fixDelex(filename, data, data2, idx, idx_acts): + """Given system dialogue acts fix automatic delexicalization.""" + try: + turn = data2[filename.strip('.json')][str(idx_acts)] + except: + return data + + if not isinstance(turn, str):# and not isinstance(turn, unicode): + for k, act in turn.items(): + if 'Attraction' in k: + if 'restaurant_' in data['log'][idx]['text']: + data['log'][idx]['text'] = data['log'][idx]['text'].replace("restaurant", "attraction") + if 'hotel_' in data['log'][idx]['text']: + data['log'][idx]['text'] = data['log'][idx]['text'].replace("hotel", "attraction") + if 'Hotel' in k: + if 'attraction_' in data['log'][idx]['text']: + data['log'][idx]['text'] = data['log'][idx]['text'].replace("attraction", "hotel") + if 'restaurant_' in data['log'][idx]['text']: + data['log'][idx]['text'] = data['log'][idx]['text'].replace("restaurant", "hotel") + if 'Restaurant' in k: + if 'attraction_' in data['log'][idx]['text']: + data['log'][idx]['text'] = data['log'][idx]['text'].replace("attraction", "restaurant") + if 'hotel_' in data['log'][idx]['text']: + data['log'][idx]['text'] = data['log'][idx]['text'].replace("hotel", "restaurant") + + return data + + +def getDialogueAct(filename, data, data2, idx, idx_acts): + """Given system dialogue acts fix automatic delexicalization.""" + acts = [] + try: + turn = data2[filename.strip('.json')][str(idx_acts)] + except: + return acts + + if not isinstance(turn, str): # and not isinstance(turn, unicode): + for k in turn.keys(): + # temp = [k.split('-')[0].lower(), k.split('-')[1].lower()] + # for a in turn[k]: + # acts.append(temp + [a[0].lower()]) + + if k.split('-')[1].lower() == 'request': + for a in turn[k]: + acts.append(a[0].lower()) + elif k.split('-')[1].lower() == 'inform': + for a in turn[k]: + acts.append([a[0].lower(), normalize(a[1].lower(), False)]) + + return acts + + +def get_summary_bstate(bstate, get_domain=False): + """Based on the mturk annotations we form multi-domain belief state""" + domains = [u'taxi',u'restaurant', u'hospital', u'hotel',u'attraction', u'train', u'police'] + summary_bstate = [] + summary_bvalue = [] + active_domain = [] + for domain in domains: + domain_active = False + + booking = [] + #print(domain,len(bstate[domain]['book'].keys())) + for slot in sorted(bstate[domain]['book'].keys()): + if slot == 'booked': + if len(bstate[domain]['book']['booked'])!=0: + booking.append(1) + # summary_bvalue.append("book {} {}:{}".format(domain, slot, "Yes")) + else: + booking.append(0) + else: + if bstate[domain]['book'][slot] != "": + booking.append(1) + summary_bvalue.append(["{}-book {}".format(domain, slot.strip().lower()), normalize(bstate[domain]['book'][slot].strip().lower(), False)]) #(["book", domain, slot, bstate[domain]['book'][slot]]) + else: + booking.append(0) + if domain == 'train': + if 'people' not in bstate[domain]['book'].keys(): + booking.append(0) + if 'ticket' not in bstate[domain]['book'].keys(): + booking.append(0) + summary_bstate += booking + + for slot in bstate[domain]['semi']: + slot_enc = [0, 0, 0] # not mentioned, dontcare, filled + if bstate[domain]['semi'][slot] == 'not mentioned': + slot_enc[0] = 1 + elif bstate[domain]['semi'][slot] in ['dont care', 'dontcare', "don't care", "do not care"]: + slot_enc[1] = 1 + summary_bvalue.append(["{}-{}".format(domain, slot.strip().lower()), "dontcare"]) #(["semi", domain, slot, "dontcare"]) + elif bstate[domain]['semi'][slot]: + summary_bvalue.append(["{}-{}".format(domain, slot.strip().lower()), normalize(bstate[domain]['semi'][slot].strip().lower(), False)]) #(["semi", domain, slot, bstate[domain]['semi'][slot]]) + if slot_enc != [0, 0, 0]: + domain_active = True + summary_bstate += slot_enc + + # quasi domain-tracker + if domain_active: + summary_bstate += [1] + active_domain.append(domain) + else: + summary_bstate += [0] + + #print(len(summary_bstate)) + assert len(summary_bstate) == 94 + if get_domain: + return active_domain + else: + return summary_bstate, summary_bvalue + + +def analyze_dialogue(dialogue, maxlen): + """Cleaning procedure for all kinds of errors in text and annotation.""" + d = dialogue + # do all the necessary postprocessing + if len(d['log']) % 2 != 0: + #print path + print('odd # of turns') + return None # odd number of turns, wrong dialogue + d_pp = {} + d_pp['goal'] = d['goal'] # for now we just copy the goal + usr_turns = [] + sys_turns = [] + # last_bvs = [] + for i in range(len(d['log'])): + if len(d['log'][i]['text'].split()) > maxlen: + # print('too long') + return None # too long sentence, wrong dialogue + if i % 2 == 0: # usr turn + text = d['log'][i]['text'] + if not is_ascii(text): + # print('not ascii') + return None + usr_turns.append(d['log'][i]) + else: # sys turn + text = d['log'][i]['text'] + if not is_ascii(text): + # print('not ascii') + return None + belief_summary, belief_value_summary = get_summary_bstate(d['log'][i]['metadata']) + d['log'][i]['belief_summary'] = str(belief_summary) + d['log'][i]['belief_value_summary'] = belief_value_summary + sys_turns.append(d['log'][i]) + d_pp['usr_log'] = usr_turns + d_pp['sys_log'] = sys_turns + + return d_pp + + +def get_dial(dialogue): + """Extract a dialogue from the file""" + dial = [] + d_orig = analyze_dialogue(dialogue, MAX_LENGTH) # max turn len is 50 words + if d_orig is None: + return None + usr = [t['text'] for t in d_orig['usr_log']] + sys = [t['text'] for t in d_orig['sys_log']] + sys_a = [t['dialogue_acts'] for t in d_orig['sys_log']] + bvs = [t['belief_value_summary'] for t in d_orig['sys_log']] + domain = [t['domain'] for t in d_orig['usr_log']] + for item in zip(usr, sys, sys_a, domain, bvs): + dial.append({'usr':item[0],'sys':item[1], 'sys_a':item[2], 'domain':item[3], 'bvs':item[4]}) + return dial + + +def loadData(args): + data_url = os.path.join(args.main_dir, "data.json") + if args.mwz_ver == '2.1': + dataset_url = "https://www.repository.cam.ac.uk/bitstream/handle/1810/294507/MULTIWOZ2.1.zip?sequence=1&isAllowed=y" + else: + dataset_url = "https://www.repository.cam.ac.uk/bitstream/handle/1810/280608/MULTIWOZ2.zip?sequence=3&isAllowed=y" + if not os.path.exists(args.main_dir): + os.makedirs(args.main_dir) + + if not os.path.exists(data_url): + print("Downloading and unzipping the MultiWOZ %s dataset" % args.mwz_ver) + resp = urllib.request.urlopen(dataset_url) + zip_ref = ZipFile(BytesIO(resp.read())) + zip_ref.extractall(args.main_dir) + zip_ref.close() + dir_name = 'MULTIWOZ2.1' if args.mwz_ver == '2.1' else 'MULTIWOZ2 2' + shutil.copy(os.path.join(args.main_dir, dir_name, 'data.json'), args.main_dir) + shutil.copy(os.path.join(args.main_dir, dir_name, 'ontology.json'), args.main_dir) + shutil.copy(os.path.join(args.main_dir, dir_name, 'valListFile.json'), args.main_dir) + shutil.copy(os.path.join(args.main_dir, dir_name, 'testListFile.json'), args.main_dir) + shutil.copy(os.path.join(args.main_dir, dir_name, 'dialogue_acts.json'), args.main_dir) + + +def getDomain(idx, log, domains, last_domain): + if idx == 1: + active_domains = get_summary_bstate(log[idx]["metadata"], True) + crnt_doms = active_domains[0] if len(active_domains)!=0 else domains[0] + return crnt_doms + else: + ds_diff = get_ds_diff(log[idx-2]["metadata"], log[idx]["metadata"]) + if len(ds_diff.keys()) == 0: # no clues from dialog states + crnt_doms = last_domain + else: + crnt_doms = list(ds_diff.keys()) + # print(crnt_doms) + return crnt_doms[0] # How about multiple domains in one sentence senario ? + + +def get_ds_diff(prev_d, crnt_d): + diff = {} + # Sometimes, metadata is an empty dictionary, bug? + if not prev_d or not crnt_d: + return diff + + for ((k1, v1), (k2, v2)) in zip(prev_d.items(), crnt_d.items()): + assert k1 == k2 + if v1 != v2: # updated + diff[k2] = v2 + return diff + + +def createData(args): + # download the data + loadData(args) + + # create dictionary of delexicalied values that then we will search against, order matters here! + # dic = delexicalize.prepareSlotValuesIndependent() + delex_data = {} + + fin1 = open(os.path.join(args.main_dir, 'data.json'), 'r') + data = json.load(fin1) + + fin2 = open(os.path.join(args.main_dir, 'dialogue_acts.json'), 'r') + data2 = json.load(fin2) + + for didx, dialogue_name in enumerate(data): + + dialogue = data[dialogue_name] + + domains = [] + for dom_k, dom_v in dialogue['goal'].items(): + if dom_v and dom_k not in IGNORE_KEYS_IN_GOAL: # check whether contains some goal entities + domains.append(dom_k) + + idx_acts = 1 + last_domain, last_slot_fill = "", [] + for idx, turn in enumerate(dialogue['log']): + # normalization, split and delexicalization of the sentence + origin_text = normalize(turn['text'], False) + # origin_text = delexicalize.markEntity(origin_text, dic) + dialogue['log'][idx]['text'] = origin_text + + if idx % 2 == 1: # if it's a system turn + + cur_domain = getDomain(idx, dialogue['log'], domains, last_domain) + last_domain = [cur_domain] + + dialogue['log'][idx - 1]['domain'] = cur_domain + dialogue['log'][idx]['dialogue_acts'] = getDialogueAct(dialogue_name, dialogue, data2, idx, idx_acts) + idx_acts += 1 + + # FIXING delexicalization: + dialogue = fixDelex(dialogue_name, dialogue, data2, idx, idx_acts) + + delex_data[dialogue_name] = dialogue + + # if didx > 10: + # break + + # with open('data/multi-woz/woz2like_data.json', 'w') as outfile: + # json.dump(delex_data, outfile) + + return delex_data + + +def buildDelexDict(origin_sent, delex_sent): + dictionary = {} + s = difflib.SequenceMatcher(None, delex_sent.split(), origin_sent.split()) + bs = s.get_matching_blocks() + for i, b in enumerate(bs): + if i < len(bs)-2: + a_start = b.a + b.size + b_start = b.b + b.size + b_end = bs[i+1].b + dictionary[a_start] = " ".join(origin_sent.split()[b_start:b_end]) + return dictionary + + +def divideData(data,args): + """Given test and validation sets, divide + the data for three different sets""" + os.makedirs(args.target_path,exist_ok=True) + + copyfile(os.path.join(args.main_dir,'ontology.json'), os.path.join(args.target_path,'ontology.json')) + + testListFile = [] + fin = open(os.path.join(args.main_dir,'testListFile.json'), 'r') + for line in fin: + testListFile.append(line[:-1]) + fin.close() + + valListFile = [] + fin = open(os.path.join(args.main_dir,'valListFile.json'), 'r') + for line in fin: + valListFile.append(line[:-1]) + fin.close() + + trainListFile = open(os.path.join(args.target_path,'trainListFile'), 'w') + + test_dials = [] + val_dials = [] + train_dials = [] + + # dictionaries + word_freqs_usr = OrderedDict() + word_freqs_sys = OrderedDict() + + count_train, count_val, count_test = 0, 0, 0 + + ontology = {} + + for dialogue_name in data: + # print dialogue_name + dial_item = data[dialogue_name] + domains = [] + for dom_k, dom_v in dial_item['goal'].items(): + if dom_v and dom_k not in IGNORE_KEYS_IN_GOAL: # check whether contains some goal entities + domains.append(dom_k) + + turn_exmaple = {"system":"none", "user":"none", "state":{"active_intent":"none", "slot_values":{} } } + dial = get_dial(data[dialogue_name]) + if dial: + dial_example = {"dial_id":dialogue_name, "domains":list(set(domains)) ,"turns":[]} + # dialogue = {} + # dialogue['dialogue_idx'] = dialogue_name + # dialogue['domains'] = list(set(domains)) #list(set([d['domain'] for d in dial])) + # last_bs = [] + # dialogue['dialogue'] = [] + + for turn_i, turn in enumerate(dial): + # usr, usr_o, sys, sys_o, sys_a, domain + turn_exmaple = {"system":"none", "user":"none", "state":{"active_intent":"none", "slot_values":{} } } + turn_exmaple['system'] = dial[turn_i-1]['sys'] if turn_i > 0 else "none" + turn_exmaple['state']["slot_values"] = {s[0]:s[1] for s in turn['bvs']} + turn_exmaple['user'] = turn['usr'] + dial_example['turns'].append(turn_exmaple) + + for ss, vv in turn_exmaple['state']["slot_values"].items(): + if ss not in ontology: + ontology[ss] = [] + if vv not in ontology[ss]: + ontology[ss].append(vv) + + if dialogue_name in testListFile: + test_dials.append(dial_example) + count_test += 1 + elif dialogue_name in valListFile: + val_dials.append(dial_example) + count_val += 1 + else: + trainListFile.write(dialogue_name + '\n') + train_dials.append(dial_example) + count_train += 1 + + print("# of dialogues: Train {}, Val {}, Test {}".format(count_train, count_val, count_test)) + + # save all dialogues + with open('data/dev_dials.json', 'w') as f: + json.dump(val_dials, f, indent=4) + + with open('data/test_dials.json', 'w') as f: + json.dump(test_dials, f, indent=4) + + with open('data/train_dials.json', 'w') as f: + json.dump(train_dials, f, indent=4) + + with open('data/ontology.json', 'w') as f: + json.dump(ontology, f, indent=4) + + + +def main(args): + print('Create WOZ-like dialogues. Get yourself a coffee, this might take a while.') + delex_data = createData(args) + print('Divide dialogues...') + divideData(delex_data,args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--main_dir", type=str, default='data') + parser.add_argument("--mwz_ver", type=str, default='2.1') + parser.add_argument("--target_path", type=str, default='data/mwz2.1') + args = parser.parse_args() + main(args) diff --git a/T5DST/data_loader.py b/T5DST/data_loader.py new file mode 100644 index 0000000..89290ef --- /dev/null +++ b/T5DST/data_loader.py @@ -0,0 +1,244 @@ +# Copyright (c) Facebook, Inc. and its affiliates + +import json +import torch +from torch.utils.data import DataLoader, TensorDataset, Dataset +import ast +from tqdm import tqdm +import os +import random +from functools import partial +from utils.fix_label import fix_general_label_error +from collections import OrderedDict +EXPERIMENT_DOMAINS = ["hotel", "train", "restaurant", "attraction", "taxi"] + +random.seed(577) +HISTORY_MAX_LEN = 450 +GPT_MAX_LEN = 1024 + +class DSTDataset(Dataset): + """Custom data.Dataset compatible with data.DataLoader.""" + def __init__(self, data, args): + """Reads source and target sequences from txt files.""" + self.data = data + self.args = args + + def __getitem__(self, index): + """Returns one data pair (source and target).""" + item_info = self.data[index] + if self.args["slot_lang"] == "value": + random.shuffle(item_info["value_list"]) + item_info["intput_text"] += " is " + " or ".join(item_info["value_list"]) + " or none?" + return item_info + + def __len__(self): + return len(self.data) + + + +def read_data(args, path_name, SLOTS, tokenizer, description, dataset=None): + slot_lang_list = ["description_human", "rule_description", "value_description", "rule2", "rule3"] + print(("Reading all files from {}".format(path_name))) + data = [] + domain_counter = {} + # read files + with open(path_name) as f: + dials = json.load(f) + + if dataset=="train" and args["fewshot"]>0: + random.Random(args["seed"]).shuffle(dials) + dials = dials[:int(len(dials)*args["fewshot"])] + + for dial_dict in dials: + dialog_history = "" + + # Counting domains + for domain in dial_dict["domains"]: + if domain not in EXPERIMENT_DOMAINS: + continue + if domain not in domain_counter.keys(): + domain_counter[domain] = 0 + domain_counter[domain] += 1 + + # Unseen domain setting + if args["only_domain"] != "none" and args["only_domain"] not in dial_dict["domains"]: + continue + if (args["except_domain"] != "none" and dataset == "test" and args["except_domain"] not in dial_dict["domains"]) or \ + (args["except_domain"] != "none" and dataset != "test" and [args["except_domain"]] == dial_dict["domains"]): + continue + + # Reading data + for ti, turn in enumerate(dial_dict["turns"]): + turn_id = ti + + # accumulate dialogue utterances + dialog_history += (" System: " + turn["system"] + " User: " + turn["user"]) + if args["fix_label"]: + slot_values = fix_general_label_error(turn["state"]["slot_values"],SLOTS) + else: + slot_values = turn["state"]["slot_values"] + # input: dialogue history + slot + # output: value + + # Generate domain-dependent slot list + slot_temp = SLOTS + if dataset == "train" or dataset == "dev": + if args["except_domain"] != "none": + slot_temp = [k for k in SLOTS if args["except_domain"] not in k] + slot_values = OrderedDict([(k, v) for k, v in slot_values.items() if args["except_domain"] not in k]) + elif args["only_domain"] != "none": + slot_temp = [k for k in SLOTS if args["only_domain"] in k] + slot_values = OrderedDict([(k, v) for k, v in slot_values.items() if args["only_domain"] in k]) + else: + if args["except_domain"] != "none": + slot_temp = [k for k in SLOTS if args["except_domain"] in k] + slot_values = OrderedDict([(k, v) for k, v in slot_values.items() if args["except_domain"] in k]) + elif args["only_domain"] != "none": + slot_temp = [k for k in SLOTS if args["only_domain"] in k] + slot_values = OrderedDict([(k, v) for k, v in slot_values.items() if args["only_domain"] in k]) + + + turn_belief_list = [str(k)+'-'+str(v) for k,v in slot_values.items()] + + # baseline gpt have different preprocessing, e.g., output: (slot1-value1, slot2-value2, slot3-value3, ...) + if "gpt" in args["model_name"]: + turn_slots = [] + turn_slot_values = [] + if len(dialog_history.split())>800: + continue + for slot in slot_temp: + # skip unrelevant slots for out of domain setting + if args["except_domain"] != "none" and dataset !="test": + if slot.split("-")[0] not in dial_dict["domains"]: + continue + input_text = dialog_history + f" {tokenizer.sep_token} {slot}" + " " + tokenizer.bos_token + output_text = input_text+ " " + turn["state"]["slot_values"].get(slot, 'none').strip() + " " + tokenizer.eos_token + slot_text = slot + value_text = turn["state"]["slot_values"].get(slot, 'none').strip() + + data_detail = { + "ID":dial_dict["dial_id"], + "domains":dial_dict["domains"], + "turn_id":turn_id, + "dialog_history":dialog_history, + "turn_belief":turn_belief_list, + "intput_text":input_text, + "output_text":output_text, + "slot_text":slot_text, + "value_text":value_text + } + data.append(data_detail) + + else: + for slot in slot_temp: + + # skip unrelevant slots for out of domain setting + if args["except_domain"] != "none" and dataset !="test": + if slot.split("-")[0] not in dial_dict["domains"]: + continue + + output_text = slot_values.get(slot, 'none').strip() + f" {tokenizer.eos_token}" + slot_text = slot + value_text = slot_values.get(slot, 'none').strip() + + if args["slot_lang"]=="human": + slot_lang = description[slot]["description_human"] + input_text = dialog_history + f" {tokenizer.sep_token} {slot_lang}?" + elif args["slot_lang"]=="naive": + slot_lang = description[slot]["naive"] + input_text = dialog_history + f" {tokenizer.sep_token} {slot_lang}?" + elif args["slot_lang"]=="value": + slot_lang = description[slot]["naive"] + input_text = dialog_history + f" {tokenizer.sep_token} {slot_lang}" + elif args["slot_lang"]=="question": + slot_lang = description[slot]["question"] + input_text = dialog_history + f" {tokenizer.sep_token} {slot_lang}" + elif args["slot_lang"]=="slottype": + slot_lang = description[slot]["slottype"] + input_text = dialog_history + f" {tokenizer.sep_token} {slot_lang}?" + else: + input_text = dialog_history + f" {tokenizer.sep_token} {slot}" + + data_detail = { + "ID":dial_dict["dial_id"], + "domains":dial_dict["domains"], + "turn_id":turn_id, + "dialog_history":dialog_history, + "turn_belief":turn_belief_list, + "intput_text":input_text, + "output_text":output_text, + "slot_text":slot_text, + "value_text":value_text, + "value_list":description[slot]["values"] + } + data.append(data_detail) + # print(len(data)) + for idx in range(10): + print(data[idx]) + print("domain_counter", domain_counter) + return data, slot_temp + + + +def get_slot_information(ontology): + ontology_domains = dict([(k, v) for k, v in ontology.items() if k.split("-")[0] in EXPERIMENT_DOMAINS]) + SLOTS = [k.replace(" ","").lower() if ("book" not in k) else k.lower() for k in ontology_domains.keys()] + + return SLOTS + + +def gpt_collate_fn(data,tokenizer): + batch_data = {} + for key in data[0]: + batch_data[key] = [d[key] for d in data] + + output_batch = tokenizer(batch_data["output_text"], padding=True, return_tensors="pt", add_special_tokens=False, return_attention_mask=False, truncation=True, max_length=1000) + batch_data["input_ids"] = output_batch['input_ids'] + return batch_data + + +def collate_fn(data, tokenizer): + batch_data = {} + for key in data[0]: + batch_data[key] = [d[key] for d in data] + + input_batch = tokenizer(batch_data["intput_text"], padding=True, return_tensors="pt", add_special_tokens=False, verbose=False) + batch_data["encoder_input"] = input_batch["input_ids"] + batch_data["attention_mask"] = input_batch["attention_mask"] + output_batch = tokenizer(batch_data["output_text"], padding=True, return_tensors="pt", add_special_tokens=False, return_attention_mask=False) + # replace the padding id to -100 for cross-entropy + output_batch['input_ids'].masked_fill_(output_batch['input_ids']==tokenizer.pad_token_id, -100) + batch_data["decoder_output"] = output_batch['input_ids'] + + return batch_data + + +def prepare_data(args, tokenizer): + path_train = 'data/train_dials.json' + path_dev = 'data/dev_dials.json' + path_test = 'data/test_dials.json' + + ontology = json.load(open("data/multi-woz/MULTIWOZ2 2/ontology.json", 'r')) + ALL_SLOTS = get_slot_information(ontology) + description = json.load(open("utils/slot_description.json", 'r')) + + data_train, _ = read_data(args, path_train, ALL_SLOTS, tokenizer, description, "train") + data_dev, _ = read_data(args, path_dev, ALL_SLOTS, tokenizer, description, "dev") + data_test, ALL_SLOTS = read_data(args, path_test, ALL_SLOTS, tokenizer, description, "test") + + + train_dataset = DSTDataset(data_train, args) + dev_dataset = DSTDataset(data_dev, args) + test_dataset = DSTDataset(data_test, args) + + if "gpt" in args["model_name"]: + train_loader = DataLoader(train_dataset, batch_size=args["train_batch_size"], shuffle=True, collate_fn=partial(gpt_collate_fn, tokenizer=tokenizer), num_workers=16) + test_loader = DataLoader(test_dataset, batch_size=args["test_batch_size"], shuffle=False, collate_fn=partial(gpt_collate_fn, tokenizer=tokenizer), num_workers=16) + dev_loader = DataLoader(dev_dataset, batch_size=args["dev_batch_size"], shuffle=False, collate_fn=partial(gpt_collate_fn, tokenizer=tokenizer), num_workers=16) + else: + train_loader = DataLoader(train_dataset, batch_size=args["train_batch_size"], shuffle=True, collate_fn=partial(collate_fn, tokenizer=tokenizer), num_workers=16) + test_loader = DataLoader(test_dataset, batch_size=args["test_batch_size"], shuffle=False, collate_fn=partial(collate_fn, tokenizer=tokenizer), num_workers=16) + dev_loader = DataLoader(dev_dataset, batch_size=args["dev_batch_size"], shuffle=False, collate_fn=partial(collate_fn, tokenizer=tokenizer), num_workers=16) + fewshot_loader_dev=None + fewshot_loader_test=None + return train_loader, dev_loader, test_loader, ALL_SLOTS, fewshot_loader_dev, fewshot_loader_test diff --git a/T5DST/evaluate.py b/T5DST/evaluate.py new file mode 100644 index 0000000..cf0804f --- /dev/null +++ b/T5DST/evaluate.py @@ -0,0 +1,86 @@ +# Copyright (c) Facebook, Inc. and its affiliates + +import json +# Strict match evaluation from https://github.com/jasonwu0731/trade-dst/blob/master/models/TRADE.py +# check utils/prediction_sample.json for the format of predictions +EXPERIMENT_DOMAINS = ["hotel", "train", "restaurant", "attraction", "taxi"] + +def compute_acc(gold, pred, slot_temp): + miss_gold = 0 + miss_slot = [] + for g in gold: + if g not in pred: + miss_gold += 1 + miss_slot.append(g.rsplit("-", 1)[0]) + wrong_pred = 0 + for p in pred: + if p not in gold and p.rsplit("-", 1)[0] not in miss_slot: + wrong_pred += 1 + ACC_TOTAL = len(slot_temp) + ACC = len(slot_temp) - miss_gold - wrong_pred + ACC = ACC / float(ACC_TOTAL) + return ACC + +def compute_prf(gold, pred): + TP, FP, FN = 0, 0, 0 + if len(gold)!= 0: + count = 1 + for g in gold: + if g in pred: + TP += 1 + else: + FN += 1 + for p in pred: + if p not in gold: + FP += 1 + precision = TP / float(TP+FP) if (TP+FP)!=0 else 0 + recall = TP / float(TP+FN) if (TP+FN)!=0 else 0 + F1 = 2 * precision * recall / float(precision + recall) if (precision+recall)!=0 else 0 + else: + if len(pred)==0: + precision, recall, F1, count = 1, 1, 1, 1 + else: + precision, recall, F1, count = 0, 0, 0, 1 + return F1, recall, precision, count + +def evaluate_metrics(all_prediction, SLOT_LIST): + total, turn_acc, joint_acc, F1_pred, F1_count = 0, 0, 0, 0, 0 + for idx, dial in all_prediction.items(): + for k, cv in dial["turns"].items(): + if set(cv["turn_belief"]) == set(cv["pred_belief"]): + joint_acc += 1 + else: + print(cv["turn_belief"]) + print(cv["pred_belief"]) + print("==================") + total += 1 + + # Compute prediction slot accuracy + temp_acc = compute_acc(set(cv["turn_belief"]), set(cv["pred_belief"]), SLOT_LIST) + turn_acc += temp_acc + + # Compute prediction joint F1 score + temp_f1, temp_r, temp_p, count = compute_prf(set(cv["turn_belief"]), set(cv["pred_belief"])) + F1_pred += temp_f1 + F1_count += count + + joint_acc_score = joint_acc / float(total) if total!=0 else 0 + turn_acc_score = turn_acc / float(total) if total!=0 else 0 + F1_score = F1_pred / float(F1_count) if F1_count!=0 else 0 + return joint_acc_score, F1_score, turn_acc_score + +def get_slot_information(ontology): + ontology_domains = dict([(k, v) for k, v in ontology.items() if k.split("-")[0] in EXPERIMENT_DOMAINS]) + SLOTS = [k.replace(" ","").lower() if ("book" not in k) else k.lower() for k in ontology_domains.keys()] + return SLOTS + +if __name__ == "__main__": + ontology = json.load(open("data/multi-woz/MULTIWOZ2 2/ontology.json", 'r')) + ALL_SLOTS = get_slot_information(ontology) + with open("save/t5/results/zeroshot_prediction.json") as f: + prediction = json.load(f) + + joint_acc_score, F1_score, turn_acc_score = evaluate_metrics(prediction, ontology) + + evaluation_metrics = {"Joint Acc":joint_acc_score, "Turn Acc":turn_acc_score, "Joint F1":F1_score} + print(evaluation_metrics) diff --git a/T5DST/figures/diagram.png b/T5DST/figures/diagram.png new file mode 100644 index 0000000..153a804 Binary files /dev/null and b/T5DST/figures/diagram.png differ diff --git a/T5DST/figures/slotdesc.png b/T5DST/figures/slotdesc.png new file mode 100644 index 0000000..00137b1 Binary files /dev/null and b/T5DST/figures/slotdesc.png differ diff --git a/T5DST/train_GPT.py b/T5DST/train_GPT.py new file mode 100644 index 0000000..78a4118 --- /dev/null +++ b/T5DST/train_GPT.py @@ -0,0 +1,188 @@ +# Copyright (c) Facebook, Inc. and its affiliates + +import os, random +import torch +import argparse +import pytorch_lightning as pl +from pytorch_lightning import Trainer, seed_everything +from transformers import (AdamW, GPT2Tokenizer, GPT2LMHeadModel) +from data_loader import prepare_data +from config import get_args +from evaluate import evaluate_metrics +import json +from tqdm import tqdm +from copy import deepcopy +import numpy as np + + +def greedy_decode(input_text, tokenizer, model, args, max_length, current_output=None): + if current_output is None: + current_output = [] + + input_ids = tokenizer.encode(input_text, add_special_tokens=False) + with torch.no_grad(): + for i in range(max_length): + input_tensor = torch.tensor(input_ids+current_output, device=torch.device("cuda:0")).unsqueeze(0) + logits = model(input_tensor) + if isinstance(logits, tuple): # for gpt2 and maybe others + logits = logits[0] + predicted_index = torch.argmax(logits[0, -1, :]).item() + + if predicted_index==tokenizer.eos_token_id: + break + current_output.append(predicted_index) + + output_text = tokenizer.decode(current_output) + return output_text + + + + +class DST_GPT(pl.LightningModule): + + def __init__(self,args, tokenizer, model): + super().__init__() + self.tokenizer = tokenizer + self.model = model + self.lr = args["lr"] + self.args = args + + def training_step(self, batch, batch_idx): + self.model.train() + # follow https://github.com/salesforce/simpletod/blob/917f66afe7f37e75de246949423fc4470a2427c4/main.py#L148 + (loss), *_ = self.model(input_ids=batch["input_ids"], labels=batch["input_ids"]) + + return {'loss': loss, 'log': {'train_loss': loss}} + + + def validation_step(self, batch, batch_idx): + self.model.eval() + (loss), *_ = self.model(input_ids=batch["input_ids"], labels=batch["input_ids"]) + + return {'val_loss': loss, 'log': {'val_loss': loss}} + # return result + + def validation_epoch_end(self, outputs): + val_loss_mean = sum([o['val_loss'] for o in outputs]) / len(outputs) + # show val_loss in progress bar but only log val_loss + results = {'progress_bar': {'val_loss': val_loss_mean.item()}, 'log': {'val_loss': val_loss_mean.item()}, + 'val_loss': val_loss_mean.item()} + return results + + + def configure_optimizers(self): + return AdamW(self.parameters(), lr=self.lr, correct_bias=True) + + + +def train(args, *more): + # # train! + args = vars(args) + args["model_name"] = args["model_checkpoint"]+args["model_name"]+"_except_domain_"+args["except_domain"]+ "_slotlang_" +str(args["slot_lang"]) + "_lr_" +str(args["lr"]) + "_epoch_" + str(args["n_epochs"]) + "_seed_" + str(args["seed"]) + # train! + seed_everything(args["seed"]) + + model = GPT2LMHeadModel.from_pretrained(args["model_checkpoint"]) + tokenizer = GPT2Tokenizer.from_pretrained(args["model_checkpoint"], bos_token="[bos]", eos_token="[eos]", sep_token="[sep]", pad_token = "[pad]") + model.resize_token_embeddings(new_num_tokens=len(tokenizer)) + + task = DST_GPT(args, tokenizer, model) + train_loader, val_loader, test_loader, ALL_SLOTS, fewshot_loader_dev, fewshot_loader_test = prepare_data(args, task.tokenizer) + + #save model + save_path = os.path.join(args["saving_dir"],args["model_name"]) + if not os.path.exists(save_path): + os.makedirs(save_path) + + trainer = Trainer( + default_root_dir=save_path, + accumulate_grad_batches=args["gradient_accumulation_steps"], + gradient_clip_val=args["max_norm"], + max_epochs=args["n_epochs"], + callbacks=[pl.callbacks.EarlyStopping(monitor='val_loss',min_delta=0.00, patience=5,verbose=False, mode='min')], + gpus=args["GPU"], + deterministic=True, + num_nodes=1, + accelerator="ddp" + ) + + trainer.fit(task, train_loader, val_loader) + + + task.model.save_pretrained(save_path) + task.tokenizer.save_pretrained(save_path) + + print("test start...") + #evaluate model + _ = evaluate_model(args, task.tokenizer, task.model, test_loader, save_path, ALL_SLOTS) + + + +def evaluate_model(args, tokenizer, model, test_loader, save_path, ALL_SLOTS, prefix="zeroshot"): + save_path = os.path.join(save_path,"results") + if not os.path.exists(save_path): + os.makedirs(save_path) + predictions = {} + # to gpu + device = torch.device("cuda:0") + model.to(device) + model.eval() + + slot_logger = {slot_name:[0,0,0] for slot_name in ALL_SLOTS} + + for batch in tqdm(test_loader): + for idx, input_text in enumerate(batch["intput_text"]): + dst_text = greedy_decode(input_text, tokenizer, model, args, max_length=200) + value = dst_text.strip() + dial_id = batch["ID"][idx] + + if dial_id not in predictions: + predictions[dial_id] = {} + predictions[dial_id]["domain"] = batch["domains"][idx][0] + predictions[dial_id]["turns"] = {} + if batch["turn_id"][idx] not in predictions[dial_id]["turns"]: + predictions[dial_id]["turns"][batch["turn_id"][idx]] = {"turn_belief":batch["turn_belief"][idx], "pred_belief":[]} + + if value!="none": + predictions[dial_id]["turns"][batch["turn_id"][idx]]["pred_belief"].append(str(batch["slot_text"][idx])+'-'+str(value)) + + # dst_text = greedy_decode(input_text, tokenizer, model, args, max_length=200) + # slot_values = dst_text.strip() + # dial_id = batch["ID"][idx] + + # if dial_id not in predictions: + # predictions[dial_id] = {} + # predictions[dial_id]["domain"] = batch["domains"][idx][0] + # predictions[dial_id]["turns"] = {} + # if batch["turn_id"][idx] not in predictions[dial_id]["turns"]: + # predictions[dial_id]["turns"][batch["turn_id"][idx]] = {"turn_belief":batch["turn_belief"][idx], "pred_belief":[]} + # # print(slot_values) + # # print(slot_values.split(", ")) + # for slot_value in slot_values.split(", "): + # value = slot_value.split("-")[-1] + # # print(value) + # if value!="none": + # predictions[dial_id]["turns"][batch["turn_id"][idx]]["pred_belief"].append(slot_value) + + + with open(os.path.join(save_path, f"{prefix}_prediction.json"), 'w') as f: + json.dump(predictions,f, indent=4) + + joint_acc_score, F1_score, turn_acc_score = evaluate_metrics(predictions, ALL_SLOTS) + + evaluation_metrics = {"Joint Acc":joint_acc_score, "Turn Acc":turn_acc_score, "Joint F1":F1_score} + print(f"{prefix} result:",evaluation_metrics) + + with open(os.path.join(save_path, f"{prefix}_result.json"), 'w') as f: + json.dump(evaluation_metrics,f, indent=4) + + return predictions + + + +if __name__ == "__main__": + args = get_args() + train(args) + + + # evaluate() diff --git a/T5DST/utils/analysis.py b/T5DST/utils/analysis.py new file mode 100644 index 0000000..7329e97 --- /dev/null +++ b/T5DST/utils/analysis.py @@ -0,0 +1,124 @@ +# Copyright (c) Facebook, Inc. and its affiliates + +import os, random +import torch +import argparse +import pytorch_lightning as pl +from pytorch_lightning import Trainer, seed_everything +from transformers import (AdamW, T5Tokenizer, T5ForConditionalGeneration) +from data_loader import prepare_data +from config import get_args +from evaluate import evaluate_metrics +import json +from tqdm import tqdm +from copy import deepcopy +import numpy as np +from collections import Counter + +class DST_Seq2Seq(pl.LightningModule): + + def __init__(self,args, tokenizer, model): + super().__init__() + self.tokenizer = tokenizer + self.model = model + self.lr = args["lr"] + + + def training_step(self, batch, batch_idx): + self.model.train() + (loss), *_ = self.model(input_ids=batch["encoder_input"], + attention_mask=batch["attention_mask"], + lm_labels=batch["decoder_output"] + ) + + # result = pl.TrainResult(loss) + # result.log('train_loss', loss, on_epoch=True) + return {'loss': loss, 'log': {'train_loss': loss}} + # return result + + def validation_step(self, batch, batch_idx): + self.model.eval() + (loss), *_ = self.model(input_ids=batch["encoder_input"], + attention_mask=batch["attention_mask"], + lm_labels=batch["decoder_output"] + ) + + + return {'val_loss': loss, 'log': {'val_loss': loss}} + # return result + + def validation_epoch_end(self, outputs): + val_loss_mean = sum([o['val_loss'] for o in outputs]) / len(outputs) + # show val_loss in progress bar but only log val_loss + results = {'progress_bar': {'val_loss': val_loss_mean.item()}, 'log': {'val_loss': val_loss_mean.item()}, + 'val_loss': val_loss_mean.item()} + return results + + def configure_optimizers(self): + return AdamW(self.parameters(), lr=self.lr, correct_bias=True) + +def analysis(args): + args = vars(args) + # args["model_checkpoint"] = "trained/t5-smallt5_except_domain_train_slotlang_rule2_lr_0.0001_epoch_5_seed_555" + model = T5ForConditionalGeneration.from_pretrained(args["model_checkpoint"]) + tokenizer = T5Tokenizer.from_pretrained(args["model_checkpoint"], bos_token="[bos]", eos_token="[eos]", sep_token="[sep]") + model.resize_token_embeddings(new_num_tokens=len(tokenizer)) + train_loader, val_loader, test_loader, ALL_SLOTS, fewshot_loader_dev, fewshot_loader_test = prepare_data(args, tokenizer) + device = torch.device("cuda:0") + + # model.load_state_dict(torch.load("trained/t5-smallt5_except_domain_train_slotlang_none_lr_0.0001_epoch_5_seed_555/pytorch_model.bin")) + model.to(device) + model.eval() + count = 0 + for batch in test_loader: + decoder_input = torch.full((batch["encoder_input"].shape[0], 1), model.config.decoder_start_token_id, dtype=torch.long, device=device) + + # dst_outputs = model.generate(input_ids=batch["encoder_input"].to(device), + # attention_mask=batch["attention_mask"].to(device), + # eos_token_id=tokenizer.eos_token_id, + # max_length=200, + # ) + # if batch["value_text"][0]!="none": + # print(batch["intput_text"][0]) + # value_batch = tokenizer.batch_decode(dst_outputs, skip_special_tokens=True) + # print(value_batch) + outputs = model(input_ids=batch["encoder_input"].to(device), + attention_mask=batch["attention_mask"].to(device), + decoder_input_ids=decoder_input, + return_dict=True, + output_attentions=True, + ) + if batch["value_text"][0]!="none": + print(batch["intput_text"][0]) + tokens = tokenizer.convert_ids_to_tokens(batch["encoder_input"][0]) + max_id = torch.argmax(torch.sum(outputs.cross_attentions[1], 1).squeeze()).item() + + weights = torch.sum(outputs.cross_attentions[1], 1).squeeze().cpu().tolist() + bukets = [] + for i in range(len(tokens)): + bukets.append((tokens[i], weights[i])) + # bukets.sort(key=lambda x: x[1]) + print(bukets[max(max_id-1,0)]) + print(bukets[max_id]) + print(bukets[max_id+1]) + count+=1 + if count>30: + exit(0) + + + # print(batch["encoder_input"].shape) + # torch.sum(outputs.cross_attentions[0], 1).squeeze().cpu().tolist() + #print(torch.sum(outputs.cross_attentions[0], 1).squeeze().cpu().tolist()) + + #print(torch.sum(outputs.cross_attentions[1], 1).squeeze()) + + + + + + +if __name__ == "__main__": + args = get_args() + analysis(args) + +# python analysis.py --test_batch_size 1 --model_checkpoint trained/t5-smallt5_except_domain_train_slotlang_rule2_lr_0.0001_epoch_5_seed_555 --except_domain train --slot_lang rule2 diff --git a/T5DST/utils/fix_label.py b/T5DST/utils/fix_label.py new file mode 100644 index 0000000..d72bae3 --- /dev/null +++ b/T5DST/utils/fix_label.py @@ -0,0 +1,67 @@ +# Copyright (c) Facebook, Inc. and its affiliates + +# from TRADE +def fix_general_label_error(labels, slots): + label_dict = labels + GENERAL_TYPO = { + # type + "guesthouse":"guest house", "guesthouses":"guest house", "guest":"guest house", "mutiple sports":"multiple sports", + "sports":"multiple sports", "mutliple sports":"multiple sports","swimmingpool":"swimming pool", "concerthall":"concert hall", + "concert":"concert hall", "pool":"swimming pool", "night club":"nightclub", "mus":"museum", "ol":"architecture", + "colleges":"college", "coll":"college", "architectural":"architecture", "musuem":"museum", "churches":"church", + # area + "center":"centre", "center of town":"centre", "near city center":"centre", "in the north":"north", "cen":"centre", "east side":"east", + "east area":"east", "west part of town":"west", "ce":"centre", "town center":"centre", "centre of cambridge":"centre", + "city center":"centre", "the south":"south", "scentre":"centre", "town centre":"centre", "in town":"centre", "north part of town":"north", + "centre of town":"centre", "cb30aq": "none", + # price + "mode":"moderate", "moderate -ly": "moderate", "mo":"moderate", + # day + "next friday":"friday", "monda": "monday", + # parking + "free parking":"free", + # internet + "free internet":"yes", + # star + "4 star":"4", "4 stars":"4", "0 star rarting":"none", + # others + "y":"yes", "any":"dontcare", "n":"no", "does not care":"dontcare", "not men":"none", "not":"none", "not mentioned":"none", + '':"none", "not mendtioned":"none", "3 .":"3", "does not":"no", "fun":"none", "art":"none", + } + + for slot in slots: + if slot in label_dict.keys(): + # general typos + if label_dict[slot] in GENERAL_TYPO.keys(): + label_dict[slot] = label_dict[slot].replace(label_dict[slot], GENERAL_TYPO[label_dict[slot]]) + + # miss match slot and value + if slot == "hotel-type" and label_dict[slot] in ["nigh", "moderate -ly priced", "bed and breakfast", "centre", "venetian", "intern", "a cheap -er hotel"] or \ + slot == "hotel-internet" and label_dict[slot] == "4" or \ + slot == "hotel-pricerange" and label_dict[slot] == "2" or \ + slot == "attraction-type" and label_dict[slot] in ["gastropub", "la raza", "galleria", "gallery", "science", "m"] or \ + "area" in slot and label_dict[slot] in ["moderate"] or \ + "day" in slot and label_dict[slot] == "t": + label_dict[slot] = "none" + elif slot == "hotel-type" and label_dict[slot] in ["hotel with free parking and free wifi", "4", "3 star hotel"]: + label_dict[slot] = "hotel" + elif slot == "hotel-star" and label_dict[slot] == "3 star hotel": + label_dict[slot] = "3" + elif "area" in slot: + if label_dict[slot] == "no": label_dict[slot] = "north" + elif label_dict[slot] == "we": label_dict[slot] = "west" + elif label_dict[slot] == "cent": label_dict[slot] = "centre" + elif "day" in slot: + if label_dict[slot] == "we": label_dict[slot] = "wednesday" + elif label_dict[slot] == "no": label_dict[slot] = "none" + elif "price" in slot and label_dict[slot] == "ch": + label_dict[slot] = "cheap" + elif "internet" in slot and label_dict[slot] == "free": + label_dict[slot] = "yes" + + # some out-of-define classification slot values + if slot == "restaurant-area" and label_dict[slot] in ["stansted airport", "cambridge", "silver street"] or \ + slot == "attraction-area" and label_dict[slot] in ["norwich", "ely", "museum", "same area as hotel"]: + label_dict[slot] = "none" + + return label_dict diff --git a/T5DST/utils/generate_slot_desp.py b/T5DST/utils/generate_slot_desp.py new file mode 100644 index 0000000..3b8cb56 --- /dev/null +++ b/T5DST/utils/generate_slot_desp.py @@ -0,0 +1,55 @@ +# Copyright (c) Facebook, Inc. and its affiliates + +import json + +# EXPERIMENT_DOMAINS = ["hotel", "train", "restaurant", "attraction", "taxi"] + + +with open("slot_description.json", 'r') as f: + ontology = json.load(f) + + +slot_map = {"pricerange": "price range", "arriveby": "arrive by", "leaveat": "leave at"} +slot_types = {str(["book stay", "book people", "stars"]):"number of ", str(["parking", "internet"]):"whether have ", str(["destination", "departure"]):"location of ", str(["arriveby", "leaveat"]):"time of "} + + +# Naive descriptions +for domain_slot in ontology: + domain, slot = domain_slot.split("-") + if slot in slot_map: + slot = slot_map[slot] + if "book" in domain_slot: + slot = slot.replace("book ", "") + ontology[domain_slot]["naive"] = f"{slot} for the {domain} booking" + else: + ontology[domain_slot]["naive"] = f"{slot} of the {domain}" + + +# question +for domain_slot in ontology: + domain, slot = domain_slot.split("-") + ontology[domain_slot]["question"] = f"What is the {slot} of the {domain} that the user in interested in?" + + +# Slot Type +for domain_slot in ontology: + domain, slot = domain_slot.split("-") + slot_name = slot + if slot in slot_map: + slot_name = slot_map[slot] + prefix = "" + for slot_list, slot_type in slot_types.items(): + if slot in slot_list: + prefix = slot_type + + if "book" in domain_slot: + slot_name = slot_name.replace("book ", "") + ontology[domain_slot]["slottype"] = f"{prefix}{slot_name} for the {domain} booking" + elif prefix=="whether have ": + ontology[domain_slot]["slottype"] = f"{prefix}{slot_name} in the {domain}" + else: + ontology[domain_slot]["slottype"] = f"{prefix}{slot_name} of the {domain}" + + +with open('slot_description.json', 'w') as f: + json.dump(ontology, f, indent=4) diff --git a/T5DST/utils/mapping.pair b/T5DST/utils/mapping.pair new file mode 100644 index 0000000..34df41d --- /dev/null +++ b/T5DST/utils/mapping.pair @@ -0,0 +1,83 @@ +it's it is +don't do not +doesn't does not +didn't did not +you'd you would +you're you are +you'll you will +i'm i am +they're they are +that's that is +what's what is +couldn't could not +i've i have +we've we have +can't cannot +i'd i would +i'd i would +aren't are not +isn't is not +wasn't was not +weren't were not +won't will not +there's there is +there're there are +. . . +restaurants restaurant -s +hotels hotel -s +laptops laptop -s +cheaper cheap -er +dinners dinner -s +lunches lunch -s +breakfasts breakfast -s +expensively expensive -ly +moderately moderate -ly +cheaply cheap -ly +prices price -s +places place -s +venues venue -s +ranges range -s +meals meal -s +locations location -s +areas area -s +policies policy -s +children child -s +kids kid -s +kidfriendly kid friendly +cards card -s +upmarket expensive +inpricey cheap +inches inch -s +uses use -s +dimensions dimension -s +driverange drive range +includes include -s +computers computer -s +machines machine -s +families family -s +ratings rating -s +constraints constraint -s +pricerange price range +batteryrating battery rating +requirements requirement -s +drives drive -s +specifications specification -s +weightrange weight range +harddrive hard drive +batterylife battery life +businesses business -s +hours hour -s +one 1 +two 2 +three 3 +four 4 +five 5 +six 6 +seven 7 +eight 8 +nine 9 +ten 10 +eleven 11 +twelve 12 +anywhere any where +good bye goodbye diff --git a/T5DST/utils/requirements.txt b/T5DST/utils/requirements.txt new file mode 100644 index 0000000..e9e022e --- /dev/null +++ b/T5DST/utils/requirements.txt @@ -0,0 +1,3 @@ +torch==1.7.0 +transformers==3.1.0 +pytorch-lightning==1.0.3 diff --git a/T5DST/utils/slot_description.json b/T5DST/utils/slot_description.json new file mode 100644 index 0000000..0d1cc1e --- /dev/null +++ b/T5DST/utils/slot_description.json @@ -0,0 +1,344 @@ +{ + "hotel-pricerange": { + "description_human": "price budget of the hotel", + "values": [ + "cheap", + "dontcare", + "expensive", + "moderate" + ], + "naive": "price range of the hotel", + "question": "What is the pricerange of the hotel that the user in interested in?", + "slottype": "price range of the hotel" + }, + "hotel-type": { + "description_human": "what is the type of the hotel", + "values": [ + "guesthouse", + "hotel" + ], + "naive": "type of the hotel", + "question": "What is the type of the hotel that the user in interested in?", + "slottype": "type of the hotel" + }, + "hotel-parking": { + "description_human": "whether the hotel has parking", + "values": [ + "dontcare", + "free", + "no", + "yes" + ], + "naive": "parking of the hotel", + "question": "What is the parking of the hotel that the user in interested in?", + "slottype": "whether have parking in the hotel" + }, + "hotel-book stay": { + "description_human": "length of stay at the hotel", + "values": [ + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8" + ], + "naive": "stay for the hotel booking", + "question": "What is the book stay of the hotel that the user in interested in?", + "slottype": "number of stay for the hotel booking" + }, + "hotel-book day": { + "description_human": "day of the hotel booking", + "values": [ + "friday", + "monday", + "saturday", + "sunday", + "thursday", + "tuesday", + "wednesday" + ], + "naive": "day for the hotel booking", + "question": "What is the book day of the hotel that the user in interested in?", + "slottype": "day for the hotel booking" + }, + "hotel-book people": { + "description_human": "number of people for the hotel booking", + "values": [ + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8" + ], + "naive": "people for the hotel booking", + "question": "What is the book people of the hotel that the user in interested in?", + "slottype": "number of people for the hotel booking" + }, + "hotel-area": { + "description_human": "area or place of the hotel", + "values": [ + "centre", + "dontcare", + "east", + "north", + "south", + "west" + ], + "naive": "area of the hotel", + "question": "What is the area of the hotel that the user in interested in?", + "slottype": "area of the hotel" + }, + "hotel-stars": { + "description_human": "star rating of the hotel", + "values": [ + "0", + "1", + "2", + "3", + "4", + "5", + "dontcare" + ], + "naive": "stars of the hotel", + "question": "What is the stars of the hotel that the user in interested in?", + "slottype": "number of stars of the hotel" + }, + "hotel-internet": { + "description_human": "whether the hotel has internet", + "values": [ + "dontcare", + "no", + "yes" + ], + "naive": "internet of the hotel", + "question": "What is the internet of the hotel that the user in interested in?", + "slottype": "whether have internet in the hotel" + }, + "train-destination": { + "description_human": "destination of the train", + "values": [], + "naive": "destination of the train", + "question": "What is the destination of the train that the user in interested in?", + "slottype": "location of destination of the train" + }, + "train-day": { + "description_human": "day of the train", + "values": [ + "dontcare", + "friday", + "monday", + "saturday", + "sunday", + "thursday", + "tuesday", + "wednesday" + ], + "naive": "day of the train", + "question": "What is the day of the train that the user in interested in?", + "slottype": "day of the train" + }, + "train-departure": { + "description_human": "departure location of the train", + "values": [], + "naive": "departure of the train", + "question": "What is the departure of the train that the user in interested in?", + "slottype": "location of departure of the train" + }, + "train-arriveby": { + "description_human": "arrival time of the train", + "values": [], + "naive": "arrive by of the train", + "question": "What is the arriveby of the train that the user in interested in?", + "slottype": "time of arrive by of the train" + }, + "train-book people": { + "description_human": "how many train tickets you need", + "values": [ + "0", + "1", + "10", + "15", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9" + ], + "naive": "people for the train booking", + "question": "What is the book people of the train that the user in interested in?", + "slottype": "number of people for the train booking" + }, + "train-leaveat": { + "description_human": "leaving time for the train", + "values": [], + "naive": "leave at of the train", + "question": "What is the leaveat of the train that the user in interested in?", + "slottype": "time of leave at of the train" + }, + "attraction-area": { + "description_human": "area to search for attractions", + "values": [ + "cambridge", + "centre", + "dontcare", + "east", + "north", + "south", + "west" + ], + "naive": "area of the attraction", + "question": "What is the area of the attraction that the user in interested in?", + "slottype": "area of the attraction" + }, + "restaurant-food": { + "description_human": "the cuisine of the restaurant you are looking for", + "values": [], + "naive": "food of the restaurant", + "question": "What is the food of the restaurant that the user in interested in?", + "slottype": "food of the restaurant" + }, + "restaurant-pricerange": { + "description_human": "price budget for the restaurant", + "values": [ + "cheap", + "dontcare", + "expensive", + "moderate" + ], + "naive": "price range of the restaurant", + "question": "What is the pricerange of the restaurant that the user in interested in?", + "slottype": "price range of the restaurant" + }, + "restaurant-area": { + "description_human": "area or place of the restaurant", + "values": [ + "centre", + "east", + "north", + "south", + "west" + ], + "naive": "area of the restaurant", + "question": "What is the area of the restaurant that the user in interested in?", + "slottype": "area of the restaurant" + }, + "attraction-name": { + "description_human": "name of the attraction", + "values": [], + "naive": "name of the attraction", + "question": "What is the name of the attraction that the user in interested in?", + "slottype": "name of the attraction" + }, + "restaurant-name": { + "description_human": "name of the restaurant", + "values": [], + "naive": "name of the restaurant", + "question": "What is the name of the restaurant that the user in interested in?", + "slottype": "name of the restaurant" + }, + "attraction-type": { + "description_human": "type of the attraction", + "values": [ + "architecture", + "boat", + "church", + "cinema", + "college", + "concerthall", + "entertainment", + "hotspot", + "multiple sports", + "museum", + "nightclub", + "park", + "special", + "swimmingpool", + "theatre" + ], + "naive": "type of the attraction", + "question": "What is the type of the attraction that the user in interested in?", + "slottype": "type of the attraction" + }, + "hotel-name": { + "description_human": "name of the hotel", + "values": [], + "naive": "name of the hotel", + "question": "What is the name of the hotel that the user in interested in?", + "slottype": "name of the hotel" + }, + "taxi-leaveat": { + "description_human": "leaving time of taxi", + "values": [], + "naive": "leave at of the taxi", + "question": "What is the leaveat of the taxi that the user in interested in?", + "slottype": "time of leave at of the taxi" + }, + "taxi-destination": { + "description_human": "destination of taxi", + "values": [], + "naive": "destination of the taxi", + "question": "What is the destination of the taxi that the user in interested in?", + "slottype": "location of destination of the taxi" + }, + "taxi-departure": { + "description_human": "departure location of taxi", + "values": [], + "naive": "departure of the taxi", + "question": "What is the departure of the taxi that the user in interested in?", + "slottype": "location of departure of the taxi" + }, + "restaurant-book time": { + "description_human": "time of the restaurant booking", + "values": [], + "naive": "time for the restaurant booking", + "question": "What is the book time of the restaurant that the user in interested in?", + "slottype": "time for the restaurant booking" + }, + "restaurant-book day": { + "description_human": "day of the restaurant booking", + "values": [ + "friday", + "monday", + "saturday", + "sunday", + "thursday", + "tuesday", + "wednesday" + ], + "naive": "day for the restaurant booking", + "question": "What is the book day of the restaurant that the user in interested in?", + "slottype": "day for the restaurant booking" + }, + "restaurant-book people": { + "description_human": "how many people for the restaurant reservation", + "values": [ + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8" + ], + "naive": "people for the restaurant booking", + "question": "What is the book people of the restaurant that the user in interested in?", + "slottype": "number of people for the restaurant booking" + }, + "taxi-arriveby": { + "description_human": "arrival time of taxi", + "values": [], + "naive": "arrive by of the taxi", + "question": "What is the arriveby of the taxi that the user in interested in?", + "slottype": "time of arrive by of the taxi" + } +} \ No newline at end of file