Skip to content

Commit

Permalink
[ML] adds new question_answering NLP task for extracting answers to q…
Browse files Browse the repository at this point in the history
…uestions from a document (#85958)

This commit adds a new `question_answering` task.

The `question_answering` task allows supplying a `question` in the inference config update.

When storing the model config for inference:
```
"inference_config": {
  "question_answering": {
    "tokenization": {...}, // tokenization settings, recommend doing 386 max sequence length with 128 span, and no truncating
    "max_answer_length": 15 // the max answer length to consider
  }
}
```

Then when calling `_infer` or running with in a pipeline, add the `question` you want answered given the context provided by the document text
```
{
  "docs":[{ "text_field": <some long text field to extract answer}],
  "inference_config": {
    "question_answering": {
      "question": <Question desiring answer>
    }
  }
}
```
The response then looks like:
```
{
    "predicted_value": <string subsection of the document that is the answer>
    "start_offset": <Char offset in document to start>,
    "end_offset": <char offset end of the answer,
    "prediction_probability": <prediction score>
}
```

Some models tested:

 - https://huggingface.co/distilbert-base-cased-distilled-squad
 - https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad
 - https://huggingface.co/deepset/electra-base-squad2
 - https://huggingface.co/deepset/tinyroberta-squad2
  • Loading branch information
benwtrent authored May 4, 2022
1 parent f385c65 commit b7f24bd
Show file tree
Hide file tree
Showing 33 changed files with 1,999 additions and 80 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/85958.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 85958
summary: Adds new `question_answering` NLP task for extracting answers to questions
from a document
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
import org.elasticsearch.xpack.core.ml.inference.results.QuestionAnsweringInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
Expand All @@ -47,6 +48,8 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
Expand Down Expand Up @@ -365,6 +368,20 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
ZeroShotClassificationConfig::fromXContentStrict
)
);
namedXContent.add(
new NamedXContentRegistry.Entry(
StrictlyParsedInferenceConfig.class,
new ParseField(QuestionAnsweringConfig.NAME),
QuestionAnsweringConfig::fromXContentStrict
)
);
namedXContent.add(
new NamedXContentRegistry.Entry(
LenientlyParsedInferenceConfig.class,
new ParseField(QuestionAnsweringConfig.NAME),
QuestionAnsweringConfig::fromXContentLenient
)
);

// Inference Configs Update
namedXContent.add(
Expand Down Expand Up @@ -423,6 +440,13 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
ZeroShotClassificationConfigUpdate::fromXContentStrict
)
);
namedXContent.add(
new NamedXContentRegistry.Entry(
InferenceConfigUpdate.class,
new ParseField(QuestionAnsweringConfigUpdate.NAME),
QuestionAnsweringConfigUpdate::fromXContentStrict
)
);

// Inference models
namedXContent.add(new NamedXContentRegistry.Entry(InferenceModel.class, Ensemble.NAME, EnsembleInferenceModel::fromXContent));
Expand Down Expand Up @@ -548,6 +572,13 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
NlpClassificationInferenceResults::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceResults.class,
QuestionAnsweringInferenceResults.NAME,
QuestionAnsweringInferenceResults::new
)
);
// Inference Configs
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME.getPreferredName(), ClassificationConfig::new)
Expand All @@ -565,6 +596,9 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceConfig.class, ZeroShotClassificationConfig.NAME, ZeroShotClassificationConfig::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceConfig.class, QuestionAnsweringConfig.NAME, QuestionAnsweringConfig::new)
);

// Inference Configs Updates
namedWriteables.add(
Expand Down Expand Up @@ -609,6 +643,13 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
ZeroShotClassificationConfigUpdate::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceConfigUpdate.class,
QuestionAnsweringConfigUpdate.NAME,
QuestionAnsweringConfigUpdate::new
)
);

