diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/StanfordQuestionAnsweringDataset.java b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/StanfordQuestionAnsweringDataset.java index 7fe2abdf16c..7716d5ab1a9 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/StanfordQuestionAnsweringDataset.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/StanfordQuestionAnsweringDataset.java @@ -228,6 +228,28 @@ public Object getData() throws IOException { return data; } + /** + * Since a question might have no answer, we need extra logic to find the last index of the + * answer in the {@code TargetTextData}. There are not many consecutive questions without + * answer, so this logic will not cause a high cost. + * + * @param questionInfoIndex the last index of the record in {@code questionInfoList} that needs + * to be preprocessed + * @return the last index of the answer in {@code TargetTextData} that needs to be preprocessed + */ + private int getLastAnswerIndex(int questionInfoIndex) { + // Go backwards through the questionInfoList until it finds one with an answer + for (; questionInfoIndex >= 0; questionInfoIndex--) { + QuestionInfo questionInfo = questionInfoList.get(questionInfoIndex); + if (!questionInfo.answerIndexList.isEmpty()) { + return questionInfo.answerIndexList.get(questionInfo.answerIndexList.size() - 1); + } + } + + // Could not find a QuestionInfo with an answer + return 0; + } + /** * Performs pre-processing steps on text data such as tokenising, applying {@link * ai.djl.modality.nlp.preprocess.TextProcessor}s, creating vocabulary, and word embeddings. @@ -241,13 +263,10 @@ public Object getData() throws IOException { @Override protected void preprocess(List newTextData, boolean source) throws EmbeddingException { TextData textData = source ? sourceTextData : targetTextData; - QuestionInfo questionInfo = questionInfoList.get(Math.toIntExact(this.limit) - 1); + int index = (int) Math.min(limit, questionInfoList.size()) - 1; int lastIndex = - source - ? questionInfo.questionIndex - : questionInfo.answerIndexList.get(questionInfo.answerIndexList.size() - 1); - textData.preprocess( - manager, newTextData.subList(0, Math.min(lastIndex + 1, newTextData.size()))); + source ? questionInfoList.get(index).questionIndex : getLastAnswerIndex(index); + textData.preprocess(manager, newTextData.subList(0, lastIndex + 1)); } /** A builder for a {@link StanfordQuestionAnsweringDataset}. */ diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/StanfordQuestionAnsweringDatasetTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/StanfordQuestionAnsweringDatasetTest.java index 0e4ae2b61f7..55f6f28375c 100644 --- a/basicdataset/src/test/java/ai/djl/basicdataset/StanfordQuestionAnsweringDatasetTest.java +++ b/basicdataset/src/test/java/ai/djl/basicdataset/StanfordQuestionAnsweringDatasetTest.java @@ -122,13 +122,12 @@ public void testMisc() throws TranslateException, IOException { TestUtils.getTextEmbedding( manager, EMBEDDING_SIZE))) .setSampling(32, true) - .optLimit(350) .optUsage(Dataset.Usage.TEST) .build(); stanfordQuestionAnsweringDataset.prepare(); stanfordQuestionAnsweringDataset.prepare(); - Assert.assertEquals(stanfordQuestionAnsweringDataset.size(), 350); + Assert.assertEquals(stanfordQuestionAnsweringDataset.size(), 11873); Record record0 = stanfordQuestionAnsweringDataset.get(manager, 0); Record record6 = stanfordQuestionAnsweringDataset.get(manager, 6);