Skip to content

Commit

Permalink
prefix type in the request
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Nov 10, 2023
1 parent 2d4dc4c commit c4d3327
Show file tree
Hide file tree
Showing 19 changed files with 168 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
Expand Down Expand Up @@ -88,6 +89,7 @@ public static Builder parseRequest(String id, XContentParser parser) {
// input and so cannot construct a document.
private final List<String> textInput;
private boolean highPriority;
private TrainedModelPrefixStrings.PrefixType prefixType = TrainedModelPrefixStrings.PrefixType.NONE;

/**
* Build a request from a list of documents as maps.
Expand Down Expand Up @@ -190,6 +192,11 @@ public Request(StreamInput in) throws IOException {
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) {
highPriority = in.readBoolean();
}
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_TRAINED_MODEL_PREFIX_STRINGS_ADDED)) {
prefixType = in.readEnum(TrainedModelPrefixStrings.PrefixType.class);
} else {
prefixType = TrainedModelPrefixStrings.PrefixType.NONE;
}
}

public int numberOfDocuments() {
Expand Down Expand Up @@ -232,6 +239,14 @@ public void setHighPriority(boolean highPriority) {
this.highPriority = highPriority;
}

public void setPrefixType(TrainedModelPrefixStrings.PrefixType prefixType) {
this.prefixType = prefixType;
}

public TrainedModelPrefixStrings.PrefixType getPrefixType() {
return prefixType;
}

@Override
public ActionRequestValidationException validate() {
return null;
Expand All @@ -253,6 +268,9 @@ public void writeTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) {
out.writeBoolean(highPriority);
}
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_TRAINED_MODEL_PREFIX_STRINGS_ADDED)) {
out.writeEnum(prefixType);
}
}

@Override
Expand All @@ -266,7 +284,8 @@ public boolean equals(Object o) {
&& Objects.equals(inferenceTimeout, that.inferenceTimeout)
&& Objects.equals(objectsToInfer, that.objectsToInfer)
&& Objects.equals(textInput, that.textInput)
&& (highPriority == that.highPriority);
&& (highPriority == that.highPriority)
&& (prefixType == that.prefixType);
}

@Override
Expand All @@ -276,7 +295,7 @@ public Task createTask(long id, String type, String action, TaskId parentTaskId,

@Override
public int hashCode() {
return Objects.hash(id, objectsToInfer, update, previouslyLicensed, inferenceTimeout, textInput, highPriority);
return Objects.hash(id, objectsToInfer, update, previouslyLicensed, inferenceTimeout, textInput, highPriority, prefixType);
}

public static class Builder {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
Expand Down Expand Up @@ -102,6 +103,7 @@ public static Request.Builder parseRequest(String id, XContentParser parser) {
// and do know which field the model expects to find its
// input and so cannot construct a document.
private final List<String> textInput;
private TrainedModelPrefixStrings.PrefixType prefixType = TrainedModelPrefixStrings.PrefixType.NONE;

public static Request forDocs(String id, InferenceConfigUpdate update, List<Map<String, Object>> docs, TimeValue inferenceTimeout) {
return new Request(
Expand Down Expand Up @@ -156,6 +158,11 @@ public Request(StreamInput in) throws IOException {
} else {
textInput = null;
}
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_TRAINED_MODEL_PREFIX_STRINGS_ADDED)) {
prefixType = in.readEnum(TrainedModelPrefixStrings.PrefixType.class);
} else {
prefixType = TrainedModelPrefixStrings.PrefixType.NONE;
}
}

public String getId() {
Expand Down Expand Up @@ -200,6 +207,14 @@ public boolean isHighPriority() {
return highPriority;
}

public void setPrefixType(TrainedModelPrefixStrings.PrefixType prefixType) {
this.prefixType = prefixType;
}

public TrainedModelPrefixStrings.PrefixType getPrefixType() {
return prefixType;
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException validationException = super.validate();
Expand All @@ -226,6 +241,9 @@ public void writeTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_7_0)) {
out.writeOptionalStringCollection(textInput);
}
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_TRAINED_MODEL_PREFIX_STRINGS_ADDED)) {
out.writeEnum(prefixType);
}
}

@Override
Expand All @@ -243,12 +261,13 @@ public boolean equals(Object o) {
&& Objects.equals(update, that.update)
&& Objects.equals(inferenceTimeout, that.inferenceTimeout)
&& Objects.equals(highPriority, that.highPriority)
&& Objects.equals(textInput, that.textInput);
&& Objects.equals(textInput, that.textInput)
&& (prefixType == that.prefixType);
}

@Override
public int hashCode() {
return Objects.hash(id, update, docs, inferenceTimeout, highPriority, textInput);
return Objects.hash(id, update, docs, inferenceTimeout, highPriority, textInput, prefixType);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@

public record TrainedModelPrefixStrings(String ingestPrefix, String searchPrefix) implements ToXContentObject, Writeable {

public enum PrefixType {
INGEST,
SEARCH,
NONE
}

public static final ParseField INGEST_PREFIX = new ParseField("ingest");
public static final ParseField SEARCH_PREFIX = new ParseField("search");
public static final String NAME = "trained_model_config_prefix_strings";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.elasticsearch.xpack.core.ml.action.InferModelAction.Request;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdateTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdateTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigUpdate;
Expand Down Expand Up @@ -66,6 +67,9 @@ protected Request createTestInstance() {
);

request.setHighPriority(randomBoolean());
if (randomBoolean()) {
request.setPrefixType(randomFrom(TrainedModelPrefixStrings.PrefixType.values()));
}
return request;
}

Expand All @@ -79,8 +83,9 @@ protected Request mutateInstance(Request instance) {
var update = instance.getUpdate();
var previouslyLicensed = instance.isPreviouslyLicensed();
var timeout = instance.getInferenceTimeout();
var prefixType = instance.getPrefixType();

int change = randomIntBetween(0, 6);
int change = randomIntBetween(0, 7);
switch (change) {
case 0:
modelId = modelId + "foo";
Expand Down Expand Up @@ -111,12 +116,17 @@ protected Request mutateInstance(Request instance) {
case 6:
timeout = TimeValue.timeValueSeconds(timeout.getSeconds() - 1);
break;
case 7:
prefixType = TrainedModelPrefixStrings.PrefixType.values()[(prefixType.ordinal() + 1) % TrainedModelPrefixStrings.PrefixType
.values().length];
break;
default:
throw new IllegalStateException();
}

var r = new Request(modelId, update, objectsToInfer, textInput, timeout, previouslyLicensed);
r.setHighPriority(highPriority);
r.setPrefixType(prefixType);
return r;
}

Expand Down Expand Up @@ -211,6 +221,18 @@ protected Request mutateInstanceForVersion(Request instance, TransportVersion ve
);
r.setHighPriority(false);
return r;
} else if (version.before(TransportVersions.ML_TRAINED_MODEL_PREFIX_STRINGS_ADDED)) {
var r = new Request(
instance.getId(),
adjustedUpdate,
instance.getObjectsToInfer(),
instance.getTextInput(),
instance.getInferenceTimeout(),
instance.isPreviouslyLicensed()
);
r.setHighPriority(instance.isHighPriority());
r.setPrefixType(TrainedModelPrefixStrings.PrefixType.NONE);
return r;
}

return instance;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.core.Tuple;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdateTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigUpdateTests;
Expand All @@ -36,9 +37,10 @@ protected Writeable.Reader<InferTrainedModelDeploymentAction.Request> instanceRe
@Override
protected InferTrainedModelDeploymentAction.Request createTestInstance() {
boolean createQueryStringRequest = randomBoolean();
InferTrainedModelDeploymentAction.Request request;

if (createQueryStringRequest) {
return InferTrainedModelDeploymentAction.Request.forTextInput(
request = InferTrainedModelDeploymentAction.Request.forTextInput(
randomAlphaOfLength(4),
randomBoolean() ? null : randomInferenceConfigUpdate(),
Arrays.asList(generateRandomStringArray(4, 7, false)),
Expand All @@ -50,13 +52,16 @@ protected InferTrainedModelDeploymentAction.Request createTestInstance() {
() -> randomMap(1, 3, () -> Tuple.tuple(randomAlphaOfLength(7), randomAlphaOfLength(7)))
);

return InferTrainedModelDeploymentAction.Request.forDocs(
request = InferTrainedModelDeploymentAction.Request.forDocs(
randomAlphaOfLength(4),
randomBoolean() ? null : randomInferenceConfigUpdate(),
docs,
randomBoolean() ? null : TimeValue.parseTimeValue(randomTimeValue(), "timeout")
);
}
request.setHighPriority(randomBoolean());
request.setPrefixType(randomFrom(TrainedModelPrefixStrings.PrefixType.values()));
return request;
}

@Override
Expand All @@ -66,8 +71,7 @@ protected InferTrainedModelDeploymentAction.Request mutateInstance(InferTrainedM

@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>(new MlInferenceNamedXContentProvider().getNamedWriteables());
return new NamedWriteableRegistry(entries);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ protected void taskOperation(
request.getUpdate(),
request.isHighPriority(),
request.getInferenceTimeout(),
request.getPrefixType(),
actionTask,
orderedListener(count, results, slot++, nlpInputs.size(), listener)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ private void inferAgainstAllocatedModel(
);
}
deploymentRequest.setHighPriority(request.isHighPriority());
deploymentRequest.setPrefixType(request.getPrefixType());
deploymentRequest.setNodes(node.v1());
deploymentRequest.setParentTask(parentTaskId);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.elasticsearch.xpack.core.ml.MlMetadata;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfoUpdate;
Expand Down Expand Up @@ -289,10 +290,11 @@ public void infer(
NlpInferenceInput input,
boolean skipQueue,
TimeValue timeout,
TrainedModelPrefixStrings.PrefixType prefixType,
CancellableTask parentActionTask,
ActionListener<InferenceResults> listener
) {
deploymentManager.infer(task, config, input, skipQueue, timeout, parentActionTask, listener);
deploymentManager.infer(task, config, input, skipQueue, timeout, prefixType, parentActionTask, listener);
}

public Optional<ModelStats> modelStats(TrainedModelDeploymentTask task) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ public void infer(
NlpInferenceInput input,
boolean skipQueue,
TimeValue timeout,
TrainedModelPrefixStrings.PrefixType prefixType,
CancellableTask parentActionTask,
ActionListener<InferenceResults> listener
) {
Expand All @@ -338,7 +339,7 @@ public void infer(
processContext,
config,
input,
skipQueue,
prefixType,
threadPool,
parentActionTask,
listener
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
Expand All @@ -39,7 +40,7 @@ class InferencePyTorchAction extends AbstractPyTorchAction<InferenceResults> {
private final NlpInferenceInput input;
@Nullable
private final CancellableTask parentActionTask;
private final boolean forSearch;
private final TrainedModelPrefixStrings.PrefixType prefixType;

InferencePyTorchAction(
String deploymentId,
Expand All @@ -48,15 +49,15 @@ class InferencePyTorchAction extends AbstractPyTorchAction<InferenceResults> {
DeploymentManager.ProcessContext processContext,
InferenceConfig config,
NlpInferenceInput input,
boolean forSearch,
TrainedModelPrefixStrings.PrefixType prefixType,
ThreadPool threadPool,
@Nullable CancellableTask parentActionTask,
ActionListener<InferenceResults> listener
) {
super(deploymentId, requestId, timeout, processContext, threadPool, listener);
this.config = config;
this.input = input;
this.forSearch = forSearch;
this.prefixType = prefixType;
this.parentActionTask = parentActionTask;
}

Expand Down Expand Up @@ -87,12 +88,25 @@ protected void doRun() throws Exception {
final String requestIdStr = String.valueOf(getRequestId());
try {
String inputText = input.extractInput(getProcessContext().getModelInput().get());
var prefixStrings = getProcessContext().getPrefixStrings().get();
if (prefixStrings != null) {
if (forSearch && Strings.isNullOrEmpty(prefixStrings.searchPrefix()) == false) {
inputText = prefixStrings.searchPrefix() + inputText;
} else if (forSearch == false && Strings.isNullOrEmpty(prefixStrings.ingestPrefix()) == false) {
inputText = prefixStrings.ingestPrefix() + inputText;
if (prefixType != TrainedModelPrefixStrings.PrefixType.NONE) {
var prefixStrings = getProcessContext().getPrefixStrings().get();
if (prefixStrings != null) {
switch (prefixType) {
case SEARCH: {
if (Strings.isNullOrEmpty(prefixStrings.searchPrefix()) == false) {
inputText = prefixStrings.searchPrefix() + inputText;
}
}
break;
case INGEST: {
if (Strings.isNullOrEmpty(prefixStrings.ingestPrefix()) == false) {
inputText = prefixStrings.ingestPrefix() + inputText;
}
}
break;
default:
throw new IllegalStateException("[" + getDeploymentId() + "] Unhandled input prefix type [" + prefixType + "]");
}
}
}

Expand Down
Loading

0 comments on commit c4d3327

Please sign in to comment.