diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/StanfordQuestionAnsweringDataset.java b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/StanfordQuestionAnsweringDataset.java new file mode 100644 index 00000000000..7fe2abdf16c --- /dev/null +++ b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/StanfordQuestionAnsweringDataset.java @@ -0,0 +1,308 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.basicdataset.nlp; + +import ai.djl.Application.NLP; +import ai.djl.basicdataset.RawDataset; +import ai.djl.basicdataset.utils.TextData; +import ai.djl.modality.nlp.embedding.EmbeddingException; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.repository.Artifact; +import ai.djl.repository.MRL; +import ai.djl.training.dataset.Record; +import ai.djl.util.JsonUtils; +import ai.djl.util.Progress; +import com.google.gson.reflect.TypeToken; +import java.io.BufferedReader; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of + * questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every + * question is a segment of text, or span, from the corresponding reading passage, or the question + * might be unanswerable. + */ +@SuppressWarnings("unchecked") +public class StanfordQuestionAnsweringDataset extends TextDataset implements RawDataset { + + private static final String VERSION = "2.0"; + private static final String ARTIFACT_ID = "stanford-question-answer"; + + /** + * Store the information of each question, so that when function {@code get()} is called, we can + * find the question corresponding to the index. + */ + private List questionInfoList; + + /** + * Creates a new instance of {@link StanfordQuestionAnsweringDataset}. + * + * @param builder the builder object to build from + */ + protected StanfordQuestionAnsweringDataset(Builder builder) { + super(builder); + this.usage = builder.usage; + mrl = builder.getMrl(); + } + + /** + * Creates a new builder to build a {@link StanfordQuestionAnsweringDataset}. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + private Path prepareUsagePath(Progress progress) throws IOException { + Artifact artifact = mrl.getDefaultArtifact(); + mrl.prepare(artifact, progress); + Path root = mrl.getRepository().getResourceDirectory(artifact); + + Path usagePath; + switch (usage) { + case TRAIN: + usagePath = Paths.get("train-v2.0.json"); + break; + case TEST: + usagePath = Paths.get("dev-v2.0.json"); + break; + case VALIDATION: + default: + throw new UnsupportedOperationException("Validation data not available."); + } + return root.resolve(usagePath); + } + + /** + * Prepares the dataset for use with tracked progress. In this method the JSON file will be + * parsed. The question, context, title will be added to {@code sourceTextData} and the answers + * will be added to {@code targetTextData}. Both of them will then be preprocessed. + * + * @param progress the progress tracker + * @throws IOException for various exceptions depending on the dataset + * @throws EmbeddingException if there are exceptions during the embedding process + */ + @Override + public void prepare(Progress progress) throws IOException, EmbeddingException { + if (prepared) { + return; + } + Path usagePath = prepareUsagePath(progress); + + Map data; + try (BufferedReader reader = Files.newBufferedReader(usagePath)) { + data = + JsonUtils.GSON_PRETTY.fromJson( + reader, new TypeToken>() {}.getType()); + } + List> articles = (List>) data.get("data"); + + questionInfoList = new ArrayList<>(); + List sourceTextData = new ArrayList<>(); + List targetTextData = new ArrayList<>(); + + // a nested loop to handle the nested json object + List> paragraphs; + List> questions; + List> answers; + + int titleIndex; + int contextIndex; + int questionIndex; + int answerIndex; + QuestionInfo questionInfo; + for (Map article : articles) { + titleIndex = sourceTextData.size(); + sourceTextData.add(article.get("title").toString()); + + // iterate through the paragraphs + paragraphs = (List>) article.get("paragraphs"); + for (Map paragraph : paragraphs) { + contextIndex = sourceTextData.size(); + sourceTextData.add(paragraph.get("context").toString()); + + // iterate through the questions + questions = (List>) paragraph.get("qas"); + for (Map question : questions) { + questionIndex = sourceTextData.size(); + sourceTextData.add(question.get("question").toString()); + questionInfo = new QuestionInfo(questionIndex, titleIndex, contextIndex); + questionInfoList.add(questionInfo); + + // iterate through the answers + answers = (List>) question.get("answers"); + for (Map answer : answers) { + answerIndex = targetTextData.size(); + targetTextData.add(answer.get("text").toString()); + questionInfo.addAnswer(answerIndex); + } + } + } + } + + preprocess(sourceTextData, true); + preprocess(targetTextData, false); + + prepared = true; + } + + /** + * Gets the {@link Record} for the given index from the dataset. + * + * @param manager the manager used to create the arrays + * @param index the index of the requested data item + * @return a {@link Record} that contains the data and label of the requested data item. The + * data {@link NDList} contains three {@link NDArray}s representing the embedded title, + * context and question, which are named accordingly. The label {@link NDList} contains + * multiple {@link NDArray}s corresponding to each embedded answer. + */ + @Override + public Record get(NDManager manager, long index) { + NDList data = new NDList(); + NDList labels = new NDList(); + QuestionInfo questionInfo = questionInfoList.get(Math.toIntExact(index)); + + NDArray title = sourceTextData.getEmbedding(manager, questionInfo.titleIndex); + title.setName("title"); + NDArray context = sourceTextData.getEmbedding(manager, questionInfo.contextIndex); + context.setName("context"); + NDArray question = sourceTextData.getEmbedding(manager, questionInfo.questionIndex); + question.setName("question"); + + data.add(title); + data.add(context); + data.add(question); + + for (Integer answerIndex : questionInfo.answerIndexList) { + labels.add(targetTextData.getEmbedding(manager, answerIndex)); + } + + return new Record(data, labels); + } + + /** + * Returns the number of records available to be read in this {@code Dataset}. In this + * implementation, the actual size of available records are the size of {@code + * questionInfoList}. + * + * @return the number of records available to be read in this {@code Dataset} + */ + @Override + protected long availableSize() { + return questionInfoList.size(); + } + + /** + * Get data from the SQuAD dataset. This method will directly return the whole dataset as an + * object + * + * @return an object of {@link Object} class in the structure of JSON, e.g. {@code Map>>} + */ + @Override + public Object getData() throws IOException { + Path usagePath = prepareUsagePath(null); + Object data; + try (BufferedReader reader = Files.newBufferedReader(usagePath)) { + data = JsonUtils.GSON_PRETTY.fromJson(reader, new TypeToken() {}.getType()); + } + return data; + } + + /** + * Performs pre-processing steps on text data such as tokenising, applying {@link + * ai.djl.modality.nlp.preprocess.TextProcessor}s, creating vocabulary, and word embeddings. + * Since the record number in this dataset is not equivalent to the length of {@code + * sourceTextData} and {@code targetTextData}, the limit should be processed. + * + * @param newTextData list of all unprocessed sentences in the dataset + * @param source whether the text data provided is source or target + * @throws EmbeddingException if there is an error while embedding input + */ + @Override + protected void preprocess(List newTextData, boolean source) throws EmbeddingException { + TextData textData = source ? sourceTextData : targetTextData; + QuestionInfo questionInfo = questionInfoList.get(Math.toIntExact(this.limit) - 1); + int lastIndex = + source + ? questionInfo.questionIndex + : questionInfo.answerIndexList.get(questionInfo.answerIndexList.size() - 1); + textData.preprocess( + manager, newTextData.subList(0, Math.min(lastIndex + 1, newTextData.size()))); + } + + /** A builder for a {@link StanfordQuestionAnsweringDataset}. */ + public static class Builder extends TextDataset.Builder { + + /** Constructs a new builder. */ + public Builder() { + artifactId = ARTIFACT_ID; + } + + /** + * Returns this {@link Builder} object. + * + * @return this {@code BaseBuilder} + */ + @Override + public Builder self() { + return this; + } + + /** + * Builds the {@link StanfordQuestionAnsweringDataset}. + * + * @return the {@link StanfordQuestionAnsweringDataset} + */ + public StanfordQuestionAnsweringDataset build() { + return new StanfordQuestionAnsweringDataset(this); + } + + MRL getMrl() { + return repository.dataset(NLP.ANY, groupId, artifactId, VERSION); + } + } + + /** + * This class stores the information of one question. {@code sourceTextData} stores not only the + * questions, but also the titles and the contexts, and {@code targetTextData} stores right + * answers and plausible answers. Also, there are some mapping relationships between questions + * and the other entries, so we need this class to help us assemble the right record. + */ + private static class QuestionInfo { + Integer questionIndex; + Integer titleIndex; + Integer contextIndex; + List answerIndexList; + + QuestionInfo(Integer questionIndex, Integer titleIndex, Integer contextIndex) { + this.questionIndex = questionIndex; + this.titleIndex = titleIndex; + this.contextIndex = contextIndex; + this.answerIndexList = new ArrayList<>(); + } + + void addAnswer(Integer answerIndex) { + this.answerIndexList.add(answerIndex); + } + } +} diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/StanfordQuestionAnsweringDatasetTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/StanfordQuestionAnsweringDatasetTest.java new file mode 100644 index 00000000000..0e4ae2b61f7 --- /dev/null +++ b/basicdataset/src/test/java/ai/djl/basicdataset/StanfordQuestionAnsweringDatasetTest.java @@ -0,0 +1,199 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.basicdataset; + +import ai.djl.basicdataset.nlp.StanfordQuestionAnsweringDataset; +import ai.djl.basicdataset.utils.TextData; +import ai.djl.ndarray.NDManager; +import ai.djl.training.dataset.Dataset; +import ai.djl.training.dataset.Record; +import ai.djl.translate.TranslateException; +import java.io.IOException; +import java.util.Map; +import org.testng.Assert; +import org.testng.annotations.Test; + +@SuppressWarnings("unchecked") +public class StanfordQuestionAnsweringDatasetTest { + + private static final int EMBEDDING_SIZE = 15; + + @Test + public void testGetDataWithPreTrainedEmbedding() throws TranslateException, IOException { + + try (NDManager manager = NDManager.newBaseManager()) { + StanfordQuestionAnsweringDataset stanfordQuestionAnsweringDataset = + StanfordQuestionAnsweringDataset.builder() + .setSourceConfiguration( + new TextData.Configuration() + .setTextEmbedding( + TestUtils.getTextEmbedding( + manager, EMBEDDING_SIZE))) + .setTargetConfiguration( + new TextData.Configuration() + .setTextEmbedding( + TestUtils.getTextEmbedding( + manager, EMBEDDING_SIZE))) + .setSampling(32, true) + .optLimit(10) + .optUsage(Dataset.Usage.TEST) + .build(); + + stanfordQuestionAnsweringDataset.prepare(); + Record record = stanfordQuestionAnsweringDataset.get(manager, 0); + Assert.assertEquals(record.getData().get("title").getShape().get(0), 1); + Assert.assertEquals(record.getData().get("question").getShape().get(0), 7); + Assert.assertEquals(record.getLabels().size(), 4); + } + } + + @Test + public void testGetDataWithTrainableEmbedding() throws IOException, TranslateException { + try (NDManager manager = NDManager.newBaseManager()) { + StanfordQuestionAnsweringDataset stanfordQuestionAnsweringDataset = + StanfordQuestionAnsweringDataset.builder() + .setSourceConfiguration( + new TextData.Configuration().setEmbeddingSize(EMBEDDING_SIZE)) + .setTargetConfiguration( + new TextData.Configuration().setEmbeddingSize(EMBEDDING_SIZE)) + .setSampling(32, true) + .optLimit(10) + .build(); + + stanfordQuestionAnsweringDataset.prepare(); + Record record = stanfordQuestionAnsweringDataset.get(manager, 0); + Assert.assertEquals(record.getData().get("title").getShape().dimension(), 1); + Assert.assertEquals(record.getData().get("context").getShape().get(0), 156); + Assert.assertEquals(record.getLabels().size(), 1); + } + } + + @Test + public void testInvalidUsage() throws TranslateException, IOException { + + try (NDManager manager = NDManager.newBaseManager()) { + StanfordQuestionAnsweringDataset stanfordQuestionAnsweringDataset = + StanfordQuestionAnsweringDataset.builder() + .setSourceConfiguration( + new TextData.Configuration() + .setTextEmbedding( + TestUtils.getTextEmbedding( + manager, EMBEDDING_SIZE))) + .setTargetConfiguration( + new TextData.Configuration() + .setTextEmbedding( + TestUtils.getTextEmbedding( + manager, EMBEDDING_SIZE))) + .setSampling(32, true) + .optLimit(10) + .optUsage(Dataset.Usage.VALIDATION) + .build(); + + stanfordQuestionAnsweringDataset.prepare(); + } catch (UnsupportedOperationException uoe) { + Assert.assertEquals(uoe.getMessage(), "Validation data not available."); + } + } + + @Test + public void testMisc() throws TranslateException, IOException { + + try (NDManager manager = NDManager.newBaseManager()) { + StanfordQuestionAnsweringDataset stanfordQuestionAnsweringDataset = + StanfordQuestionAnsweringDataset.builder() + .setSourceConfiguration( + new TextData.Configuration() + .setTextEmbedding( + TestUtils.getTextEmbedding( + manager, EMBEDDING_SIZE))) + .setTargetConfiguration( + new TextData.Configuration() + .setTextEmbedding( + TestUtils.getTextEmbedding( + manager, EMBEDDING_SIZE))) + .setSampling(32, true) + .optLimit(350) + .optUsage(Dataset.Usage.TEST) + .build(); + + stanfordQuestionAnsweringDataset.prepare(); + stanfordQuestionAnsweringDataset.prepare(); + Assert.assertEquals(stanfordQuestionAnsweringDataset.size(), 350); + + Record record0 = stanfordQuestionAnsweringDataset.get(manager, 0); + Record record6 = stanfordQuestionAnsweringDataset.get(manager, 6); + Assert.assertEquals(record6.getData().get("title").getShape().dimension(), 2); + Assert.assertEquals( + record0.getData().get("context").getShape().get(0), + record6.getData().get("context").getShape().get(0)); + Assert.assertEquals(record6.getLabels().size(), 0); + } + } + + @Test + public void testLimitBoundary() throws TranslateException, IOException { + + try (NDManager manager = NDManager.newBaseManager()) { + StanfordQuestionAnsweringDataset stanfordQuestionAnsweringDataset = + StanfordQuestionAnsweringDataset.builder() + .setSourceConfiguration( + new TextData.Configuration() + .setTextEmbedding( + TestUtils.getTextEmbedding( + manager, EMBEDDING_SIZE))) + .setTargetConfiguration( + new TextData.Configuration() + .setTextEmbedding( + TestUtils.getTextEmbedding( + manager, EMBEDDING_SIZE))) + .setSampling(32, true) + .optLimit(3) + .optUsage(Dataset.Usage.TEST) + .build(); + + stanfordQuestionAnsweringDataset.prepare(); + Assert.assertEquals(stanfordQuestionAnsweringDataset.size(), 3); + Record record = stanfordQuestionAnsweringDataset.get(manager, 2); + Assert.assertEquals(record.getData().get("title").getShape().dimension(), 2); + Assert.assertEquals(record.getData().get("context").getShape().get(0), 140); + Assert.assertEquals(record.getLabels().size(), 4); + } + } + + @Test + public void testRawData() throws IOException { + + try (NDManager manager = NDManager.newBaseManager()) { + StanfordQuestionAnsweringDataset stanfordQuestionAnsweringDataset = + StanfordQuestionAnsweringDataset.builder() + .setSourceConfiguration( + new TextData.Configuration() + .setTextEmbedding( + TestUtils.getTextEmbedding( + manager, EMBEDDING_SIZE))) + .setTargetConfiguration( + new TextData.Configuration() + .setTextEmbedding( + TestUtils.getTextEmbedding( + manager, EMBEDDING_SIZE))) + .setSampling(32, true) + .optLimit(350) + .optUsage(Dataset.Usage.TEST) + .build(); + + Map data = + (Map) stanfordQuestionAnsweringDataset.getData(); + Assert.assertEquals(data.get("version").toString(), "v2.0"); + } + } +} diff --git a/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/stanford-question-answer/metadata.json b/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/stanford-question-answer/metadata.json new file mode 100644 index 00000000000..9eb409e401f --- /dev/null +++ b/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/stanford-question-answer/metadata.json @@ -0,0 +1,34 @@ +{ + "metadataVersion": "0.2", + "resourceType": "dataset", + "application": "nlp", + "groupId": "ai.djl.basicdataset", + "artifactId": "stanford-question-answer", + "name": "stanford-question-answer", + "description": "A reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles", + "website": "https://rajpurkar.github.io/SQuAD-explorer/", + "licenses": { + "license": { + "name": "Creative Commons Attribution-ShareAlike License", + "url": "https://creativecommons.org/licenses/by-sa/4.0/legalcode" + } + }, + "artifacts": [ + { + "version": "2.0", + "snapshot": false, + "files": { + "train": { + "uri": "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json", + "sha1Hash": "ceb2acdea93b9d82ab1829c7b1e03bee9e302c99", + "size": 42123633 + }, + "test": { + "uri": "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json", + "sha1Hash": "53ebaeb15bc5cab36645150f6f65d074348e2f3d", + "size": 4370528 + } + } + } + ] +}