Skip to content

Commit

Permalink
[basicdataset] Fixed out of bound limit (#1599)
Browse files Browse the repository at this point in the history
* [basicdataset] Fixed out of bound limit

* make preprocess method more concise

* make code more clear

Co-authored-by: Zach Kimberg <zachary@kimberg.com>

Co-authored-by: Zach Kimberg <zachary@kimberg.com>
  • Loading branch information
WHALEEYE and zachgk authored Apr 28, 2022
1 parent 6e4dec2 commit 009ca1a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,28 @@ public Object getData() throws IOException {
return data;
}

/**
* Since a question might have no answer, we need extra logic to find the last index of the
* answer in the {@code TargetTextData}. There are not many consecutive questions without
* answer, so this logic will not cause a high cost.
*
* @param questionInfoIndex the last index of the record in {@code questionInfoList} that needs
* to be preprocessed
* @return the last index of the answer in {@code TargetTextData} that needs to be preprocessed
*/
private int getLastAnswerIndex(int questionInfoIndex) {
// Go backwards through the questionInfoList until it finds one with an answer
for (; questionInfoIndex >= 0; questionInfoIndex--) {
QuestionInfo questionInfo = questionInfoList.get(questionInfoIndex);
if (!questionInfo.answerIndexList.isEmpty()) {
return questionInfo.answerIndexList.get(questionInfo.answerIndexList.size() - 1);
}
}

// Could not find a QuestionInfo with an answer
return 0;
}

/**
* Performs pre-processing steps on text data such as tokenising, applying {@link
* ai.djl.modality.nlp.preprocess.TextProcessor}s, creating vocabulary, and word embeddings.
Expand All @@ -241,13 +263,10 @@ public Object getData() throws IOException {
@Override
protected void preprocess(List<String> newTextData, boolean source) throws EmbeddingException {
TextData textData = source ? sourceTextData : targetTextData;
QuestionInfo questionInfo = questionInfoList.get(Math.toIntExact(this.limit) - 1);
int index = (int) Math.min(limit, questionInfoList.size()) - 1;
int lastIndex =
source
? questionInfo.questionIndex
: questionInfo.answerIndexList.get(questionInfo.answerIndexList.size() - 1);
textData.preprocess(
manager, newTextData.subList(0, Math.min(lastIndex + 1, newTextData.size())));
source ? questionInfoList.get(index).questionIndex : getLastAnswerIndex(index);
textData.preprocess(manager, newTextData.subList(0, lastIndex + 1));
}

/** A builder for a {@link StanfordQuestionAnsweringDataset}. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,12 @@ public void testMisc() throws TranslateException, IOException {
TestUtils.getTextEmbedding(
manager, EMBEDDING_SIZE)))
.setSampling(32, true)
.optLimit(350)
.optUsage(Dataset.Usage.TEST)
.build();

stanfordQuestionAnsweringDataset.prepare();
stanfordQuestionAnsweringDataset.prepare();
Assert.assertEquals(stanfordQuestionAnsweringDataset.size(), 350);
Assert.assertEquals(stanfordQuestionAnsweringDataset.size(), 11873);

Record record0 = stanfordQuestionAnsweringDataset.get(manager, 0);
Record record6 = stanfordQuestionAnsweringDataset.get(manager, 6);
Expand Down

0 comments on commit 009ca1a

Please sign in to comment.