// Location
namedWriteables.add(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;

public class QuestionAnsweringInferenceResults extends NlpInferenceResults {

public static final String NAME = "question_answering";
public static final ParseField START_OFFSET = new ParseField("start_offset");
public static final ParseField END_OFFSET = new ParseField("end_offset");

private final String resultsField;
private final String answer;
private final int startOffset;
private final int endOffset;
private final double score;
private final List<TopAnswerEntry> topClasses;

public QuestionAnsweringInferenceResults(
String answer,
int startOffset,
int endOffset,
List<TopAnswerEntry> topClasses,
String resultsField,
double score,
boolean isTruncated
) {
super(isTruncated);
this.startOffset = startOffset;
this.endOffset = endOffset;
this.answer = Objects.requireNonNull(answer);
this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses);
this.resultsField = resultsField;
this.score = score;
}

public QuestionAnsweringInferenceResults(StreamInput in) throws IOException {
super(in);
this.answer = in.readString();
this.startOffset = in.readVInt();
this.endOffset = in.readVInt();
this.topClasses = Collections.unmodifiableList(in.readList(TopAnswerEntry::fromStream));
this.resultsField = in.readString();
this.score = in.readDouble();
}

public String getAnswer() {
return answer;
}

public List<TopAnswerEntry> getTopClasses() {
return topClasses;
}

@Override
public void doWriteTo(StreamOutput out) throws IOException {
out.writeString(answer);
out.writeVInt(startOffset);
out.writeVInt(endOffset);
out.writeCollection(topClasses);
out.writeString(resultsField);
out.writeDouble(score);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
QuestionAnsweringInferenceResults that = (QuestionAnsweringInferenceResults) o;
return Objects.equals(resultsField, that.resultsField)
&& Objects.equals(answer, that.answer)
&& Objects.equals(startOffset, that.startOffset)
&& Objects.equals(endOffset, that.endOffset)
&& Objects.equals(score, that.score)
&& Objects.equals(topClasses, that.topClasses);
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), resultsField, answer, score, topClasses, startOffset, endOffset);
}

public double getScore() {
return score;
}

@Override
public String getResultsField() {
return resultsField;
}

@Override
public String predictedValue() {
return answer;
}

@Override
void addMapFields(Map<String, Object> map) {
map.put(resultsField, answer);
map.put(START_OFFSET.getPreferredName(), startOffset);
map.put(END_OFFSET.getPreferredName(), endOffset);
if (topClasses.isEmpty() == false) {
map.put(
NlpConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD,
topClasses.stream().map(TopAnswerEntry::asValueMap).collect(Collectors.toList())
);
}
map.put(PREDICTION_PROBABILITY, score);
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public void doXContentBody(XContentBuilder builder, Params params) throws IOException {
builder.field(resultsField, answer);
builder.field(START_OFFSET.getPreferredName(), startOffset);
builder.field(END_OFFSET.getPreferredName(), endOffset);
if (topClasses.size() > 0) {
builder.field(NlpConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD, topClasses);
}
builder.field(PREDICTION_PROBABILITY, score);
}

public int getStartOffset() {
return startOffset;
}

public int getEndOffset() {
return endOffset;
}

public record TopAnswerEntry(String answer, double score, int startOffset, int endOffset) implements Writeable, ToXContentObject {

public static final ParseField ANSWER = new ParseField("answer");
public static final ParseField SCORE = new ParseField("score");

public static TopAnswerEntry fromStream(StreamInput in) throws IOException {
return new TopAnswerEntry(in.readString(), in.readDouble(), in.readVInt(), in.readVInt());
}

public static final String NAME = "top_answer";

private static final ConstructingObjectParser<TopAnswerEntry, Void> PARSER = new ConstructingObjectParser<>(
NAME,
a -> new TopAnswerEntry((String) a[0], (Double) a[1], (Integer) a[2], (Integer) a[3])
);

static {
PARSER.declareString(constructorArg(), ANSWER);
PARSER.declareDouble(constructorArg(), SCORE);
PARSER.declareInt(constructorArg(), START_OFFSET);
PARSER.declareInt(constructorArg(), END_OFFSET);
}

public static TopAnswerEntry fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
}

public Map<String, Object> asValueMap() {
Map<String, Object> map = Maps.newMapWithExpectedSize(4);
map.put(ANSWER.getPreferredName(), answer);
map.put(START_OFFSET.getPreferredName(), startOffset);
map.put(END_OFFSET.getPreferredName(), endOffset);
map.put(SCORE.getPreferredName(), score);
return map;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ANSWER.getPreferredName(), answer);
builder.field(START_OFFSET.getPreferredName(), startOffset);
builder.field(END_OFFSET.getPreferredName(), endOffset);
builder.field(SCORE.getPreferredName(), score);
builder.endObject();
return builder;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(answer);
out.writeDouble(score);
out.writeVInt(startOffset);
out.writeVInt(endOffset);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;

public interface InferenceConfig extends NamedXContentObject, NamedWriteable {
public interface InferenceConfig extends NamedXContentObject, VersionedNamedWriteable {

String DEFAULT_TOP_CLASSES_RESULTS_FIELD = "top_classes";
String DEFAULT_RESULTS_FIELD = "predicted_value";
Expand Down
Loading

0 comments on commit b7f24bd

Please sign in to comment.