Skip to content

Commit

Permalink
[ML] Enrich documents with inference results at Fetch (#53230)
Browse files Browse the repository at this point in the history
Adds a FetchSubPhase which adds a new field to the search hits with the result of the model 
inference performed on the hit. There isn't a direct way of configuring FetchSubPhases so 
SearchExtSpec is used for the purpose.
  • Loading branch information
davidkyle authored Mar 11, 2020
1 parent f153f19 commit 54fb29f
Show file tree
Hide file tree
Showing 20 changed files with 692 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,24 @@ public void writeResult(IngestDocument document, String parentResultField) {
}
}

@Override
public Map<String, Object> writeResultToMap(String parentResultField) {
Map<String, Object> parentField = new HashMap<>();
Map<String, Object> results = new HashMap<>();
parentField.put(parentResultField, results);

results.put(resultsField, valueAsString());
if (topClasses.size() > 0) {
results.put(topNumClassesField, topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
}
if (getFeatureImportance().size() > 0) {
results.put("feature_importance", getFeatureImportance());
}

return parentField;
}


@Override
public String getWriteableName() {
return NAME;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.ingest.IngestDocument;

import java.util.Map;

public interface InferenceResults extends NamedWriteable {

void writeResult(IngestDocument document, String parentResultField);

Map<String, Object> writeResultToMap(String parentResultField);
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ public void writeResult(IngestDocument document, String parentResultField) {
throw new UnsupportedOperationException("[raw] does not support writing inference results");
}

@Override
public Map<String, Object> writeResultToMap(String parentResultField) {
throw new UnsupportedOperationException("[raw] does not support writing inference results");
}

@Override
public String getWriteableName() {
return NAME;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

Expand Down Expand Up @@ -74,6 +75,20 @@ public void writeResult(IngestDocument document, String parentResultField) {
}
}

@Override
public Map<String, Object> writeResultToMap(String parentResultField) {
Map<String, Object> parentResult = new HashMap<>();
Map<String, Object> result = new HashMap<>();
parentResult.put(parentResultField, result);

result.put(resultsField, value());
if (getFeatureImportance().size() > 0) {
result.put("feature_importance", getFeatureImportance());
}

return parentResult;
}

@Override
public String getWriteableName() {
return NAME;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;

public class WarningInferenceResults implements InferenceResults {

public static final String NAME = "warning";
public static final ParseField WARNING = new ParseField("warning");
public static final ParseField WARNING = new ParseField(NAME);

private final String warning;

Expand Down Expand Up @@ -55,7 +57,12 @@ public int hashCode() {
public void writeResult(IngestDocument document, String parentResultField) {
ExceptionsHelper.requireNonNull(document, "document");
ExceptionsHelper.requireNonNull(parentResultField, "resultField");
document.setFieldValue(parentResultField + "." + "warning", warning);
document.setFieldValue(parentResultField + "." + NAME, warning);
}

@Override
public Map<String, Object> writeResultToMap(String parentResultField) {
return Collections.singletonMap(parentResultField, Collections.singletonMap(NAME, warning));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ public String getTopClassesResultsField() {
return topClassesResultsField;
}

@Override
public String getResultsField() {
return resultsField;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ public interface InferenceConfig extends NamedXContentObject, NamedWriteable {
default boolean requestingImportance() {
return false;
}

String getResultsField();
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ public NullInferenceConfig(boolean requestingFeatureImportance) {
this.requestingFeatureImportance = requestingFeatureImportance;
}

@Override
public String getResultsField() {
return null;
}

@Override
public boolean isTargetTypeSupported(TargetType targetType) {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ public int getNumTopFeatureImportanceValues() {
return numTopFeatureImportanceValues;
}

@Override
public String getResultsField() {
return resultsField;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;

import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -81,6 +82,30 @@ public void testWriteResultsWithTopClasses() {
assertThat(document.getFieldValue("result_field.my_results", String.class), equalTo("foo"));
}

@SuppressWarnings("unchecked")
public void testWriteResultsToMapWithTopClasses() {
List<ClassificationInferenceResults.TopClassEntry> entries = Arrays.asList(
new ClassificationInferenceResults.TopClassEntry("foo", 0.7),
new ClassificationInferenceResults.TopClassEntry("bar", 0.2),
new ClassificationInferenceResults.TopClassEntry("baz", 0.1));
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0,
"foo",
entries,
new ClassificationConfig(3, "my_results", "bar"));
Map<String, Object> resultsDoc = result.writeResultToMap("result_field");

List<?> list = (List<?>) MapHelper.dig("result_field.bar", resultsDoc);
assertThat(list.size(), equalTo(3));

for(int i = 0; i < 3; i++) {
Map<String, Object> map = (Map<String, Object>)list.get(i);
assertThat(map, equalTo(entries.get(i).asValueMap()));
}

Object value = MapHelper.dig("result_field.my_results", resultsDoc);
assertThat(value, equalTo("foo"));
}

@Override
protected ClassificationInferenceResults createTestInstance() {
return createRandomResults();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;

import java.util.HashMap;
import java.util.Map;

import static org.hamcrest.Matchers.equalTo;

Expand All @@ -31,6 +32,14 @@ public void testWriteResults() {
assertThat(document.getFieldValue("result_field.predicted_value", Double.class), equalTo(0.3));
}

public void testWriteResultsToMap() {
RegressionInferenceResults result = new RegressionInferenceResults(0.3, RegressionConfig.EMPTY_PARAMS);
Map<String, Object> doc = result.writeResultToMap("result_field");

Object value = MapHelper.dig("result_field.predicted_value", doc);
assertThat(value, equalTo(0.3));
}

@Override
protected RegressionInferenceResults createTestInstance() {
return createRandomResults();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;

import java.util.HashMap;
import java.util.Map;

import static org.hamcrest.Matchers.equalTo;

Expand All @@ -27,6 +29,14 @@ public void testWriteResults() {
assertThat(document.getFieldValue("result_field.warning", String.class), equalTo("foo"));
}

public void testWriteResultToMap() {
WarningInferenceResults result = new WarningInferenceResults("foo");
Map<String, Object> doc = result.writeResultToMap("result_field");

Object field = MapHelper.dig("result_field.warning", doc);
assertThat(field, equalTo("foo"));
}

@Override
protected WarningInferenceResults createTestInstance() {
return createRandomResults();
Expand Down
2 changes: 2 additions & 0 deletions x-pack/plugin/ml/qa/ml-with-security/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ integTest.runner {
'ml/delete_model_snapshot/Test delete snapshot missing snapshotId',
'ml/delete_model_snapshot/Test delete snapshot missing job_id',
'ml/delete_model_snapshot/Test delete with in-use model',
'ml/fetch_inference/Test fetch regression',
'ml/fetch_inference/Test fetch classification',
'ml/filter_crud/Test create filter api with mismatching body ID',
'ml/filter_crud/Test create filter given invalid filter_id',
'ml/filter_crud/Test get filter API with bad ID',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@
import org.elasticsearch.plugins.IngestPlugin;
import org.elasticsearch.plugins.PersistentTaskPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.plugins.SystemIndexPlugin;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestHandler;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.fetch.FetchSubPhase;
import org.elasticsearch.threadpool.ExecutorBuilder;
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
import org.elasticsearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -211,6 +213,8 @@
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.inference.search.InferencePhase;
import org.elasticsearch.xpack.ml.inference.search.InferenceSearchExtBuilder;
import org.elasticsearch.xpack.ml.job.JobManager;
import org.elasticsearch.xpack.ml.job.JobManagerHolder;
import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier;
Expand Down Expand Up @@ -318,7 +322,7 @@

import static java.util.Collections.emptyList;

public class MachineLearning extends Plugin implements SystemIndexPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin {
public class MachineLearning extends Plugin implements SystemIndexPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin, SearchPlugin {
public static final String NAME = "ml";
public static final String BASE_PATH = "/_ml/";
public static final String PRE_V7_BASE_PATH = "/_xpack/ml/";
Expand Down Expand Up @@ -414,6 +418,7 @@ public Set<DiscoveryNodeRole> getRoles() {
private final SetOnce<DataFrameAnalyticsManager> dataFrameAnalyticsManager = new SetOnce<>();
private final SetOnce<DataFrameAnalyticsAuditor> dataFrameAnalyticsAuditor = new SetOnce<>();
private final SetOnce<MlMemoryTracker> memoryTracker = new SetOnce<>();
private final SetOnce<ModelLoadingService> modelLoadingService = new SetOnce<>();

public MachineLearning(Settings settings, Path configPath) {
this.settings = settings;
Expand Down Expand Up @@ -628,6 +633,7 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
clusterService,
xContentRegistry,
settings);
this.modelLoadingService.set(modelLoadingService);

// Data frame analytics components
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,
Expand Down Expand Up @@ -886,6 +892,18 @@ public Map<String, AnalysisProvider<TokenizerFactory>> getTokenizers() {
return Collections.singletonMap(MlClassicTokenizer.NAME, MlClassicTokenizerFactory::new);
}

@Override
public List<FetchSubPhase> getFetchSubPhases(FetchPhaseConstructionContext context) {
return Collections.singletonList(new InferencePhase(modelLoadingService));
}

@Override
public List<SearchExtSpec<?>> getSearchExts() {
return Collections.singletonList(
new SearchExtSpec<>(InferenceSearchExtBuilder.NAME, InferenceSearchExtBuilder::new,
InferenceSearchExtBuilder::fromXContent));
}

@Override
public UnaryOperator<Map<String, IndexTemplateMetaData>> getIndexTemplateMetaDataUpgrader() {
return UnaryOperator.identity();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;

import java.util.HashSet;
Expand Down Expand Up @@ -44,6 +44,11 @@ public String getModelId() {
return modelId;
}

@Override
public Set<String> getFieldNames() {
return fieldNames;
}

@Override
public String getResultsType() {
switch (trainedModelDefinition.getTrainedModel().targetType()) {
Expand All @@ -53,23 +58,26 @@ public String getResultsType() {
return RegressionInferenceResults.NAME;
default:
throw ExceptionsHelper.badRequestException("Model [{}] has unsupported target type [{}]",
modelId,
trainedModelDefinition.getTrainedModel().targetType());
modelId,
trainedModelDefinition.getTrainedModel().targetType());
}
}

@Override
public void infer(Map<String, Object> fields, InferenceConfig config, ActionListener<InferenceResults> listener) {
public void infer(Map<String, Object> fields, InferenceConfig inferenceConfig, ActionListener<InferenceResults> listener) {
try {
if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) {
listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)));
return;
}

listener.onResponse(trainedModelDefinition.infer(fields, config));
listener.onResponse(infer(fields, inferenceConfig));
} catch (Exception e) {
listener.onFailure(e);
}
}

@Override
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) {
return new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId));
}

return trainedModelDefinition.infer(fields, config);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,17 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;

import java.util.Map;
import java.util.Set;

public interface Model {

String getResultsType();

void infer(Map<String, Object> fields, InferenceConfig inferenceConfig, ActionListener<InferenceResults> listener);

InferenceResults infer(Map<String, Object> fields, InferenceConfig inferenceConfig);

String getModelId();

Set<String> getFieldNames();
}
Loading

0 comments on commit 54fb29f

Please sign in to comment.