From b7f24bdc4bc8a4427d71325ddbb33ca327d182a7 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Wed, 4 May 2022 13:15:13 -0400 Subject: [PATCH] [ML] adds new question_answering NLP task for extracting answers to questions from a document (#85958) This commit adds a new `question_answering` task. The `question_answering` task allows supplying a `question` in the inference config update. When storing the model config for inference: ``` "inference_config": { "question_answering": { "tokenization": {...}, // tokenization settings, recommend doing 386 max sequence length with 128 span, and no truncating "max_answer_length": 15 // the max answer length to consider } } ``` Then when calling `_infer` or running with in a pipeline, add the `question` you want answered given the context provided by the document text ``` { "docs":[{ "text_field": } } } ``` The response then looks like: ``` { "predicted_value": "start_offset": , "end_offset": } ``` Some models tested: - https://huggingface.co/distilbert-base-cased-distilled-squad - https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad - https://huggingface.co/deepset/electra-base-squad2 - https://huggingface.co/deepset/tinyroberta-squad2 --- docs/changelog/85958.yaml | 6 + .../MlInferenceNamedXContentProvider.java | 41 +++ .../QuestionAnsweringInferenceResults.java | 215 ++++++++++++++ .../trainedmodel/InferenceConfig.java | 4 +- .../trainedmodel/QuestionAnsweringConfig.java | 250 ++++++++++++++++ .../QuestionAnsweringConfigUpdate.java | 258 +++++++++++++++++ .../InternalInferModelActionRequestTests.java | 5 + ...InternalInferModelActionResponseTests.java | 50 +++- .../InferenceConfigItemTestCase.java | 17 +- .../results/FillMaskResultsTests.java | 17 +- .../ml/inference/results/NerResultsTests.java | 17 +- .../PyTorchPassThroughResultsTests.java | 17 +- ...uestionAnsweringInferenceResultsTests.java | 85 ++++++ .../results/TextEmbeddingResultsTests.java | 17 +- .../results/TopAnswerEntryTests.java | 36 +++ .../QuestionAnsweringConfigTests.java | 65 +++++ .../QuestionAnsweringConfigUpdateTests.java | 184 ++++++++++++ .../inference/ingest/InferenceProcessor.java | 5 + .../ml/inference/nlp/FillMaskProcessor.java | 2 +- .../xpack/ml/inference/nlp/NerProcessor.java | 8 +- .../nlp/QuestionAnsweringProcessor.java | 272 ++++++++++++++++++ .../xpack/ml/inference/nlp/TaskType.java | 7 + .../tokenizers/BertTokenizationResult.java | 13 +- .../tokenizers/MPNetTokenizationResult.java | 1 + .../nlp/tokenizers/NlpTokenizer.java | 131 ++++++++- .../tokenizers/RobertaTokenizationResult.java | 13 +- .../nlp/tokenizers/TokenizationResult.java | 49 +++- .../InferenceProcessorFactoryTests.java | 26 +- .../inference/nlp/FillMaskProcessorTests.java | 4 +- .../nlp/QuestionAnsweringProcessorTests.java | 171 +++++++++++ .../nlp/tokenizers/BertTokenizerTests.java | 89 +++++- .../nlp/tokenizers/MPNetTokenizerTests.java | 2 +- .../nlp/tokenizers/RobertaTokenizerTests.java | 2 +- 33 files changed, 1999 insertions(+), 80 deletions(-) create mode 100644 docs/changelog/85958.yaml create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfig.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigUpdate.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResultsTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TopAnswerEntryTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigUpdateTests.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessor.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java diff --git a/docs/changelog/85958.yaml b/docs/changelog/85958.yaml new file mode 100644 index 0000000000000..82bfc8e3a80e9 --- /dev/null +++ b/docs/changelog/85958.yaml @@ -0,0 +1,6 @@ +pr: 85958 +summary: Adds new `question_answering` NLP task for extracting answers to questions + from a document +area: Machine Learning +type: enhancement +issues: [] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 87dfdaea41191..107d717498a6a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -25,6 +25,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.NerResults; import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults; +import org.elasticsearch.xpack.core.ml.inference.results.QuestionAnsweringInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; @@ -47,6 +48,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate; @@ -365,6 +368,20 @@ public List getNamedXContentParsers() { ZeroShotClassificationConfig::fromXContentStrict ) ); + namedXContent.add( + new NamedXContentRegistry.Entry( + StrictlyParsedInferenceConfig.class, + new ParseField(QuestionAnsweringConfig.NAME), + QuestionAnsweringConfig::fromXContentStrict + ) + ); + namedXContent.add( + new NamedXContentRegistry.Entry( + LenientlyParsedInferenceConfig.class, + new ParseField(QuestionAnsweringConfig.NAME), + QuestionAnsweringConfig::fromXContentLenient + ) + ); // Inference Configs Update namedXContent.add( @@ -423,6 +440,13 @@ public List getNamedXContentParsers() { ZeroShotClassificationConfigUpdate::fromXContentStrict ) ); + namedXContent.add( + new NamedXContentRegistry.Entry( + InferenceConfigUpdate.class, + new ParseField(QuestionAnsweringConfigUpdate.NAME), + QuestionAnsweringConfigUpdate::fromXContentStrict + ) + ); // Inference models namedXContent.add(new NamedXContentRegistry.Entry(InferenceModel.class, Ensemble.NAME, EnsembleInferenceModel::fromXContent)); @@ -548,6 +572,13 @@ public List getNamedWriteables() { NlpClassificationInferenceResults::new ) ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + InferenceResults.class, + QuestionAnsweringInferenceResults.NAME, + QuestionAnsweringInferenceResults::new + ) + ); // Inference Configs namedWriteables.add( new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME.getPreferredName(), ClassificationConfig::new) @@ -565,6 +596,9 @@ public List getNamedWriteables() { namedWriteables.add( new NamedWriteableRegistry.Entry(InferenceConfig.class, ZeroShotClassificationConfig.NAME, ZeroShotClassificationConfig::new) ); + namedWriteables.add( + new NamedWriteableRegistry.Entry(InferenceConfig.class, QuestionAnsweringConfig.NAME, QuestionAnsweringConfig::new) + ); // Inference Configs Updates namedWriteables.add( @@ -609,6 +643,13 @@ public List getNamedWriteables() { ZeroShotClassificationConfigUpdate::new ) ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + InferenceConfigUpdate.class, + QuestionAnsweringConfigUpdate.NAME, + QuestionAnsweringConfigUpdate::new + ) + ); // Location namedWriteables.add( diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java new file mode 100644 index 0000000000000..293694ac8e3ca --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java @@ -0,0 +1,215 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.Maps; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; + +public class QuestionAnsweringInferenceResults extends NlpInferenceResults { + + public static final String NAME = "question_answering"; + public static final ParseField START_OFFSET = new ParseField("start_offset"); + public static final ParseField END_OFFSET = new ParseField("end_offset"); + + private final String resultsField; + private final String answer; + private final int startOffset; + private final int endOffset; + private final double score; + private final List topClasses; + + public QuestionAnsweringInferenceResults( + String answer, + int startOffset, + int endOffset, + List topClasses, + String resultsField, + double score, + boolean isTruncated + ) { + super(isTruncated); + this.startOffset = startOffset; + this.endOffset = endOffset; + this.answer = Objects.requireNonNull(answer); + this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses); + this.resultsField = resultsField; + this.score = score; + } + + public QuestionAnsweringInferenceResults(StreamInput in) throws IOException { + super(in); + this.answer = in.readString(); + this.startOffset = in.readVInt(); + this.endOffset = in.readVInt(); + this.topClasses = Collections.unmodifiableList(in.readList(TopAnswerEntry::fromStream)); + this.resultsField = in.readString(); + this.score = in.readDouble(); + } + + public String getAnswer() { + return answer; + } + + public List getTopClasses() { + return topClasses; + } + + @Override + public void doWriteTo(StreamOutput out) throws IOException { + out.writeString(answer); + out.writeVInt(startOffset); + out.writeVInt(endOffset); + out.writeCollection(topClasses); + out.writeString(resultsField); + out.writeDouble(score); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; + QuestionAnsweringInferenceResults that = (QuestionAnsweringInferenceResults) o; + return Objects.equals(resultsField, that.resultsField) + && Objects.equals(answer, that.answer) + && Objects.equals(startOffset, that.startOffset) + && Objects.equals(endOffset, that.endOffset) + && Objects.equals(score, that.score) + && Objects.equals(topClasses, that.topClasses); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), resultsField, answer, score, topClasses, startOffset, endOffset); + } + + public double getScore() { + return score; + } + + @Override + public String getResultsField() { + return resultsField; + } + + @Override + public String predictedValue() { + return answer; + } + + @Override + void addMapFields(Map map) { + map.put(resultsField, answer); + map.put(START_OFFSET.getPreferredName(), startOffset); + map.put(END_OFFSET.getPreferredName(), endOffset); + if (topClasses.isEmpty() == false) { + map.put( + NlpConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD, + topClasses.stream().map(TopAnswerEntry::asValueMap).collect(Collectors.toList()) + ); + } + map.put(PREDICTION_PROBABILITY, score); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void doXContentBody(XContentBuilder builder, Params params) throws IOException { + builder.field(resultsField, answer); + builder.field(START_OFFSET.getPreferredName(), startOffset); + builder.field(END_OFFSET.getPreferredName(), endOffset); + if (topClasses.size() > 0) { + builder.field(NlpConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD, topClasses); + } + builder.field(PREDICTION_PROBABILITY, score); + } + + public int getStartOffset() { + return startOffset; + } + + public int getEndOffset() { + return endOffset; + } + + public record TopAnswerEntry(String answer, double score, int startOffset, int endOffset) implements Writeable, ToXContentObject { + + public static final ParseField ANSWER = new ParseField("answer"); + public static final ParseField SCORE = new ParseField("score"); + + public static TopAnswerEntry fromStream(StreamInput in) throws IOException { + return new TopAnswerEntry(in.readString(), in.readDouble(), in.readVInt(), in.readVInt()); + } + + public static final String NAME = "top_answer"; + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + a -> new TopAnswerEntry((String) a[0], (Double) a[1], (Integer) a[2], (Integer) a[3]) + ); + + static { + PARSER.declareString(constructorArg(), ANSWER); + PARSER.declareDouble(constructorArg(), SCORE); + PARSER.declareInt(constructorArg(), START_OFFSET); + PARSER.declareInt(constructorArg(), END_OFFSET); + } + + public static TopAnswerEntry fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + public Map asValueMap() { + Map map = Maps.newMapWithExpectedSize(4); + map.put(ANSWER.getPreferredName(), answer); + map.put(START_OFFSET.getPreferredName(), startOffset); + map.put(END_OFFSET.getPreferredName(), endOffset); + map.put(SCORE.getPreferredName(), score); + return map; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ANSWER.getPreferredName(), answer); + builder.field(START_OFFSET.getPreferredName(), startOffset); + builder.field(END_OFFSET.getPreferredName(), endOffset); + builder.field(SCORE.getPreferredName(), score); + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(answer); + out.writeDouble(score); + out.writeVInt(startOffset); + out.writeVInt(endOffset); + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java index a5900a9a615b5..e3fe1a6d6576f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java @@ -7,10 +7,10 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; import org.elasticsearch.Version; -import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.io.stream.VersionedNamedWriteable; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; -public interface InferenceConfig extends NamedXContentObject, NamedWriteable { +public interface InferenceConfig extends NamedXContentObject, VersionedNamedWriteable { String DEFAULT_TOP_CLASSES_RESULTS_FIELD = "top_classes"; String DEFAULT_RESULTS_FIELD = "predicted_value"; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfig.java new file mode 100644 index 0000000000000..0ecbcdf3bee32 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfig.java @@ -0,0 +1,250 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; + +import java.io.IOException; +import java.util.Objects; +import java.util.Optional; + +/** + * Question and Answer configuration + */ +public class QuestionAnsweringConfig implements NlpConfig { + + public static final String NAME = "question_answering"; + public static final ParseField MAX_ANSWER_LENGTH = new ParseField("max_answer_length"); + public static final ParseField QUESTION = new ParseField("question"); + + public static QuestionAnsweringConfig fromXContentStrict(XContentParser parser) { + return STRICT_PARSER.apply(parser, null); + } + + public static QuestionAnsweringConfig fromXContentLenient(XContentParser parser) { + return LENIENT_PARSER.apply(parser, null); + } + + public static final int DEFAULT_MAX_ANSWER_LENGTH = 15; + public static final int DEFAULT_NUM_TOP_CLASSES = 0; + + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + + @SuppressWarnings({ "unchecked" }) + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>( + NAME, + ignoreUnknownFields, + a -> new QuestionAnsweringConfig((Integer) a[0], (Integer) a[1], (VocabularyConfig) a[2], (Tokenization) a[3], (String) a[4]) + ); + parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES); + parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAX_ANSWER_LENGTH); + parser.declareObject(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> { + if (ignoreUnknownFields == false) { + throw ExceptionsHelper.badRequestException( + "illegal setting [{}] on inference model creation", + VOCABULARY.getPreferredName() + ); + } + return VocabularyConfig.fromXContentLenient(p); + }, VOCABULARY); + parser.declareNamedObject( + ConstructingObjectParser.optionalConstructorArg(), + (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields), + TOKENIZATION + ); + parser.declareString(ConstructingObjectParser.optionalConstructorArg(), RESULTS_FIELD); + return parser; + } + + private final int numTopClasses; + private final int maxAnswerLength; + private final VocabularyConfig vocabularyConfig; + private final Tokenization tokenization; + private final String resultsField; + private final String question; + + public QuestionAnsweringConfig( + @Nullable Integer numTopClasses, + @Nullable Integer maxAnswerLength, + @Nullable VocabularyConfig vocabularyConfig, + @Nullable Tokenization tokenization, + @Nullable String resultsField + ) { + this.numTopClasses = Optional.ofNullable(numTopClasses).orElse(DEFAULT_NUM_TOP_CLASSES); + this.maxAnswerLength = Optional.ofNullable(maxAnswerLength).orElse(DEFAULT_MAX_ANSWER_LENGTH); + if (this.numTopClasses < 0) { + throw ExceptionsHelper.badRequestException( + "[{}] must be greater than or equal to [0]; provided [{}]", + NUM_TOP_CLASSES.getPreferredName(), + this.numTopClasses + ); + } + if (this.maxAnswerLength <= 0) { + throw ExceptionsHelper.badRequestException( + "[{}] must be greater than [0]; provided [{}]", + MAX_ANSWER_LENGTH.getPreferredName(), + this.maxAnswerLength + ); + } + this.vocabularyConfig = Optional.ofNullable(vocabularyConfig) + .orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore())); + this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization; + this.resultsField = resultsField; + this.question = null; + } + + public QuestionAnsweringConfig( + String question, + int numTopClasses, + int maxAnswerLength, + VocabularyConfig vocabularyConfig, + Tokenization tokenization, + String resultsField + ) { + this.question = ExceptionsHelper.requireNonNull(question, QUESTION); + this.numTopClasses = numTopClasses; + this.maxAnswerLength = maxAnswerLength; + if (this.numTopClasses < 0) { + throw ExceptionsHelper.badRequestException( + "[{}] must be greater than or equal to [0]; provided [{}]", + NUM_TOP_CLASSES.getPreferredName(), + this.numTopClasses + ); + } + if (this.maxAnswerLength <= 0) { + throw ExceptionsHelper.badRequestException( + "[{}] must be greater than [0]; provided [{}]", + MAX_ANSWER_LENGTH.getPreferredName(), + this.maxAnswerLength + ); + } + this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY); + this.tokenization = ExceptionsHelper.requireNonNull(tokenization, TOKENIZATION); + this.resultsField = resultsField; + } + + public QuestionAnsweringConfig(StreamInput in) throws IOException { + numTopClasses = in.readVInt(); + maxAnswerLength = in.readVInt(); + vocabularyConfig = new VocabularyConfig(in); + tokenization = in.readNamedWriteable(Tokenization.class); + resultsField = in.readOptionalString(); + question = in.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(numTopClasses); + out.writeVInt(maxAnswerLength); + vocabularyConfig.writeTo(out); + out.writeNamedWriteable(tokenization); + out.writeOptionalString(resultsField); + out.writeOptionalString(question); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); + builder.field(MAX_ANSWER_LENGTH.getPreferredName(), maxAnswerLength); + builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params); + NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization); + if (resultsField != null) { + builder.field(RESULTS_FIELD.getPreferredName(), resultsField); + } + if (question != null) { + builder.field(QUESTION.getPreferredName(), question); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public boolean isTargetTypeSupported(TargetType targetType) { + return false; + } + + @Override + public Version getMinimalSupportedVersion() { + return Version.V_8_3_0; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + QuestionAnsweringConfig that = (QuestionAnsweringConfig) o; + return Objects.equals(vocabularyConfig, that.vocabularyConfig) + && Objects.equals(tokenization, that.tokenization) + && Objects.equals(numTopClasses, that.numTopClasses) + && Objects.equals(maxAnswerLength, that.maxAnswerLength) + && Objects.equals(question, that.question) + && Objects.equals(resultsField, that.resultsField); + } + + @Override + public int hashCode() { + return Objects.hash(vocabularyConfig, tokenization, maxAnswerLength, numTopClasses, resultsField, question); + } + + @Override + public VocabularyConfig getVocabularyConfig() { + return vocabularyConfig; + } + + @Override + public Tokenization getTokenization() { + return tokenization; + } + + public int getNumTopClasses() { + return numTopClasses; + } + + public int getMaxAnswerLength() { + return maxAnswerLength; + } + + public String getQuestion() { + return question; + } + + @Override + public String getResultsField() { + return resultsField; + } + + @Override + public boolean isAllocateOnly() { + return true; + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigUpdate.java new file mode 100644 index 0000000000000..1ffdd337aaeb2 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigUpdate.java @@ -0,0 +1,258 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.NUM_TOP_CLASSES; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.RESULTS_FIELD; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.TOKENIZATION; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfig.MAX_ANSWER_LENGTH; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfig.QUESTION; + +public class QuestionAnsweringConfigUpdate extends NlpConfigUpdate implements NamedXContentObject { + + public static final String NAME = "question_answering"; + + public static QuestionAnsweringConfigUpdate fromXContentStrict(XContentParser parser) { + return STRICT_PARSER.apply(parser, null).build(); + } + + @SuppressWarnings({ "unchecked" }) + public static QuestionAnsweringConfigUpdate fromMap(Map map) { + Map options = new HashMap<>(map); + Integer numTopClasses = (Integer) options.remove(NUM_TOP_CLASSES.getPreferredName()); + Integer maxAnswerLength = (Integer) options.remove(MAX_ANSWER_LENGTH.getPreferredName()); + String resultsField = (String) options.remove(RESULTS_FIELD.getPreferredName()); + String question = (String) options.remove(QUESTION.getPreferredName()); + TokenizationUpdate tokenizationUpdate = NlpConfigUpdate.tokenizationFromMap(options); + if (options.isEmpty() == false) { + throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet()); + } + return new QuestionAnsweringConfigUpdate(question, numTopClasses, maxAnswerLength, resultsField, tokenizationUpdate); + } + + @SuppressWarnings({ "unchecked" }) + private static final ObjectParser STRICT_PARSER = new ObjectParser<>( + NAME, + QuestionAnsweringConfigUpdate.Builder::new + ); + + static { + STRICT_PARSER.declareString(Builder::setQuestion, QUESTION); + STRICT_PARSER.declareInt(Builder::setNumTopClasses, NUM_TOP_CLASSES); + STRICT_PARSER.declareInt(Builder::setMaxAnswerLength, MAX_ANSWER_LENGTH); + STRICT_PARSER.declareString(Builder::setResultsField, RESULTS_FIELD); + STRICT_PARSER.declareNamedObject( + Builder::setTokenizationUpdate, + (p, c, n) -> p.namedObject(TokenizationUpdate.class, n, false), + TOKENIZATION + ); + } + + private final String question; + private final Integer numTopClasses; + private final Integer maxAnswerLength; + private final String resultsField; + + public QuestionAnsweringConfigUpdate( + String question, + @Nullable Integer numTopClasses, + @Nullable Integer maxAnswerLength, + @Nullable String resultsField, + @Nullable TokenizationUpdate tokenizationUpdate + ) { + super(tokenizationUpdate); + this.question = ExceptionsHelper.requireNonNull(question, QUESTION); + this.numTopClasses = numTopClasses; + this.maxAnswerLength = maxAnswerLength; + this.resultsField = resultsField; + } + + public QuestionAnsweringConfigUpdate(StreamInput in) throws IOException { + super(in); + question = in.readString(); + numTopClasses = in.readOptionalInt(); + maxAnswerLength = in.readOptionalInt(); + resultsField = in.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(question); + out.writeOptionalInt(numTopClasses); + out.writeOptionalInt(maxAnswerLength); + out.writeOptionalString(resultsField); + } + + @Override + public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { + if (numTopClasses != null) { + builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); + } + if (maxAnswerLength != null) { + builder.field(MAX_ANSWER_LENGTH.getPreferredName(), maxAnswerLength); + } + if (resultsField != null) { + builder.field(RESULTS_FIELD.getPreferredName(), resultsField); + } + builder.field(QUESTION.getPreferredName(), question); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public InferenceConfig apply(InferenceConfig originalConfig) { + if (originalConfig instanceof QuestionAnsweringConfig == false) { + throw ExceptionsHelper.badRequestException( + "Inference config of type [{}] can not be updated with a inference request of type [{}]", + originalConfig.getName(), + getName() + ); + } + + QuestionAnsweringConfig questionAnsweringConfig = (QuestionAnsweringConfig) originalConfig; + return new QuestionAnsweringConfig( + question, + Optional.ofNullable(numTopClasses).orElse(questionAnsweringConfig.getNumTopClasses()), + Optional.ofNullable(maxAnswerLength).orElse(questionAnsweringConfig.getMaxAnswerLength()), + questionAnsweringConfig.getVocabularyConfig(), + tokenizationUpdate == null + ? questionAnsweringConfig.getTokenization() + : tokenizationUpdate.apply(questionAnsweringConfig.getTokenization()), + Optional.ofNullable(resultsField).orElse(questionAnsweringConfig.getResultsField()) + ); + } + + boolean isNoop(QuestionAnsweringConfig originalConfig) { + return (numTopClasses == null || numTopClasses.equals(originalConfig.getNumTopClasses())) + && (maxAnswerLength == null || maxAnswerLength.equals(originalConfig.getMaxAnswerLength())) + && (resultsField == null || resultsField.equals(originalConfig.getResultsField())) + && (question == null || question.equals(originalConfig.getQuestion())) + && super.isNoop(); + } + + @Override + public boolean isSupported(InferenceConfig config) { + return config instanceof QuestionAnsweringConfig; + } + + @Override + public String getResultsField() { + return resultsField; + } + + @Override + public InferenceConfigUpdate.Builder, ? extends InferenceConfigUpdate> newBuilder() { + return new Builder().setQuestion(question) + .setNumTopClasses(numTopClasses) + .setMaxAnswerLength(maxAnswerLength) + .setResultsField(resultsField) + .setTokenizationUpdate(tokenizationUpdate); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + QuestionAnsweringConfigUpdate that = (QuestionAnsweringConfigUpdate) o; + return Objects.equals(numTopClasses, that.numTopClasses) + && Objects.equals(maxAnswerLength, that.maxAnswerLength) + && Objects.equals(question, that.question) + && Objects.equals(resultsField, that.resultsField) + && Objects.equals(tokenizationUpdate, that.tokenizationUpdate); + } + + @Override + public int hashCode() { + return Objects.hash(maxAnswerLength, numTopClasses, resultsField, tokenizationUpdate, question); + } + + public Integer getNumTopClasses() { + return numTopClasses; + } + + public Integer getMaxAnswerLength() { + return maxAnswerLength; + } + + public String getQuestion() { + return question; + } + + public static class Builder + implements + InferenceConfigUpdate.Builder { + private Integer numTopClasses; + private Integer maxAnswerLength; + private String resultsField; + private TokenizationUpdate tokenizationUpdate; + private String question; + + @Override + public QuestionAnsweringConfigUpdate.Builder setResultsField(String resultsField) { + this.resultsField = resultsField; + return this; + } + + public Builder setNumTopClasses(Integer numTopClasses) { + this.numTopClasses = numTopClasses; + return this; + } + + public Builder setMaxAnswerLength(Integer maxAnswerLength) { + this.maxAnswerLength = maxAnswerLength; + return this; + } + + public Builder setTokenizationUpdate(TokenizationUpdate tokenizationUpdate) { + this.tokenizationUpdate = tokenizationUpdate; + return this; + } + + public Builder setQuestion(String question) { + this.question = question; + return this; + } + + @Override + public QuestionAnsweringConfigUpdate build() { + return new QuestionAnsweringConfigUpdate(question, numTopClasses, maxAnswerLength, resultsField, tokenizationUpdate); + } + } + + @Override + public Version getMinimalSupportedVersion() { + return Version.V_8_3_0; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionRequestTests.java index 9659ffbe817f1..c74afa4114adc 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionRequestTests.java @@ -22,6 +22,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigUpdateTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfigUpdateTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdateTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdateTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigUpdate; @@ -63,6 +65,7 @@ private static InferenceConfigUpdate randomInferenceConfigUpdate() { FillMaskConfigUpdateTests.randomUpdate(), ZeroShotClassificationConfigUpdateTests.randomUpdate(), PassThroughConfigUpdateTests.randomUpdate(), + QuestionAnsweringConfigUpdateTests.randomUpdate(), EmptyConfigUpdateTests.testInstance() ); } @@ -102,6 +105,8 @@ protected Request mutateInstanceForVersion(Request instance, Version version) { adjustedUpdate = ZeroShotClassificationConfigUpdateTests.mutateForVersion(update, version); } else if (nlpConfigUpdate instanceof PassThroughConfigUpdate update) { adjustedUpdate = PassThroughConfigUpdateTests.mutateForVersion(update, version); + } else if (nlpConfigUpdate instanceof QuestionAnsweringConfigUpdate update) { + adjustedUpdate = QuestionAnsweringConfigUpdateTests.mutateForVersion(update, version); } else { throw new IllegalArgumentException("Unknown update [" + currentUpdate.getName() + "]"); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionResponseTests.java index 11024bc98aab0..54c885e4b5ba6 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionResponseTests.java @@ -13,9 +13,21 @@ import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests; +import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults; +import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.NerResults; +import org.elasticsearch.xpack.core.ml.inference.results.NerResultsTests; +import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults; +import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResultsTests; +import org.elasticsearch.xpack.core.ml.inference.results.QuestionAnsweringInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.QuestionAnsweringInferenceResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests; +import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResultsTests; +import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResultsTests; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -24,7 +36,16 @@ public class InternalInferModelActionResponseTests extends AbstractWireSerializi @Override protected Response createTestInstance() { - String resultType = randomFrom(ClassificationInferenceResults.NAME, RegressionInferenceResults.NAME); + String resultType = randomFrom( + ClassificationInferenceResults.NAME, + RegressionInferenceResults.NAME, + NerResults.NAME, + TextEmbeddingResults.NAME, + PyTorchPassThroughResults.NAME, + FillMaskResults.NAME, + WarningInferenceResults.NAME, + QuestionAnsweringInferenceResults.NAME + ); return new Response( Stream.generate(() -> randomInferenceResult(resultType)).limit(randomIntBetween(0, 10)).collect(Collectors.toList()), randomAlphaOfLength(10), @@ -33,13 +54,26 @@ protected Response createTestInstance() { } private static InferenceResults randomInferenceResult(String resultType) { - if (resultType.equals(ClassificationInferenceResults.NAME)) { - return ClassificationInferenceResultsTests.createRandomResults(); - } else if (resultType.equals(RegressionInferenceResults.NAME)) { - return RegressionInferenceResultsTests.createRandomResults(); - } else { - fail("unexpected result type [" + resultType + "]"); - return null; + switch (resultType) { + case ClassificationInferenceResults.NAME: + return ClassificationInferenceResultsTests.createRandomResults(); + case RegressionInferenceResults.NAME: + return RegressionInferenceResultsTests.createRandomResults(); + case NerResults.NAME: + return NerResultsTests.createRandomResults(); + case TextEmbeddingResults.NAME: + return TextEmbeddingResultsTests.createRandomResults(); + case PyTorchPassThroughResults.NAME: + return PyTorchPassThroughResultsTests.createRandomResults(); + case FillMaskResults.NAME: + return FillMaskResultsTests.createRandomResults(); + case WarningInferenceResults.NAME: + return WarningInferenceResultsTests.createRandomResults(); + case QuestionAnsweringInferenceResults.NAME: + return QuestionAnsweringInferenceResultsTests.createRandomResults(); + default: + fail("unexpected result type [" + resultType + "]"); + return null; } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceConfigItemTestCase.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceConfigItemTestCase.java index d040a567477cf..79157bcb5ab27 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceConfigItemTestCase.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceConfigItemTestCase.java @@ -7,8 +7,9 @@ package org.elasticsearch.xpack.core.ml.inference; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.io.stream.VersionedNamedWriteable; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.search.SearchModule; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -18,8 +19,12 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.stream.Collectors; -public abstract class InferenceConfigItemTestCase extends AbstractBWCSerializationTestCase { +import static org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase.getAllBWCVersions; + +public abstract class InferenceConfigItemTestCase extends AbstractBWCSerializationTestCase< + T> { @Override protected NamedXContentRegistry xContentRegistry() { List namedXContent = new ArrayList<>(); @@ -33,4 +38,12 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() { List entries = new ArrayList<>(new MlInferenceNamedXContentProvider().getNamedWriteables()); return new NamedWriteableRegistry(entries); } + + @Override + protected List bwcVersions() { + T obj = createTestInstance(); + return getAllBWCVersions(Version.CURRENT).stream() + .filter(v -> v.onOrAfter(obj.getMinimalSupportedVersion())) + .collect(Collectors.toList()); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java index e78a174d493ee..ae7238504ef48 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java @@ -25,13 +25,8 @@ import static org.hamcrest.Matchers.nullValue; public class FillMaskResultsTests extends AbstractWireSerializingTestCase { - @Override - protected Writeable.Reader instanceReader() { - return FillMaskResults::new; - } - @Override - protected FillMaskResults createTestInstance() { + public static FillMaskResults createRandomResults() { int numResults = randomIntBetween(0, 3); List resultList = new ArrayList<>(); for (int i = 0; i < numResults; i++) { @@ -47,6 +42,16 @@ protected FillMaskResults createTestInstance() { ); } + @Override + protected Writeable.Reader instanceReader() { + return FillMaskResults::new; + } + + @Override + protected FillMaskResults createTestInstance() { + return createRandomResults(); + } + @SuppressWarnings("unchecked") public void testAsMap() { FillMaskResults testInstance = createTestInstance(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java index ab0ceea78aae6..c4b4eb9e9c0ff 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java @@ -23,13 +23,8 @@ import static org.hamcrest.Matchers.not; public class NerResultsTests extends InferenceResultsTestCase { - @Override - protected Writeable.Reader instanceReader() { - return NerResults::new; - } - @Override - protected NerResults createTestInstance() { + public static NerResults createRandomResults() { int numEntities = randomIntBetween(0, 3); return new NerResults( @@ -48,6 +43,16 @@ protected NerResults createTestInstance() { ); } + @Override + protected Writeable.Reader instanceReader() { + return NerResults::new; + } + + @Override + protected NerResults createTestInstance() { + return createRandomResults(); + } + @SuppressWarnings("unchecked") public void testAsMap() { NerResults testInstance = createTestInstance(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResultsTests.java index e33b5274231a9..517bd9c210b0d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResultsTests.java @@ -17,13 +17,8 @@ import static org.hamcrest.Matchers.is; public class PyTorchPassThroughResultsTests extends InferenceResultsTestCase { - @Override - protected Writeable.Reader instanceReader() { - return PyTorchPassThroughResults::new; - } - @Override - protected PyTorchPassThroughResults createTestInstance() { + public static PyTorchPassThroughResults createRandomResults() { int rows = randomIntBetween(1, 10); int columns = randomIntBetween(1, 10); double[][] arr = new double[rows][columns]; @@ -36,6 +31,16 @@ protected PyTorchPassThroughResults createTestInstance() { return new PyTorchPassThroughResults(DEFAULT_RESULTS_FIELD, arr, randomBoolean()); } + @Override + protected Writeable.Reader instanceReader() { + return PyTorchPassThroughResults::new; + } + + @Override + protected PyTorchPassThroughResults createTestInstance() { + return createRandomResults(); + } + public void testAsMap() { PyTorchPassThroughResults testInstance = createTestInstance(); Map asMap = testInstance.asMap(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResultsTests.java new file mode 100644 index 0000000000000..be742b831cd77 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResultsTests.java @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.ingest.IngestDocument; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult; +import static org.hamcrest.Matchers.equalTo; + +public class QuestionAnsweringInferenceResultsTests extends InferenceResultsTestCase { + + public static QuestionAnsweringInferenceResults createRandomResults() { + return new QuestionAnsweringInferenceResults( + randomAlphaOfLength(10), + randomInt(1000), + randomInt(1000), + randomBoolean() + ? null + : Stream.generate(TopAnswerEntryTests::createRandomTopAnswerEntry) + .limit(randomIntBetween(0, 10)) + .collect(Collectors.toList()), + randomAlphaOfLength(10), + randomDoubleBetween(0.0, 1.0, false), + randomBoolean() + ); + } + + @SuppressWarnings("unchecked") + public void testWriteResultsWithTopClasses() { + List entries = Arrays.asList( + new QuestionAnsweringInferenceResults.TopAnswerEntry("foo", 0.7, 0, 3), + new QuestionAnsweringInferenceResults.TopAnswerEntry("bar", 0.2, 11, 14), + new QuestionAnsweringInferenceResults.TopAnswerEntry("baz", 0.1, 4, 7) + ); + QuestionAnsweringInferenceResults result = new QuestionAnsweringInferenceResults( + "foo", + 0, + 3, + entries, + "my_results", + 0.7, + randomBoolean() + ); + IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); + writeResult(result, document, "result_field", "test"); + + List list = document.getFieldValue("result_field.top_classes", List.class); + assertThat(list.size(), equalTo(3)); + + for (int i = 0; i < 3; i++) { + Map map = (Map) list.get(i); + assertThat(map, equalTo(entries.get(i).asValueMap())); + } + + assertThat(document.getFieldValue("result_field.my_results", String.class), equalTo("foo")); + } + + @Override + protected QuestionAnsweringInferenceResults createTestInstance() { + return createRandomResults(); + } + + @Override + protected Writeable.Reader instanceReader() { + return QuestionAnsweringInferenceResults::new; + } + + @Override + void assertFieldValues(QuestionAnsweringInferenceResults createdInstance, IngestDocument document, String resultsField) { + String path = resultsField + "." + createdInstance.getResultsField(); + assertThat(document.getFieldValue(path, String.class), equalTo(createdInstance.predictedValue())); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java index c255c3de8cbfd..90a8cce88237b 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java @@ -17,13 +17,8 @@ import static org.hamcrest.Matchers.is; public class TextEmbeddingResultsTests extends InferenceResultsTestCase { - @Override - protected Writeable.Reader instanceReader() { - return TextEmbeddingResults::new; - } - @Override - protected TextEmbeddingResults createTestInstance() { + public static TextEmbeddingResults createRandomResults() { int columns = randomIntBetween(1, 10); double[] arr = new double[columns]; for (int i = 0; i < columns; i++) { @@ -33,6 +28,16 @@ protected TextEmbeddingResults createTestInstance() { return new TextEmbeddingResults(DEFAULT_RESULTS_FIELD, arr, randomBoolean()); } + @Override + protected Writeable.Reader instanceReader() { + return TextEmbeddingResults::new; + } + + @Override + protected TextEmbeddingResults createTestInstance() { + return createRandomResults(); + } + public void testAsMap() { TextEmbeddingResults testInstance = createTestInstance(); Map asMap = testInstance.asMap(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TopAnswerEntryTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TopAnswerEntryTests.java new file mode 100644 index 0000000000000..345d410210352 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TopAnswerEntryTests.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; + +public class TopAnswerEntryTests extends AbstractSerializingTestCase { + + public static QuestionAnsweringInferenceResults.TopAnswerEntry createRandomTopAnswerEntry() { + return new QuestionAnsweringInferenceResults.TopAnswerEntry(randomAlphaOfLength(10), randomDouble(), randomInt(10), randomInt(400)); + } + + @Override + protected QuestionAnsweringInferenceResults.TopAnswerEntry doParseInstance(XContentParser parser) throws IOException { + return QuestionAnsweringInferenceResults.TopAnswerEntry.fromXContent(parser); + } + + @Override + protected Writeable.Reader instanceReader() { + return QuestionAnsweringInferenceResults.TopAnswerEntry::fromStream; + } + + @Override + protected QuestionAnsweringInferenceResults.TopAnswerEntry createTestInstance() { + return createRandomTopAnswerEntry(); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigTests.java new file mode 100644 index 0000000000000..2ad335d3cf4b0 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigTests.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase; + +import java.io.IOException; +import java.util.function.Predicate; + +public class QuestionAnsweringConfigTests extends InferenceConfigItemTestCase { + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> field.isEmpty() == false; + } + + @Override + protected QuestionAnsweringConfig doParseInstance(XContentParser parser) throws IOException { + return QuestionAnsweringConfig.fromXContentLenient(parser); + } + + @Override + protected Writeable.Reader instanceReader() { + return QuestionAnsweringConfig::new; + } + + @Override + protected QuestionAnsweringConfig createTestInstance() { + return createRandom(); + } + + @Override + protected QuestionAnsweringConfig mutateInstanceForVersion(QuestionAnsweringConfig instance, Version version) { + return instance; + } + + public static QuestionAnsweringConfig createRandom() { + return new QuestionAnsweringConfig( + randomBoolean() ? null : randomIntBetween(0, 30), + randomBoolean() ? null : randomIntBetween(1, 50), + randomBoolean() ? null : VocabularyConfigTests.createRandom(), + randomBoolean() + ? null + : randomFrom( + BertTokenizationTests.createRandomWithSpan(), + MPNetTokenizationTests.createRandomWithSpan(), + RobertaTokenizationTests.createRandomWithSpan() + ), + randomBoolean() ? null : randomAlphaOfLength(7) + ); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigUpdateTests.java new file mode 100644 index 0000000000000..18c873a5db26b --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigUpdateTests.java @@ -0,0 +1,184 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigTestScaffolding.cloneWithNewTruncation; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigTestScaffolding.createTokenizationUpdate; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.NUM_TOP_CLASSES; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfig.QUESTION; +import static org.hamcrest.Matchers.equalTo; + +public class QuestionAnsweringConfigUpdateTests extends AbstractNlpConfigUpdateTestCase { + + public static QuestionAnsweringConfigUpdate randomUpdate() { + return new QuestionAnsweringConfigUpdate( + randomAlphaOfLength(10), + randomBoolean() ? null : randomIntBetween(0, 15), + randomBoolean() ? null : randomIntBetween(1, 100), + randomBoolean() ? null : randomAlphaOfLength(5), + randomBoolean() ? null : new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null) + ); + } + + public static QuestionAnsweringConfigUpdate mutateForVersion(QuestionAnsweringConfigUpdate instance, Version version) { + if (version.before(Version.V_8_1_0)) { + return new QuestionAnsweringConfigUpdate( + instance.getQuestion(), + instance.getNumTopClasses(), + instance.getMaxAnswerLength(), + instance.getResultsField(), + null + ); + } + return instance; + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + protected QuestionAnsweringConfigUpdate doParseInstance(XContentParser parser) throws IOException { + return QuestionAnsweringConfigUpdate.fromXContentStrict(parser); + } + + @Override + protected Writeable.Reader instanceReader() { + return QuestionAnsweringConfigUpdate::new; + } + + @Override + protected QuestionAnsweringConfigUpdate createTestInstance() { + return createRandom(); + } + + @Override + protected QuestionAnsweringConfigUpdate mutateInstanceForVersion(QuestionAnsweringConfigUpdate instance, Version version) { + return mutateForVersion(instance, version); + } + + @Override + Tuple, QuestionAnsweringConfigUpdate> fromMapTestInstances(TokenizationUpdate expectedTokenization) { + QuestionAnsweringConfigUpdate expected = new QuestionAnsweringConfigUpdate( + "What is the meaning of life?", + 3, + 20, + "ml-results", + expectedTokenization + ); + + Map config = new HashMap<>() { + { + put(QUESTION.getPreferredName(), "What is the meaning of life?"); + put(NUM_TOP_CLASSES.getPreferredName(), 3); + put(QuestionAnsweringConfig.MAX_ANSWER_LENGTH.getPreferredName(), 20); + put(QuestionAnsweringConfig.RESULTS_FIELD.getPreferredName(), "ml-results"); + } + }; + return Tuple.tuple(config, expected); + } + + @Override + QuestionAnsweringConfigUpdate fromMap(Map map) { + return QuestionAnsweringConfigUpdate.fromMap(map); + } + + public void testApply() { + Tokenization tokenizationConfig = randomFrom( + BertTokenizationTests.createRandom(), + MPNetTokenizationTests.createRandom(), + RobertaTokenizationTests.createRandom() + ); + QuestionAnsweringConfig originalConfig = new QuestionAnsweringConfig( + randomBoolean() ? null : randomIntBetween(-1, 10), + randomBoolean() ? null : randomIntBetween(1, 20), + randomBoolean() ? null : VocabularyConfigTests.createRandom(), + tokenizationConfig, + randomBoolean() ? null : randomAlphaOfLength(8) + ); + assertThat( + new QuestionAnsweringConfig( + "Are you my mother?", + 4, + 40, + originalConfig.getVocabularyConfig(), + originalConfig.getTokenization(), + originalConfig.getResultsField() + ), + equalTo( + new QuestionAnsweringConfigUpdate.Builder().setQuestion("Are you my mother?") + .setNumTopClasses(4) + .setMaxAnswerLength(40) + .build() + .apply(originalConfig) + ) + ); + assertThat( + new QuestionAnsweringConfig( + "Are you my mother?", + originalConfig.getNumTopClasses(), + originalConfig.getMaxAnswerLength(), + originalConfig.getVocabularyConfig(), + originalConfig.getTokenization(), + "updated-field" + ), + equalTo( + new QuestionAnsweringConfigUpdate.Builder().setQuestion("Are you my mother?") + .setResultsField("updated-field") + .build() + .apply(originalConfig) + ) + ); + + Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values()); + Tokenization tokenization = cloneWithNewTruncation(originalConfig.getTokenization(), truncate); + assertThat( + new QuestionAnsweringConfig( + "Are you my mother?", + originalConfig.getNumTopClasses(), + originalConfig.getMaxAnswerLength(), + originalConfig.getVocabularyConfig(), + tokenization, + originalConfig.getResultsField() + ), + equalTo( + new QuestionAnsweringConfigUpdate.Builder().setQuestion("Are you my mother?") + .setTokenizationUpdate(createTokenizationUpdate(originalConfig.getTokenization(), truncate, null)) + .build() + .apply(originalConfig) + ) + ); + } + + public static QuestionAnsweringConfigUpdate createRandom() { + return randomUpdate(); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables()); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index ab0da27eb526e..81e35e984f83d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -42,6 +42,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig; @@ -402,6 +404,9 @@ InferenceConfigUpdate inferenceConfigUpdateFromMap(Map configMap } else if (configMap.containsKey(ZeroShotClassificationConfig.NAME)) { checkNlpSupported(ZeroShotClassificationConfig.NAME); return ZeroShotClassificationConfigUpdate.fromMap(valueMap); + } else if (configMap.containsKey(QuestionAnsweringConfig.NAME)) { + checkNlpSupported(QuestionAnsweringConfig.NAME); + return QuestionAnsweringConfigUpdate.fromMap(valueMap); } else { throw ExceptionsHelper.badRequestException( "unrecognized inference configuration type {}. Supported types {}", diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java index aa9a658976e3a..a866d691cb7d1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java @@ -132,7 +132,7 @@ static InferenceResults processResult( String predictedValue = tokenization.decode(tokenization.getFromVocab(scoreAndIndices[0].index)); return new FillMaskResults( predictedValue, - tokenization.getTokenization(0).input().replace(tokenizer.getMaskToken(), predictedValue), + tokenization.getTokenization(0).input().get(0).replace(tokenizer.getMaskToken(), predictedValue), results, Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD), scoreAndIndices[0].score, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java index e8c7253d3c5d2..dfdce1f31b34c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java @@ -207,12 +207,14 @@ public InferenceResults processResult(TokenizationResult tokenization, PyTorchIn List entities = groupTaggedTokens( taggedTokens, - ignoreCase ? tokenization.getTokenization(0).input().toLowerCase(Locale.ROOT) : tokenization.getTokenization(0).input() + ignoreCase + ? tokenization.getTokenization(0).input().get(0).toLowerCase(Locale.ROOT) + : tokenization.getTokenization(0).input().get(0) ); return new NerResults( resultsField, - buildAnnotatedText(tokenization.getTokenization(0).input(), entities), + buildAnnotatedText(tokenization.getTokenization(0).input().get(0), entities), entities, tokenization.anyTruncated() ); @@ -255,7 +257,7 @@ static List tagTokens(TokenizationResult.Tokens tokenization, doubl int maxScoreIndex = NlpHelpers.argmax(avgScores); double score = avgScores[maxScoreIndex]; taggedTokens.add( - new TaggedToken(tokenization.tokens().get(startTokenIndex - numSpecialTokens), iobMap[maxScoreIndex], score) + new TaggedToken(tokenization.tokens().get(0).get(startTokenIndex - numSpecialTokens), iobMap[maxScoreIndex], score) ); startTokenIndex = endTokenIndex + 1; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessor.java new file mode 100644 index 0000000000000..c29e171314af1 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessor.java @@ -0,0 +1,272 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp; + +import org.apache.lucene.util.PriorityQueue; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.QuestionAnsweringInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult; +import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.function.Consumer; +import java.util.function.IntPredicate; + +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; + +public class QuestionAnsweringProcessor extends NlpTask.Processor { + + QuestionAnsweringProcessor(NlpTokenizer tokenizer, QuestionAnsweringConfig config) { + super(tokenizer); + } + + @Override + public void validateInputs(List inputs) { + // nothing to validate + } + + @Override + public NlpTask.RequestBuilder getRequestBuilder(NlpConfig nlpConfig) { + if (nlpConfig instanceof QuestionAnsweringConfig questionAnsweringConfig) { + return new RequestBuilder(tokenizer, questionAnsweringConfig.getQuestion()); + } + throw ExceptionsHelper.badRequestException( + "please provide configuration update for question_answering task including the desired [question]" + ); + } + + @Override + public NlpTask.ResultProcessor getResultProcessor(NlpConfig nlpConfig) { + if (nlpConfig instanceof QuestionAnsweringConfig questionAnsweringConfig) { + int maxAnswerLength = questionAnsweringConfig.getMaxAnswerLength(); + int numTopClasses = questionAnsweringConfig.getNumTopClasses(); + String resultsFieldValue = questionAnsweringConfig.getResultsField(); + return new ResultProcessor(questionAnsweringConfig.getQuestion(), maxAnswerLength, numTopClasses, resultsFieldValue); + } + throw ExceptionsHelper.badRequestException( + "please provide configuration update for question_answering task including the desired [question]" + ); + } + + record RequestBuilder(NlpTokenizer tokenizer, String question) implements NlpTask.RequestBuilder { + + @Override + public NlpTask.Request buildRequest(List inputs, String requestId, Tokenization.Truncate truncate, int span) + throws IOException { + if (inputs.size() > 1) { + throw ExceptionsHelper.badRequestException("Unable to do question answering on more than one text input at a time"); + } + String context = inputs.get(0); + List tokenizations = tokenizer.tokenize(question, context, truncate, span, 0); + TokenizationResult result = tokenizer.buildTokenizationResult(tokenizations); + return result.buildRequest(requestId, truncate); + } + } + + record ResultProcessor(String question, int maxAnswerLength, int numTopClasses, String resultsField) + implements + NlpTask.ResultProcessor { + + @Override + public InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult) { + if (pyTorchResult.getInferenceResult().length < 1) { + throw new ElasticsearchStatusException("question answering result has no data", RestStatus.INTERNAL_SERVER_ERROR); + } + // Should be a collection of "starts" and "ends" + if (pyTorchResult.getInferenceResult().length != 2) { + throw new ElasticsearchStatusException( + "question answering result has invalid dimension, expected 2 found [{}]", + RestStatus.INTERNAL_SERVER_ERROR, + pyTorchResult.getInferenceResult().length + ); + } + double[][] starts = pyTorchResult.getInferenceResult()[0]; + double[][] ends = pyTorchResult.getInferenceResult()[1]; + if (starts.length != ends.length) { + throw new ElasticsearchStatusException( + "question answering result has invalid dimensions; start positions [{}] must equal potential end [{}]", + RestStatus.INTERNAL_SERVER_ERROR, + starts.length, + ends.length + ); + } + List tokensList = tokenization.getTokensBySequenceId().get(0); + if (starts.length != tokensList.size()) { + throw new ElasticsearchStatusException( + "question answering result has invalid dimensions; start positions number [{}] equal batched token size [{}]", + RestStatus.INTERNAL_SERVER_ERROR, + starts.length, + tokensList.size() + ); + } + final int numAnswersToGather = Math.max(numTopClasses, 1); + + ScoreAndIndicesPriorityQueue finalEntries = new ScoreAndIndicesPriorityQueue(numAnswersToGather); + for (int i = 0; i < starts.length; i++) { + topScores( + starts[i], + ends[i], + numAnswersToGather, + finalEntries::insertWithOverflow, + tokensList.get(i).seqPairOffset(), + tokensList.get(i).tokenIds().length, + maxAnswerLength, + i + ); + } + QuestionAnsweringInferenceResults.TopAnswerEntry[] topAnswerList = + new QuestionAnsweringInferenceResults.TopAnswerEntry[numAnswersToGather]; + for (int i = numAnswersToGather - 1; i >= 0; i--) { + ScoreAndIndices scoreAndIndices = finalEntries.pop(); + TokenizationResult.Tokens tokens = tokensList.get(scoreAndIndices.spanIndex()); + int startOffset = tokens.tokens().get(1).get(scoreAndIndices.startToken).startOffset(); + int endOffset = tokens.tokens().get(1).get(scoreAndIndices.endToken).endOffset(); + topAnswerList[i] = new QuestionAnsweringInferenceResults.TopAnswerEntry( + tokens.input().get(1).substring(startOffset, endOffset), + scoreAndIndices.score(), + startOffset, + endOffset + ); + } + QuestionAnsweringInferenceResults.TopAnswerEntry finalAnswer = topAnswerList[0]; + return new QuestionAnsweringInferenceResults( + finalAnswer.answer(), + finalAnswer.startOffset(), + finalAnswer.endOffset(), + numTopClasses > 0 ? Arrays.asList(topAnswerList) : List.of(), + Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD), + finalAnswer.score(), + tokenization.anyTruncated() + ); + } + } + + /** + * + * @param start The starting token index scores. May include padded tokens. + * @param end The ending token index scores. May include padded tokens. + * @param numAnswersToGather How many top answers to return + * @param topScoresCollector Called when a score is collected. May be called many more times than numAnswersToGather + * @param seq2Start The token position of where the context sequence starts. This is AFTER the sequence separation special tokens. + * @param tokenSize The true total tokenization size. This should NOT include padded tokens. + * @param maxAnswerLength The maximum answer length to consider. + * @param spanIndex Which sequence span is this. + */ + static void topScores( + double[] start, + double[] end, + int numAnswersToGather, + Consumer topScoresCollector, + int seq2Start, + int tokenSize, + int maxAnswerLength, + int spanIndex + ) { + if (start.length != end.length) { + throw new ElasticsearchStatusException( + "question answering result has invalid dimensions; possible start tokens [{}] must equal possible end tokens [{}]", + RestStatus.INTERNAL_SERVER_ERROR, + start.length, + end.length + ); + } + // This needs to be the start of the second sequence skipping the separator tokens + // Example seq1 seq2, seq2Start should be (len(seq1) + 2) + // This predicate ensures the following + // - we include the cls token + // - we exclude the first sequence, which is always the question + // - we exclude the final token, which is a sep token + double[] startNormalized = normalizeWith(start, i -> { + if (i == 0) { + return true; + } + return i >= seq2Start && i < tokenSize - 1; + }, -10000.0); + double[] endNormalized = normalizeWith(end, i -> { + if (i == 0) { + return true; + } + return i >= seq2Start && i < tokenSize - 1; + }, -10000.0); + // We use CLS in the softmax, but then remove it from being considered a possible position + endNormalized[0] = startNormalized[0] = 0.0; + if (numAnswersToGather == 1) { + ScoreAndIndices toReturn = new ScoreAndIndices(0, 0, 0.0, spanIndex); + double maxScore = 0.0; + for (int i = seq2Start; i < tokenSize; i++) { + if (startNormalized[i] == 0) { + continue; + } + for (int j = i + 1; j < (maxAnswerLength + i) && j < tokenSize; j++) { + double score = startNormalized[i] * endNormalized[j]; + if (score > maxScore) { + maxScore = score; + toReturn = new ScoreAndIndices(i - seq2Start, j - seq2Start, score, spanIndex); + } + } + } + topScoresCollector.accept(toReturn); + return; + } + for (int i = seq2Start; i < tokenSize; i++) { + for (int j = i + 1; j < (maxAnswerLength + i) && j < tokenSize; j++) { + topScoresCollector.accept( + new ScoreAndIndices(i - seq2Start, j - seq2Start, startNormalized[i] * endNormalized[j], spanIndex) + ); + } + } + } + + static double[] normalizeWith(double[] values, IntPredicate mutateIndex, double predicateValue) { + double[] toReturn = new double[values.length]; + for (int i = 0; i < values.length; i++) { + toReturn[i] = values[i]; + if (mutateIndex.test(i) == false) { + toReturn[i] = predicateValue; + } + } + double expSum = 0.0; + for (double v : toReturn) { + expSum += Math.exp(v); + } + double diff = Math.log(expSum); + for (int i = 0; i < toReturn.length; i++) { + toReturn[i] = Math.exp(toReturn[i] - diff); + } + return toReturn; + } + + static class ScoreAndIndicesPriorityQueue extends PriorityQueue { + + ScoreAndIndicesPriorityQueue(int maxSize) { + super(maxSize); + } + + @Override + protected boolean lessThan(ScoreAndIndices a, ScoreAndIndices b) { + return a.compareTo(b) < 0; + } + } + + record ScoreAndIndices(int startToken, int endToken, double score, int spanIndex) implements Comparable { + @Override + public int compareTo(ScoreAndIndices o) { + return Double.compare(score, o.score); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TaskType.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TaskType.java index 2f011600faa86..32c0ded38b34c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TaskType.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TaskType.java @@ -11,6 +11,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig; @@ -55,6 +56,12 @@ public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig confi public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) { return new ZeroShotClassificationProcessor(tokenizer, (ZeroShotClassificationConfig) config); } + }, + QUESTION_ANSWERING { + @Override + public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) { + return new QuestionAnsweringProcessor(tokenizer, (QuestionAnsweringConfig) config); + } }; public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizationResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizationResult.java index 2f22e8fe982d8..b9334e7911886 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizationResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizationResult.java @@ -53,6 +53,7 @@ static class BertTokensBuilder implements TokensBuilder { protected final boolean withSpecialTokens; protected final int clsTokenId; protected final int sepTokenId; + protected int seqPairOffset = 0; BertTokensBuilder(boolean withSpecialTokens, int clsTokenId, int sepTokenId) { this.withSpecialTokens = withSpecialTokens; @@ -95,6 +96,7 @@ public TokensBuilder addSequencePair( tokenIds.add(IntStream.of(sepTokenId)); tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION)); } + seqPairOffset = withSpecialTokens ? tokenId1s.size() + 2 : tokenId1s.size(); tokenIds.add(tokenId2s.stream().mapToInt(Integer::valueOf)); tokenMap.add(tokenMap2.stream().mapToInt(i -> i + previouslyFinalMap)); if (withSpecialTokens) { @@ -105,7 +107,13 @@ public TokensBuilder addSequencePair( } @Override - public Tokens build(String input, boolean truncated, List allTokens, int spanPrev, int seqId) { + public Tokens build( + List input, + boolean truncated, + List> allTokens, + int spanPrev, + int seqId + ) { return new Tokens( input, allTokens, @@ -113,7 +121,8 @@ public Tokens build(String input, boolean truncated, List i + previouslyFinalMap)); if (withSpecialTokens) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java index 53a42b08966af..3cdb3b702ae2d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java @@ -98,7 +98,8 @@ public List tokenize(String seq, Tokenization.Truncat ); // Make sure we do not end on a word if (splitEndPos != tokenIds.size()) { - while (Objects.equals(tokenPositionMap.get(splitEndPos), tokenPositionMap.get(splitEndPos - 1))) { + while (splitEndPos > splitStartPos + 1 + && Objects.equals(tokenPositionMap.get(splitEndPos), tokenPositionMap.get(splitEndPos - 1))) { splitEndPos--; } } @@ -110,7 +111,7 @@ public List tokenize(String seq, Tokenization.Truncat .map(DelimitedToken.Encoded::getEncoding) .collect(Collectors.toList()), tokenPositionMap.subList(splitStartPos, splitEndPos) - ).build(seq, false, innerResult.tokens, spanPrev, sequenceId) + ).build(seq, false, tokenIds.subList(splitStartPos, splitEndPos), spanPrev, sequenceId) ); spanPrev = span; int prevSplitStart = splitStartPos; @@ -207,14 +208,134 @@ public TokenizationResult.Tokens tokenize( ); } } - List tokens = new ArrayList<>(innerResultSeq1.tokens); - tokens.addAll(innerResultSeq2.tokens); return createTokensBuilder(clsTokenId(), sepTokenId(), isWithSpecialTokens()).addSequencePair( tokenIdsSeq1.stream().map(DelimitedToken.Encoded::getEncoding).collect(Collectors.toList()), tokenPositionMapSeq1, tokenIdsSeq2.stream().map(DelimitedToken.Encoded::getEncoding).collect(Collectors.toList()), tokenPositionMapSeq2 - ).build(seq1 + seq2, isTruncated, tokens, -1, sequenceId); + ).build(List.of(seq1, seq2), isTruncated, List.of(innerResultSeq1.tokens, innerResultSeq2.tokens), -1, sequenceId); + } + + /** + * Tokenize the two sequences, allowing for spanning of the 2nd sequence + * @param seq1 The first sequence in the pair + * @param seq2 The second sequence + * @param truncate truncate settings + * @param span the spanning settings, how many tokens to overlap. + * We split and span on seq2. + * @param sequenceId Unique sequence id for this tokenization + * @return tokenization result for the sequence pair + */ + public List tokenize(String seq1, String seq2, Tokenization.Truncate truncate, int span, int sequenceId) { + var innerResultSeq1 = innerTokenize(seq1); + List tokenIdsSeq1 = innerResultSeq1.tokens; + List tokenPositionMapSeq1 = innerResultSeq1.tokenPositionMap; + var innerResultSeq2 = innerTokenize(seq2); + List tokenIdsSeq2 = innerResultSeq2.tokens; + List tokenPositionMapSeq2 = innerResultSeq2.tokenPositionMap; + if (isWithSpecialTokens() == false) { + throw new IllegalArgumentException("Unable to do sequence pair tokenization without special tokens"); + } + int extraTokens = getNumExtraTokensForSeqPair(); + int numTokens = tokenIdsSeq1.size() + tokenIdsSeq2.size() + extraTokens; + + boolean isTruncated = false; + if (numTokens > maxSequenceLength() && span < 0) { + switch (truncate) { + case FIRST -> { + isTruncated = true; + if (tokenIdsSeq2.size() > maxSequenceLength() - extraTokens) { + throw ExceptionsHelper.badRequestException( + "Attempting truncation [{}] but input is too large for the second sequence. " + + "The tokenized input length [{}] exceeds the maximum sequence length [{}], " + + "when taking special tokens into account", + truncate.toString(), + tokenIdsSeq2.size(), + maxSequenceLength() - extraTokens + ); + } + tokenIdsSeq1 = tokenIdsSeq1.subList(0, maxSequenceLength() - extraTokens - tokenIdsSeq2.size()); + tokenPositionMapSeq1 = tokenPositionMapSeq1.subList(0, maxSequenceLength() - extraTokens - tokenIdsSeq2.size()); + } + case SECOND -> { + isTruncated = true; + if (tokenIdsSeq1.size() > maxSequenceLength() - extraTokens) { + throw ExceptionsHelper.badRequestException( + "Attempting truncation [{}] but input is too large for the first sequence. " + + "The tokenized input length [{}] exceeds the maximum sequence length [{}], " + + "when taking special tokens into account", + truncate.toString(), + tokenIdsSeq1.size(), + maxSequenceLength() - extraTokens + ); + } + tokenIdsSeq2 = tokenIdsSeq2.subList(0, maxSequenceLength() - extraTokens - tokenIdsSeq1.size()); + tokenPositionMapSeq2 = tokenPositionMapSeq2.subList(0, maxSequenceLength() - extraTokens - tokenIdsSeq1.size()); + } + case NONE -> throw ExceptionsHelper.badRequestException( + "Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]", + numTokens, + maxSequenceLength() + ); + } + } + if (isTruncated || numTokens < maxSequenceLength()) {// indicates no spanning + return List.of( + createTokensBuilder(clsTokenId(), sepTokenId(), isWithSpecialTokens()).addSequencePair( + tokenIdsSeq1.stream().map(DelimitedToken.Encoded::getEncoding).collect(Collectors.toList()), + tokenPositionMapSeq1, + tokenIdsSeq2.stream().map(DelimitedToken.Encoded::getEncoding).collect(Collectors.toList()), + tokenPositionMapSeq2 + ).build(List.of(seq1, seq2), isTruncated, List.of(innerResultSeq1.tokens, innerResultSeq2.tokens), -1, sequenceId) + ); + } + List toReturn = new ArrayList<>(); + int splitEndPos = 0; + int splitStartPos = 0; + int spanPrev = -1; + List seq1TokenIds = tokenIdsSeq1.stream().map(DelimitedToken.Encoded::getEncoding).collect(Collectors.toList()); + + final int trueMaxSeqLength = maxSequenceLength() - extraTokens - tokenIdsSeq1.size(); + while (splitEndPos < tokenIdsSeq2.size()) { + splitEndPos = Math.min(splitStartPos + trueMaxSeqLength, tokenIdsSeq2.size()); + // Make sure we do not end on a word + if (splitEndPos != tokenIdsSeq2.size()) { + while (splitEndPos > splitStartPos + 1 + && Objects.equals(tokenPositionMapSeq2.get(splitEndPos), tokenPositionMapSeq2.get(splitEndPos - 1))) { + splitEndPos--; + } + } + toReturn.add( + createTokensBuilder(clsTokenId(), sepTokenId(), isWithSpecialTokens()).addSequencePair( + seq1TokenIds, + tokenPositionMapSeq1, + tokenIdsSeq2.subList(splitStartPos, splitEndPos) + .stream() + .map(DelimitedToken.Encoded::getEncoding) + .collect(Collectors.toList()), + tokenPositionMapSeq2.subList(splitStartPos, splitEndPos) + ) + .build( + List.of(seq1, seq2), + false, + List.of(tokenIdsSeq1, tokenIdsSeq2.subList(splitStartPos, splitEndPos)), + spanPrev, + sequenceId + ) + ); + spanPrev = span; + int prevSplitStart = splitStartPos; + splitStartPos = splitEndPos - span; + // try to back up our split so that it starts at the first whole word + if (splitStartPos < tokenIdsSeq2.size()) { + while (splitStartPos > (prevSplitStart + 1) + && Objects.equals(tokenPositionMapSeq2.get(splitStartPos), tokenPositionMapSeq2.get(splitStartPos - 1))) { + splitStartPos--; + spanPrev++; + } + } + } + return toReturn; } public abstract NlpTask.RequestBuilder requestBuilder(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/RobertaTokenizationResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/RobertaTokenizationResult.java index 6167b63c34f4c..3b466776bc7ae 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/RobertaTokenizationResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/RobertaTokenizationResult.java @@ -77,6 +77,7 @@ static class RobertaTokensBuilder implements TokensBuilder { protected final boolean withSpecialTokens; protected final int clsTokenId; protected final int sepTokenId; + protected int seqPairOffset = 0; RobertaTokensBuilder(boolean withSpecialTokens, int clsTokenId, int sepTokenId) { this.withSpecialTokens = withSpecialTokens; @@ -120,6 +121,7 @@ public TokensBuilder addSequencePair( tokenIds.add(IntStream.of(sepTokenId, sepTokenId)); tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION, SPECIAL_TOKEN_POSITION)); } + seqPairOffset = withSpecialTokens ? tokenId1s.size() + 3 : tokenId1s.size(); tokenIds.add(tokenId2s.stream().mapToInt(Integer::valueOf)); tokenMap.add(tokenMap2.stream().mapToInt(i -> i + previouslyFinalMap)); if (withSpecialTokens) { @@ -130,7 +132,13 @@ public TokensBuilder addSequencePair( } @Override - public Tokens build(String input, boolean truncated, List allTokens, int spanPrev, int seqId) { + public Tokens build( + List input, + boolean truncated, + List> allTokens, + int spanPrev, + int seqId + ) { return new Tokens( input, allTokens, @@ -138,7 +146,8 @@ public Tokens build(String input, boolean truncated, List tokens, + List input, + List> tokens, boolean truncated, int[] tokenIds, int[] tokenMap, int spanPrev, - int sequenceId + int sequenceId, + int seqPairOffset ) { /** * - * @param input The whole sequence input + * @param input The sequence inputs * @param tokens The delimited tokens (includes original text offsets) * @param truncated Was this tokenization truncated * @param tokenIds The token ids * @param tokenMap The token positions * @param spanPrev How many of the previous sub-sequence does this tokenization include * @param sequenceId A unique sequence ID to allow sub-sequence reconstitution + * @param seqPairOffset if the tokenization is of a sequence pair, when does the second sequence start? + * This does take into account separator token ids. Meaning, the offset will indicate the actual + * start of the second sequence of the pair. */ public Tokens { assert tokenIds.length == tokenMap.length; @@ -186,13 +202,26 @@ interface TokensBuilder { /** * Builds the token object - * @param input the original sequence input, may be a simple concatenation of a sequence pair + * @param input the original sequences input, may be a single sequence or a pair of sequences * @param truncated Was this truncated when tokenized - * @param allTokens All the tokens with their values and offsets + * @param allTokens The tokens with their values and offsets. Should match relatively to the input provided * @param spanPrev how many tokens from the previous subsequence are in this one. Only relevant when windowing * @param seqId the sequence id, unique per tokenized sequence, useful for windowing * @return A new Tokens object */ - Tokens build(String input, boolean truncated, List allTokens, int spanPrev, int seqId); + Tokens build(List input, boolean truncated, List> allTokens, int spanPrev, int seqId); + + /** + * Build the token object accounting for a single tokenized sequence + * @param input the original sequence input + * @param truncated Was this truncated when tokenized + * @param allTokens The tokens with their values and offsets + * @param spanPrev how many tokens from the previous subsequence are in this one. Only relevant when windowing + * @param seqId the sequence id, unique per tokenized sequence, useful for windowing + * @return A new Tokens object + */ + default Tokens build(String input, boolean truncated, List allTokens, int spanPrev, int seqId) { + return build(List.of(input), truncated, List.of(allTokens), spanPrev, seqId); + } } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java index ca0a146e614ec..6420c4e80dda4 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.core.Tuple; import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.ingest.PipelineConfiguration; import org.elasticsearch.test.ESTestCase; @@ -37,6 +38,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig; @@ -391,17 +393,21 @@ public void testCreateProcessorWithDuplicateFields() { public void testParseFromMap() { InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, clusterService, Settings.EMPTY); - for (String name : List.of( - ClassificationConfig.NAME.getPreferredName(), - RegressionConfig.NAME.getPreferredName(), - FillMaskConfig.NAME, - NerConfig.NAME, - PassThroughConfig.NAME, - TextClassificationConfig.NAME, - TextEmbeddingConfig.NAME, - ZeroShotClassificationConfig.NAME + for (var nameAndMap : List.of( + Tuple.tuple(ClassificationConfig.NAME.getPreferredName(), Map.of()), + Tuple.tuple(RegressionConfig.NAME.getPreferredName(), Map.of()), + Tuple.tuple(FillMaskConfig.NAME, Map.of()), + Tuple.tuple(NerConfig.NAME, Map.of()), + Tuple.tuple(PassThroughConfig.NAME, Map.of()), + Tuple.tuple(TextClassificationConfig.NAME, Map.of()), + Tuple.tuple(TextEmbeddingConfig.NAME, Map.of()), + Tuple.tuple(ZeroShotClassificationConfig.NAME, Map.of()), + Tuple.tuple(QuestionAnsweringConfig.NAME, Map.of("question", "What is the answer to life, the universe and everything?")) )) { - assertThat(processorFactory.inferenceConfigUpdateFromMap(Map.of(name, Map.of())).getName(), equalTo(name)); + assertThat( + processorFactory.inferenceConfigUpdateFromMap(Map.of(nameAndMap.v1(), nameAndMap.v2())).getName(), + equalTo(nameAndMap.v1()) + ); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java index 77c6f2f1a2c46..8d62f0239466f 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java @@ -55,7 +55,7 @@ public void testProcessResults() { TokenizationResult tokenization = new BertTokenizationResult( vocab, - List.of(new TokenizationResult.Tokens(input, tokens, false, tokenIds, tokenMap, -1, 0)), + List.of(new TokenizationResult.Tokens(List.of(input), List.of(tokens), false, tokenIds, tokenMap, -1, 0, 0)), 0 ); @@ -89,7 +89,7 @@ public void testProcessResults_GivenMissingTokens() { TokenizationResult tokenization = new BertTokenizationResult( List.of(), - List.of(new TokenizationResult.Tokens("", List.of(), false, new int[0], new int[0], -1, 0)), + List.of(new TokenizationResult.Tokens(List.of(""), List.of(), false, new int[0], new int[0], -1, 0, 0)), 0 ); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java new file mode 100644 index 0000000000000..f6d8dcccb9c26 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java @@ -0,0 +1,171 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.results.QuestionAnsweringInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult; +import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizerTests.TEST_CASED_VOCAB; +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +public class QuestionAnsweringProcessorTests extends ESTestCase { + + private static final double[] START_TOKEN_SCORES = new double[] { + 1.6665655, + -7.988514, + -8.249796, + .529973, + -8.46703, + -8.345977, + -8.459701, + -8.260341, + .071103, + -7.339133, + -7.647086, + -8.165343, + -8.277936, + -8.156116, + -8.104215, + -8.45849, + -8.249917, + -2.0896196, + -0.67172474 }; + + private static final double[] END_TOKEN_SCORES = new double[] { + 1.0593028, + -8.276232, + -7.9352865, + -8.340191, + -8.326643, + -8.225507, + -8.548992, + -8.50256, + -8.716394, + -8.0558195, + -8.4110565, + -6.564298, + -8.570332, + .01, + -7.2465587, + .6000237, + -8.045577, + -6.3548584, + -3.5642238 }; + + // The data here is nonsensical. We just want to make sure tokens chosen match up with our scores + public void testProcessor() throws IOException { + String question = "is Elasticsearch fun?"; + String input = "Pancake day is fun with Elasticsearch and little red car"; + BertTokenization tokenization = new BertTokenization(false, true, 384, Tokenization.Truncate.NONE, 128); + BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, tokenization).build(); + QuestionAnsweringConfig config = new QuestionAnsweringConfig(question, 1, 10, new VocabularyConfig(""), tokenization, "prediction"); + QuestionAnsweringProcessor processor = new QuestionAnsweringProcessor(tokenizer, config); + TokenizationResult tokenizationResult = processor.getRequestBuilder(config) + .buildRequest(List.of(input), "1", Tokenization.Truncate.NONE, 128) + .tokenization(); + assertThat(tokenizationResult.anyTruncated(), is(false)); + assertThat(tokenizationResult.getTokenization(0).tokenIds().length, equalTo(END_TOKEN_SCORES.length)); + // tokenized question length with cls and sep token + assertThat(tokenizationResult.getTokenization(0).seqPairOffset(), equalTo(7)); + double[][][] scores = { { START_TOKEN_SCORES }, { END_TOKEN_SCORES } }; + NlpTask.ResultProcessor resultProcessor = processor.getResultProcessor(config); + PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", scores, 1L, null); + QuestionAnsweringInferenceResults result = (QuestionAnsweringInferenceResults) resultProcessor.processResult( + tokenizationResult, + pyTorchResult + ); + + // Note this is a different answer to testTopScores because of the question length + assertThat(result.getScore(), closeTo(0.05264939, 1e-6)); + // These are the token offsets by char + assertThat(result.getStartOffset(), equalTo(8)); + assertThat(result.getEndOffset(), equalTo(48)); + assertThat(result.getAnswer(), equalTo(input.substring(8, 48))); + } + + public void testTopScoresRespectsAnswerLength() { + int seq2Start = 8; + int numAnswersToGather = 1; + AtomicReference result = new AtomicReference<>(); + QuestionAnsweringProcessor.topScores( + START_TOKEN_SCORES, + END_TOKEN_SCORES, + numAnswersToGather, + result::set, + seq2Start, + START_TOKEN_SCORES.length, + 10, + 0 + ); + assertThat(result.get().score(), closeTo(0.05265336, 1e-6)); + // The token positions as related to the second sequence start + assertThat(result.get().startToken(), equalTo(0)); + assertThat(result.get().endToken(), equalTo(7)); + + // Restrict to a shorter answer length + QuestionAnsweringProcessor.topScores( + START_TOKEN_SCORES, + END_TOKEN_SCORES, + numAnswersToGather, + result::set, + seq2Start, + START_TOKEN_SCORES.length, + 6, + 0 + ); + assertThat(result.get().score(), closeTo(0.0291865, 1e-6)); + // The token positions as related to the second sequence start + assertThat(result.get().startToken(), equalTo(0)); + assertThat(result.get().endToken(), equalTo(5)); + } + + public void testTopScoresMoreThanOne() { + int seq2Start = 8; + int numAnswersToGather = 2; + QuestionAnsweringProcessor.ScoreAndIndicesPriorityQueue result = new QuestionAnsweringProcessor.ScoreAndIndicesPriorityQueue(2); + QuestionAnsweringProcessor.topScores( + START_TOKEN_SCORES, + END_TOKEN_SCORES, + numAnswersToGather, + result::insertWithOverflow, + seq2Start, + START_TOKEN_SCORES.length, + 10, + 0 + ); + + assertThat(result.size(), equalTo(numAnswersToGather)); + + QuestionAnsweringProcessor.ScoreAndIndices[] topScores = new QuestionAnsweringProcessor.ScoreAndIndices[numAnswersToGather]; + for (int i = numAnswersToGather - 1; i >= 0; i--) { + topScores[i] = result.pop(); + } + + assertThat(topScores[0].score(), closeTo(0.05265336, 1e-6)); + assertThat(topScores[0].startToken(), equalTo(0)); + assertThat(topScores[0].endToken(), equalTo(7)); + + assertThat(topScores[1].score(), closeTo(0.0291865, 1e-6)); + assertThat(topScores[1].startToken(), equalTo(0)); + assertThat(topScores[1].endToken(), equalTo(5)); + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java index e8edac804910c..a0778b8575d44 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java @@ -44,7 +44,8 @@ public class BertTokenizerTests extends ESTestCase { "day", "Pancake", "with", - BertTokenizer.PAD_TOKEN + BertTokenizer.PAD_TOKEN, + "?" ); private List tokenStrings(List tokens) { @@ -59,7 +60,7 @@ public void testTokenize() { ).build() ) { TokenizationResult.Tokens tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE, -1, 0).get(0); - assertThat(tokenStrings(tokenization.tokens()), contains("Elastic", "##search", "fun")); + assertThat(tokenStrings(tokenization.tokens().get(0)), contains("Elastic", "##search", "fun")); assertArrayEquals(new int[] { 0, 1, 3 }, tokenization.tokenIds()); assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.tokenMap()); } @@ -227,7 +228,7 @@ public void testNeverSplitTokens() { -1, 0 ).get(0); - assertThat(tokenStrings(tokenization.tokens()), contains("Elastic", "##search", specialToken, "fun")); + assertThat(tokenStrings(tokenization.tokens().get(0)), contains("Elastic", "##search", specialToken, "fun")); assertArrayEquals(new int[] { 0, 1, 15, 3 }, tokenization.tokenIds()); assertArrayEquals(new int[] { 0, 0, 1, 2 }, tokenization.tokenMap()); } @@ -270,7 +271,7 @@ public void testPunctuation() { .build() ) { TokenizationResult.Tokens tokenization = tokenizer.tokenize("Elasticsearch, fun.", Tokenization.Truncate.NONE, -1, 0).get(0); - assertThat(tokenStrings(tokenization.tokens()), contains("Elastic", "##search", ",", "fun", ".")); + assertThat(tokenStrings(tokenization.tokens().get(0)), contains("Elastic", "##search", ",", "fun", ".")); assertArrayEquals(new int[] { 0, 1, 11, 3, 10 }, tokenization.tokenIds()); assertArrayEquals(new int[] { 0, 0, 1, 2, 3 }, tokenization.tokenMap()); @@ -305,17 +306,17 @@ public void testPunctuationWithMask() { ) { TokenizationResult.Tokens tokenization = tokenizer.tokenize("This is [MASK]-tastic!", Tokenization.Truncate.NONE, -1, 0).get(0); - assertThat(tokenStrings(tokenization.tokens()), contains("This", "is", "[MASK]", "-", "ta", "##stic", "!")); + assertThat(tokenStrings(tokenization.tokens().get(0)), contains("This", "is", "[MASK]", "-", "ta", "##stic", "!")); assertArrayEquals(new int[] { 0, 1, 2, 3, 4, 6, 7, 8, 9 }, tokenization.tokenIds()); assertArrayEquals(new int[] { -1, 0, 1, 2, 3, 4, 4, 5, -1 }, tokenization.tokenMap()); tokenization = tokenizer.tokenize("This is sub~[MASK]!", Tokenization.Truncate.NONE, -1, 0).get(0); - assertThat(tokenStrings(tokenization.tokens()), contains("This", "is", "sub", "~", "[MASK]", "!")); + assertThat(tokenStrings(tokenization.tokens().get(0)), contains("This", "is", "sub", "~", "[MASK]", "!")); assertArrayEquals(new int[] { 0, 1, 2, 10, 5, 3, 8, 9 }, tokenization.tokenIds()); assertArrayEquals(new int[] { -1, 0, 1, 2, 3, 4, 5, -1 }, tokenization.tokenMap()); tokenization = tokenizer.tokenize("This is sub,[MASK].tastic!", Tokenization.Truncate.NONE, -1, 0).get(0); - assertThat(tokenStrings(tokenization.tokens()), contains("This", "is", "sub", ",", "[MASK]", ".", "ta", "##stic", "!")); + assertThat(tokenStrings(tokenization.tokens().get(0)), contains("This", "is", "sub", ",", "[MASK]", ".", "ta", "##stic", "!")); assertArrayEquals(new int[] { 0, 1, 2, 10, 11, 3, 12, 6, 7, 8, 9 }, tokenization.tokenIds()); assertArrayEquals(new int[] { -1, 0, 1, 2, 3, 4, 5, 6, 6, 7, -1 }, tokenization.tokenMap()); } @@ -394,6 +395,80 @@ public void testMultiSeqTokenization() { } } + public void testMultiSeqTokenizationWithSpan() { + try ( + BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, Tokenization.createDefault()) + .setDoLowerCase(false) + .setWithSpecialTokens(true) + .build() + ) { + List tokenization = tokenizer.tokenize( + "Elasticsearch is fun", + "Godzilla my little red car", + Tokenization.Truncate.NONE, + 1, + 0 + ); + assertThat(tokenization, hasSize(1)); + + var tokenStream = Arrays.stream(tokenization.get(0).tokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList()); + assertThat( + tokenStream, + contains( + BertTokenizer.CLASS_TOKEN, + "Elastic", + "##search", + "is", + "fun", + BertTokenizer.SEPARATOR_TOKEN, + "God", + "##zilla", + "my", + "little", + "red", + "car", + BertTokenizer.SEPARATOR_TOKEN + ) + ); + assertArrayEquals(new int[] { 12, 0, 1, 2, 3, 13, 8, 9, 4, 5, 6, 7, 13 }, tokenization.get(0).tokenIds()); + } + } + + public void testMultiSeqTokenizationWithSpanOnLongInput() { + try ( + BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, Tokenization.createDefault()) + .setDoLowerCase(false) + .setWithSpecialTokens(true) + .setMaxSequenceLength(8) + .build() + ) { + List tokenizationList = tokenizer.tokenize( + "Elasticsearch is fun", + "Godzilla my little red car", + Tokenization.Truncate.NONE, + 0, + 0 + ); + assertThat(tokenizationList, hasSize(6)); + String[] seventhToken = new String[] { "God", "##zilla", "my", "little", "red", "car" }; + for (int i = 0; i < seventhToken.length; i++) { + assertThat( + Arrays.stream(tokenizationList.get(i).tokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList()), + contains( + BertTokenizer.CLASS_TOKEN, + "Elastic", + "##search", + "is", + "fun", + BertTokenizer.SEPARATOR_TOKEN, + seventhToken[i], + BertTokenizer.SEPARATOR_TOKEN + ) + ); + } + } + } + public void testTokenizeLargeInputMultiSequenceTruncation() { try ( BertTokenizer tokenizer = BertTokenizer.builder( diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/MPNetTokenizerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/MPNetTokenizerTests.java index 2aa44c6e4d5e3..f15ff3a610a92 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/MPNetTokenizerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/MPNetTokenizerTests.java @@ -54,7 +54,7 @@ public void testTokenize() { ).build() ) { TokenizationResult.Tokens tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE, -1, 0).get(0); - assertThat(tokenStrings(tokenization.tokens()), contains("Elastic", "##search", "fun")); + assertThat(tokenStrings(tokenization.tokens().get(0)), contains("Elastic", "##search", "fun")); assertArrayEquals(new int[] { 0, 1, 3 }, tokenization.tokenIds()); assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.tokenMap()); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/RobertaTokenizerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/RobertaTokenizerTests.java index 27ba0b2b75608..202ca786d5207 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/RobertaTokenizerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/RobertaTokenizerTests.java @@ -34,7 +34,7 @@ public void testTokenize() { ).build() ) { TokenizationResult.Tokens tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE, -1, 0).get(0); - assertThat(tokenStrings(tokenization.tokens()), contains("Elast", "icsearch", "Ġfun")); + assertThat(tokenStrings(tokenization.tokens().get(0)), contains("Elast", "icsearch", "Ġfun")); assertArrayEquals(new int[] { 0, 297, 299, 275, 2 }, tokenization.tokenIds()); assertArrayEquals(new int[] { -1, 0, 0, 1, -1 }, tokenization.tokenMap()); }