Skip to content

Commit

Permalink
Allow for deployments already started
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Nov 16, 2024
1 parent 4605650 commit 2509f92
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,9 @@

package org.elasticsearch.xpack.inference.services.elasticsearch;

import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

public class CustomElandModel extends ElasticsearchInternalModel {

Expand All @@ -39,31 +34,10 @@ public CustomElandModel(
}

@Override
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
Model model,
ActionListener<Boolean> listener
) {

return new ActionListener<>() {
@Override
public void onResponse(CreateTrainedModelAssignmentAction.Response response) {
listener.onResponse(Boolean.TRUE);
}

@Override
public void onFailure(Exception e) {
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
listener.onFailure(
new ResourceNotFoundException(
"Could not start the inference as the custom eland model [{0}] for this platform cannot be found."
+ " Custom models need to be loaded into the cluster with eland before they can be started.",
internalServiceSettings.modelId()
)
);
return;
}
listener.onFailure(e);
}
};
protected String modelNotFoundErrorMessage(String modelId) {
return "Could not deploy model ["
+ modelId
+ "] as the model cannot be found."
+ " Custom models need to be loaded into the cluster with Eland before they can be started.";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentA
throw new IllegalStateException("cannot start model that uses an existing deployment");
}

@Override
protected String modelNotFoundErrorMessage(String modelId) {
throw new IllegalStateException("cannot start model [" + modelId + "] that uses an existing deployment");
}

@Override
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
Model model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

package org.elasticsearch.xpack.inference.services.elasticsearch;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceAlreadyExistsException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.TimeValue;
Expand All @@ -15,8 +18,10 @@
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import static org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus.State.STARTED;

Expand Down Expand Up @@ -79,10 +84,38 @@ public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentA
return startRequest;
}

public abstract ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
Model model,
ActionListener<Boolean> listener
);
) {
return new ActionListener<>() {
@Override
public void onResponse(CreateTrainedModelAssignmentAction.Response response) {
listener.onResponse(Boolean.TRUE);
}

@Override
public void onFailure(Exception e) {
var cause = ExceptionsHelper.unwrapCause(e);
if (cause instanceof ResourceNotFoundException) {
listener.onFailure(new ResourceNotFoundException(modelNotFoundErrorMessage(internalServiceSettings.modelId())));
return;
} else if (cause instanceof ElasticsearchStatusException statusException) {
if (statusException.status() == RestStatus.CONFLICT
&& statusException.getRootCause() instanceof ResourceAlreadyExistsException) {
// Deployment is already started
listener.onResponse(Boolean.TRUE);
}
return;
}
listener.onFailure(e);
}
};
}

protected String modelNotFoundErrorMessage(String modelId) {
return "Could not deploy model [" + modelId + "] as the model cannot be found.";
}

