Skip to content

Commit

Permalink
[ML] Make Inference Services pluggable (elastic#99886)
Browse files Browse the repository at this point in the history
Creates an InferenceServicePlugins interface for inference services
to implement and adds a test implementation to mock an inference 
service.
  • Loading branch information
davidkyle authored Sep 27, 2023
1 parent 9ea5250 commit 096cf81
Show file tree
Hide file tree
Showing 97 changed files with 705 additions and 392 deletions.
1 change: 1 addition & 0 deletions server/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@
exports org.elasticsearch.indices.recovery;
exports org.elasticsearch.indices.recovery.plan;
exports org.elasticsearch.indices.store;
exports org.elasticsearch.inference;
exports org.elasticsearch.ingest;
exports org.elasticsearch.internal
to
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.xpack.core.ml.inference.results;

package org.elasticsearch.inference;

import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xcontent.ToXContentFragment;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.util.Map;
import java.util.Objects;

public interface InferenceResults extends NamedWriteable, ToXContentFragment {
String PREDICTION_PROBABILITY = "prediction_probability";
String MODEL_ID_RESULTS_FIELD = "model_id";

static void writeResult(InferenceResults results, IngestDocument ingestDocument, String resultField, String modelId) {
ExceptionsHelper.requireNonNull(results, "results");
ExceptionsHelper.requireNonNull(ingestDocument, "ingestDocument");
ExceptionsHelper.requireNonNull(resultField, "resultField");
Objects.requireNonNull(results, "results");
Objects.requireNonNull(ingestDocument, "ingestDocument");
Objects.requireNonNull(resultField, "resultField");
Map<String, Object> resultMap = results.asMap();
resultMap.put(MODEL_ID_RESULTS_FIELD, modelId);
if (ingestDocument.hasField(resultField)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.xpack.inference.services;
package org.elasticsearch.inference;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.xpack.inference.Model;
import org.elasticsearch.xpack.inference.TaskType;
import org.elasticsearch.xpack.inference.results.InferenceResult;

import java.util.Map;

Expand Down Expand Up @@ -45,20 +43,20 @@ public interface InferenceService {
*/
Model parseConfigLenient(String modelId, TaskType taskType, Map<String, Object> config);

/**
* Start or prepare the model for use.
* @param model The model
* @param listener The listener
*/
void start(Model model, ActionListener<Boolean> listener);

/**
* Perform inference on the model.
*
* @param model Model configuration
* @param model The model
* @param input Inference input
* @param requestTaskSettings Settings in the request to override the model's defaults
* @param taskSettings Settings in the request to override the model's defaults
* @param listener Inference result listener
*/
void infer(Model model, String input, Map<String, Object> requestTaskSettings, ActionListener<InferenceResult> listener);
void infer(Model model, String input, Map<String, Object> taskSettings, ActionListener<InferenceResults> listener);

/**
* Start or prepare the model for use.
* @param model The model
* @param listener The listener
*/
void start(Model model, ActionListener<Boolean> listener);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.inference;

import org.elasticsearch.common.component.AbstractLifecycleComponent;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.plugins.InferenceServicePlugin;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

public class InferenceServiceRegistry extends AbstractLifecycleComponent {

private final Map<String, InferenceService> services;
private final List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();

public InferenceServiceRegistry(
List<InferenceServicePlugin> inferenceServicePlugins,
InferenceServicePlugin.InferenceServiceFactoryContext factoryContext
) {
// TODO check names are unique
services = inferenceServicePlugins.stream()
.flatMap(r -> r.getInferenceServiceFactories().stream())
.map(factory -> factory.create(factoryContext))
.collect(Collectors.toMap(InferenceService::name, Function.identity()));

for (var plugin : inferenceServicePlugins) {
namedWriteables.addAll(plugin.getInferenceServiceNamedWriteables());
}
}

public Map<String, InferenceService> getServices() {
return services;
}

public Optional<InferenceService> getService(String serviceName) {
return Optional.ofNullable(services.get(serviceName));
}

public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
return namedWriteables;
}

@Override
protected void doStart() {

}

@Override
protected void doStop() {

}

@Override
protected void doClose() throws IOException {

}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.xpack.inference;
package org.elasticsearch.inference;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.xpack.inference;
package org.elasticsearch.inference;

import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
import org.elasticsearch.xcontent.ToXContentObject;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.xpack.inference;
package org.elasticsearch.inference;

import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
import org.elasticsearch.xcontent.ToXContentObject;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.xpack.inference;
package org.elasticsearch.inference;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.io.stream.StreamInput;
Expand Down Expand Up @@ -49,4 +50,8 @@ public static TaskType fromStream(StreamInput in) throws IOException {
public void writeTo(StreamOutput out) throws IOException {
out.writeEnum(this);
}

public static String unsupportedTaskTypeErrorMsg(TaskType taskType, String serviceName) {
return "The [" + serviceName + "] service does not support task type [" + taskType + "]";
}
}
12 changes: 11 additions & 1 deletion server/src/main/java/org/elasticsearch/node/Node.java
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
import org.elasticsearch.indices.recovery.plan.RecoveryPlannerService;
import org.elasticsearch.indices.recovery.plan.ShardSnapshotsService;
import org.elasticsearch.indices.store.IndicesStore;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.ingest.IngestService;
import org.elasticsearch.monitor.MonitorService;
import org.elasticsearch.monitor.fs.FsHealthService;
Expand All @@ -165,6 +166,7 @@
import org.elasticsearch.plugins.EnginePlugin;
import org.elasticsearch.plugins.HealthPlugin;
import org.elasticsearch.plugins.IndexStorePlugin;
import org.elasticsearch.plugins.InferenceServicePlugin;
import org.elasticsearch.plugins.IngestPlugin;
import org.elasticsearch.plugins.MapperPlugin;
import org.elasticsearch.plugins.MetadataUpgrader;
Expand Down Expand Up @@ -532,6 +534,12 @@ protected Node(

Supplier<DocumentParsingObserver> documentParsingObserverSupplier = getDocumentParsingObserverSupplier();

var factoryContext = new InferenceServicePlugin.InferenceServiceFactoryContext(client);
final InferenceServiceRegistry inferenceServiceRegistry = new InferenceServiceRegistry(
pluginsService.filterPlugins(InferenceServicePlugin.class),
factoryContext
);

final IngestService ingestService = new IngestService(
clusterService,
threadPool,
Expand All @@ -555,7 +563,8 @@ protected Node(
searchModule.getNamedWriteables().stream(),
pluginsService.flatMap(Plugin::getNamedWriteables),
ClusterModule.getNamedWriteables().stream(),
SystemIndexMigrationExecutor.getNamedWriteables().stream()
SystemIndexMigrationExecutor.getNamedWriteables().stream(),
inferenceServiceRegistry.getNamedWriteables().stream()
).flatMap(Function.identity()).toList();
final NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(namedWriteables);
NamedXContentRegistry xContentRegistry = new NamedXContentRegistry(
Expand Down Expand Up @@ -1170,6 +1179,7 @@ protected Node(
b.bind(WriteLoadForecaster.class).toInstance(writeLoadForecaster);
b.bind(HealthPeriodicLogger.class).toInstance(healthPeriodicLogger);
b.bind(CompatibilityVersions.class).toInstance(compatibilityVersions);
b.bind(InferenceServiceRegistry.class).toInstance(inferenceServiceRegistry);
});

if (ReadinessService.enabled(environment)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.plugins;

import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.inference.InferenceService;

import java.util.List;

/**
* InferenceServicePlugins implement an inference service
*/
public interface InferenceServicePlugin {

List<Factory> getInferenceServiceFactories();

record InferenceServiceFactoryContext(Client client) {}

interface Factory {
/**
* InferenceServices are created from the factory context
*/
InferenceService create(InferenceServiceFactoryContext context);
}

/**
* The named writables defined and used by each of the implemented
* InferenceServices. Each service should define named writables for
* - {@link org.elasticsearch.inference.TaskSettings}
* - {@link org.elasticsearch.inference.ServiceSettings}
* And optionally for {@link org.elasticsearch.inference.InferenceResults}
* if the service uses a new type of result.
* @return All named writables defined by the services
*/
List<NamedWriteableRegistry.Entry> getInferenceServiceNamedWriteables();
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
Expand All @@ -22,7 +23,6 @@
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
Expand All @@ -25,7 +26,6 @@
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package org.elasticsearch.xpack.core.ml.inference;

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.plugins.spi.NamedXContentProvider;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ParseField;
Expand All @@ -21,7 +22,6 @@
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.stream.Collectors;

public class ClassificationInferenceResults extends SingleValueInferenceResults {
public static String PREDICTION_PROBABILITY = "prediction_probability";

public static final String NAME = "classification";

Expand Down
Loading

0 comments on commit 096cf81

Please sign in to comment.