public boolean usesExistingDeployment() {
return internalServiceSettings.getDeploymentId() != null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,8 @@

package org.elasticsearch.xpack.inference.services.elasticsearch;

import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

public class ElserInternalModel extends ElasticsearchInternalModel {

Expand All @@ -37,31 +32,4 @@ public ElserInternalServiceSettings getServiceSettings() {
public ElserMlNodeTaskSettings getTaskSettings() {
return (ElserMlNodeTaskSettings) super.getTaskSettings();
}

@Override
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
Model model,
ActionListener<Boolean> listener
) {
return new ActionListener<>() {
@Override
public void onResponse(CreateTrainedModelAssignmentAction.Response response) {
listener.onResponse(Boolean.TRUE);
}

@Override
public void onFailure(Exception e) {
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
listener.onFailure(
new ResourceNotFoundException(
"Could not start the ELSER service as the ELSER model for this platform cannot be found."
+ " ELSER needs to be downloaded before it can be started."
)
);
return;
}
listener.onFailure(e);
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,8 @@

package org.elasticsearch.xpack.inference.services.elasticsearch;

import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

public class MultilingualE5SmallModel extends ElasticsearchInternalModel {

Expand All @@ -31,34 +26,4 @@ public MultilingualE5SmallModel(
public MultilingualE5SmallInternalServiceSettings getServiceSettings() {
return (MultilingualE5SmallInternalServiceSettings) super.getServiceSettings();
}

@Override
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
Model model,
ActionListener<Boolean> listener
) {

return new ActionListener<>() {
@Override
public void onResponse(CreateTrainedModelAssignmentAction.Response response) {
listener.onResponse(Boolean.TRUE);
}

@Override
public void onFailure(Exception e) {
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
listener.onFailure(
new ResourceNotFoundException(
"Could not start the TextEmbeddingService service as the "
+ "Multilingual-E5-Small model for this platform cannot be found."
+ " Multilingual-E5-Small needs to be downloaded before it can be started"
)
);
return;
}
listener.onFailure(e);
}
};
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<

private final Client client;
private final CircuitBreakerService circuitBreakerService;
final Map<String, List<DownloadTaskRemovedListener>> downloadTrackersByModelId;
final Map<String, List<DownloadTaskRemovedListener>> taskRemovedListenersByModelId;

@Inject
public TransportLoadTrainedModelPackage(
Expand All @@ -83,7 +83,7 @@ public TransportLoadTrainedModelPackage(
);
this.client = new OriginSettingClient(client, ML_ORIGIN);
this.circuitBreakerService = circuitBreakerService;
downloadTrackersByModelId = new HashMap<>();
taskRemovedListenersByModelId = new HashMap<>();
}

@Override
Expand Down Expand Up @@ -174,7 +174,7 @@ synchronized boolean existingDownloadInProgress(
// Otherwise register a task removed listener which is called
// once the tasks is complete and unregistered
var tracker = new DownloadTaskRemovedListener(inProgress, listener);
downloadTrackersByModelId.computeIfAbsent(modelId, s -> new ArrayList<>()).add(tracker);
taskRemovedListenersByModelId.computeIfAbsent(modelId, s -> new ArrayList<>()).add(tracker);
taskManager.registerRemovedTaskListener(tracker);
return true;
}
Expand All @@ -191,7 +191,7 @@ synchronized boolean existingDownloadInProgress(
synchronized void unregisterTask(ModelDownloadTask task) {
taskManager.unregister(task); // unregister will call the on remove function

var trackers = downloadTrackersByModelId.remove(task.getModelId());
var trackers = taskRemovedListenersByModelId.remove(task.getModelId());
if (trackers != null) {
for (var tracker : trackers) {
taskManager.unregisterRemovedTaskListener(tracker);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,14 @@ public void testWaitForExistingDownload() {

assertTrue(action.existingDownloadInProgress(modelId, true, ActionListener.noop()));
verify(taskManager).registerRemovedTaskListener(any());
assertThat(action.downloadTrackersByModelId.entrySet(), hasSize(1));
assertThat(action.downloadTrackersByModelId.get(modelId), hasSize(1));
assertThat(action.taskRemovedListenersByModelId.entrySet(), hasSize(1));
assertThat(action.taskRemovedListenersByModelId.get(modelId), hasSize(1));

// With wait for completion == false no new removed listener will be added
assertTrue(action.existingDownloadInProgress(modelId, false, ActionListener.noop()));
verify(taskManager, times(1)).registerRemovedTaskListener(any());
assertThat(action.downloadTrackersByModelId.entrySet(), hasSize(1));
assertThat(action.downloadTrackersByModelId.get(modelId), hasSize(1));
assertThat(action.taskRemovedListenersByModelId.entrySet(), hasSize(1));
assertThat(action.taskRemovedListenersByModelId.get(modelId), hasSize(1));

assertFalse(action.existingDownloadInProgress("no-task-for-this-one", randomBoolean(), ActionListener.noop()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ protected void masterOperation(
() -> "[" + request.getDeploymentId() + "] creating new assignment for model [" + request.getModelId() + "] failed",
e
);
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException) {
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException resourceAlreadyExistsException) {
e = new ElasticsearchStatusException(
"Cannot start deployment [{}] because it has already been started",
RestStatus.CONFLICT,
e,
resourceAlreadyExistsException,
request.getDeploymentId()
);
}
Expand Down

0 comments on commit 2509f92

Please sign in to comment.