diff --git a/docs/changelog/95440.yaml b/docs/changelog/95440.yaml new file mode 100644 index 0000000000000..e6b888d55bff8 --- /dev/null +++ b/docs/changelog/95440.yaml @@ -0,0 +1,5 @@ +pr: 95440 +summary: "[ML] Get trained model stats by deployment id or model id" +area: Machine Learning +type: enhancement +issues: [] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java index 869e1059f63ce..7cd0cf60ae057 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java @@ -12,7 +12,6 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.core.RestApiVersion; import org.elasticsearch.ingest.IngestStats; @@ -30,7 +29,6 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collections; -import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Objects; @@ -236,7 +234,7 @@ protected Reader getReader() { public static class Builder { private long totalModelCount; - private Map> expandedIdsWithAliases; + private Map> expandedModelIdsWithAliases; private Map modelSizeStatsMap; private Map ingestStatsMap; private Map inferenceStatsMap; @@ -247,13 +245,13 @@ public Builder setTotalModelCount(long totalModelCount) { return this; } - public Builder setExpandedIdsWithAliases(Map> expandedIdsWithAliases) { - this.expandedIdsWithAliases = expandedIdsWithAliases; + public Builder setExpandedModelIdsWithAliases(Map> expandedIdsWithAliases) { + this.expandedModelIdsWithAliases = expandedIdsWithAliases; return this; } - public Map> getExpandedIdsWithAliases() { - return this.expandedIdsWithAliases; + public Map> getExpandedModelIdsWithAliases() { + return this.expandedModelIdsWithAliases; } public Builder setModelSizeStatsByModelId(Map modelSizeStatsByModelId) { @@ -276,36 +274,86 @@ public Builder setInferenceStatsByModelId(Map inferenceS * @param assignmentStatsMap map of model_id to assignment stats * @return the builder with inference stats map updated and assignment stats map set */ - public Builder setDeploymentStatsByModelId(Map assignmentStatsMap) { + public Builder setDeploymentStatsByDeploymentId(Map assignmentStatsMap) { this.assignmentStatsMap = assignmentStatsMap; - if (inferenceStatsMap == null) { - inferenceStatsMap = Maps.newHashMapWithExpectedSize(assignmentStatsMap.size()); - } - assignmentStatsMap.forEach( - (modelId, assignmentStats) -> inferenceStatsMap.put(modelId, assignmentStats.getOverallInferenceStats()) - ); return this; } - public Response build() { - List trainedModelStats = new ArrayList<>(expandedIdsWithAliases.size()); - expandedIdsWithAliases.keySet().forEach(id -> { - TrainedModelSizeStats modelSizeStats = modelSizeStatsMap.get(id); - IngestStats ingestStats = ingestStatsMap.get(id); - InferenceStats inferenceStats = inferenceStatsMap.get(id); - AssignmentStats assignmentStats = assignmentStatsMap.get(id); - trainedModelStats.add( - new TrainedModelStats( - id, - modelSizeStats, - ingestStats, - ingestStats == null ? 0 : ingestStats.getPipelineStats().size(), - inferenceStats, - assignmentStats - ) - ); + public Response build(Map> modelToDeploymentIds) { + int numResponses = expandedModelIdsWithAliases.size(); + // plus an extra response for every deployment after + // the first per model + for (var entry : modelToDeploymentIds.entrySet()) { + assert expandedModelIdsWithAliases.containsKey(entry.getKey()); // model id + assert entry.getValue().size() > 0; // must have a deployment + numResponses += entry.getValue().size() - 1; + } + + if (inferenceStatsMap == null) { + inferenceStatsMap = Collections.emptyMap(); + } + + List trainedModelStats = new ArrayList<>(numResponses); + expandedModelIdsWithAliases.keySet().forEach(modelId -> { + if (modelToDeploymentIds.containsKey(modelId) == false) { // not deployed + TrainedModelSizeStats modelSizeStats = modelSizeStatsMap.get(modelId); + IngestStats ingestStats = ingestStatsMap.get(modelId); + InferenceStats inferenceStats = inferenceStatsMap.get(modelId); + trainedModelStats.add( + new TrainedModelStats( + modelId, + modelSizeStats, + ingestStats, + ingestStats == null ? 0 : ingestStats.getPipelineStats().size(), + inferenceStats, + null // no assignment stats for undeployed models + ) + ); + } else { + for (var deploymentId : modelToDeploymentIds.get(modelId)) { + AssignmentStats assignmentStats = assignmentStatsMap.get(deploymentId); + if (assignmentStats == null) { + continue; + } + InferenceStats inferenceStats = assignmentStats.getOverallInferenceStats(); + IngestStats ingestStats = ingestStatsMap.get(deploymentId); + if (ingestStats == null) { + // look up by model id + ingestStats = ingestStatsMap.get(modelId); + } + TrainedModelSizeStats modelSizeStats = modelSizeStatsMap.get(modelId); + trainedModelStats.add( + new TrainedModelStats( + modelId, + modelSizeStats, + ingestStats, + ingestStats == null ? 0 : ingestStats.getPipelineStats().size(), + inferenceStats, + assignmentStats + ) + ); + } + } + }); + + // Sort first by model id then by deployment id + trainedModelStats.sort((modelStats1, modelStats2) -> { + var comparison = modelStats1.getModelId().compareTo(modelStats2.getModelId()); + if (comparison == 0) { + var deploymentId1 = modelStats1.getDeploymentStats() == null + ? null + : modelStats1.getDeploymentStats().getDeploymentId(); + var deploymentId2 = modelStats2.getDeploymentStats() == null + ? null + : modelStats1.getDeploymentStats().getDeploymentId(); + + assert deploymentId1 != null && deploymentId2 != null + : "2 results for model " + modelStats1.getModelId() + " both should have deployment stats"; + + comparison = deploymentId1.compareTo(deploymentId2); + } + return comparison; }); - trainedModelStats.sort(Comparator.comparing(TrainedModelStats::getModelId)); return new Response(new QueryPage<>(trainedModelStats, totalModelCount, RESULTS_FIELD)); } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MultipleDeploymentsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MultipleDeploymentsIT.java index e7d06ea8c9ae8..e78f761756e6a 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MultipleDeploymentsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MultipleDeploymentsIT.java @@ -8,18 +8,24 @@ package org.elasticsearch.xpack.ml.integration; import org.elasticsearch.client.Response; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.xpack.core.ml.utils.MapHelper; import java.io.IOException; +import java.util.HashSet; import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; public class MultipleDeploymentsIT extends PyTorchModelRestTestCase { @SuppressWarnings("unchecked") public void testDeployModelMultipleTimes() throws IOException { String baseModelId = "base-model"; - createPassThroughModel(baseModelId); - putModelDefinition(baseModelId); - putVocabulary(List.of("these", "are", "my", "words"), baseModelId); + putAllModelParts(baseModelId); String forSearch = "for-search"; startWithDeploymentId(baseModelId, forSearch); @@ -35,12 +41,141 @@ public void testDeployModelMultipleTimes() throws IOException { inference = infer("my words", forIngest); assertOK(inference); - // TODO - // assertInferenceCount(1, forSearch); - // assertInferenceCount(2, forIngest); + assertInferenceCount(1, forSearch); + assertInferenceCount(2, forIngest); stopDeployment(forSearch); stopDeployment(forIngest); + + Response statsResponse = getTrainedModelStats("_all"); + Map stats = entityAsMap(statsResponse); + List> trainedModelStats = (List>) stats.get("trained_model_stats"); + assertThat(stats.toString(), trainedModelStats, hasSize(2)); + + for (var statsMap : trainedModelStats) { + // no deployment stats when the deployment is stopped + assertNull(stats.toString(), statsMap.get("deployment_stats")); + } + } + + @SuppressWarnings("unchecked") + public void testGetStats() throws IOException { + String undeployedModel1 = "undeployed_1"; + putAllModelParts(undeployedModel1); + String undeployedModel2 = "undeployed_2"; + putAllModelParts(undeployedModel2); + + String modelWith1Deployment = "model-with-1-deployment"; + putAllModelParts(modelWith1Deployment); + + String modelWith2Deployments = "model-with-2-deployments"; + putAllModelParts(modelWith2Deployments); + String forSearchDeployment = "for-search"; + startWithDeploymentId(modelWith2Deployments, forSearchDeployment); + String forIngestDeployment = "for-ingest"; + startWithDeploymentId(modelWith2Deployments, forIngestDeployment); + + // deployment Id is the same as model + startDeployment(modelWith1Deployment); + + { + Map stats = entityAsMap(getTrainedModelStats("_all")); + List> trainedModelStats = (List>) stats.get("trained_model_stats"); + checkExpectedStats( + List.of( + new Tuple<>(undeployedModel1, null), + new Tuple<>(undeployedModel2, null), + new Tuple<>(modelWith1Deployment, modelWith1Deployment), + new Tuple<>(modelWith2Deployments, forSearchDeployment), + new Tuple<>(modelWith2Deployments, forIngestDeployment) + ), + trainedModelStats, + true + ); + + // check the sorted order + assertEquals(trainedModelStats.get(0).get("model_id"), "lang_ident_model_1"); + assertEquals(trainedModelStats.get(1).get("model_id"), modelWith1Deployment); + assertEquals(MapHelper.dig("deployment_stats.deployment_id", trainedModelStats.get(1)), modelWith1Deployment); + assertEquals(trainedModelStats.get(2).get("model_id"), modelWith2Deployments); + assertEquals(MapHelper.dig("deployment_stats.deployment_id", trainedModelStats.get(2)), forIngestDeployment); + assertEquals(trainedModelStats.get(3).get("model_id"), modelWith2Deployments); + assertEquals(MapHelper.dig("deployment_stats.deployment_id", trainedModelStats.get(3)), forSearchDeployment); + assertEquals(trainedModelStats.get(4).get("model_id"), undeployedModel1); + assertEquals(trainedModelStats.get(5).get("model_id"), undeployedModel2); + } + { + Map stats = entityAsMap(getTrainedModelStats(modelWith1Deployment)); + List> trainedModelStats = (List>) stats.get("trained_model_stats"); + checkExpectedStats(List.of(new Tuple<>(modelWith1Deployment, modelWith1Deployment)), trainedModelStats); + } + { + Map stats = entityAsMap(getTrainedModelStats(modelWith2Deployments)); + List> trainedModelStats = (List>) stats.get("trained_model_stats"); + checkExpectedStats( + List.of(new Tuple<>(modelWith2Deployments, forSearchDeployment), new Tuple<>(modelWith2Deployments, forIngestDeployment)), + trainedModelStats + ); + } + { + Map stats = entityAsMap(getTrainedModelStats(forIngestDeployment)); + List> trainedModelStats = (List>) stats.get("trained_model_stats"); + checkExpectedStats(List.of(new Tuple<>(modelWith2Deployments, forIngestDeployment)), trainedModelStats); + } + { + // wildcard model id matching + Map stats = entityAsMap(getTrainedModelStats("model-with-*")); + List> trainedModelStats = (List>) stats.get("trained_model_stats"); + checkExpectedStats( + List.of( + new Tuple<>(modelWith1Deployment, modelWith1Deployment), + new Tuple<>(modelWith2Deployments, forSearchDeployment), + new Tuple<>(modelWith2Deployments, forIngestDeployment) + ), + trainedModelStats + ); + } + { + // wildcard deployment id matching + Map stats = entityAsMap(getTrainedModelStats("for-*")); + List> trainedModelStats = (List>) stats.get("trained_model_stats"); + checkExpectedStats( + List.of(new Tuple<>(modelWith2Deployments, forSearchDeployment), new Tuple<>(modelWith2Deployments, forIngestDeployment)), + trainedModelStats + ); + } + } + + private void checkExpectedStats(List> modelDeploymentPairs, List> trainedModelStats) { + checkExpectedStats(modelDeploymentPairs, trainedModelStats, false); + } + + private void checkExpectedStats( + List> modelDeploymentPairs, + List> trainedModelStats, + boolean plusOneForLangIdent + ) { + var concatenatedIds = new HashSet(); + modelDeploymentPairs.forEach(t -> concatenatedIds.add(t.v1() + t.v2())); + + int expectedSize = modelDeploymentPairs.size(); + if (plusOneForLangIdent) { + expectedSize++; + } + assertEquals(trainedModelStats.toString(), trainedModelStats.size(), expectedSize); + for (var tmStats : trainedModelStats) { + String modelId = (String) tmStats.get("model_id"); + String deploymentId = (String) XContentMapValues.extractValue("deployment_stats.deployment_id", tmStats); + concatenatedIds.remove(modelId + deploymentId); + } + + assertThat("Missing stats for " + concatenatedIds, concatenatedIds, empty()); + } + + private void putAllModelParts(String modelId) throws IOException { + createPassThroughModel(modelId); + putModelDefinition(modelId); + putVocabulary(List.of("these", "are", "my", "words"), modelId); } private void putModelDefinition(String modelId) throws IOException { diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelRestTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelRestTestCase.java index d0a56fb446b5f..94773c0b78b2a 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelRestTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelRestTestCase.java @@ -60,7 +60,8 @@ public void setLogging() throws IOException { "logger.org.elasticsearch.xpack.ml.inference.assignment" : "DEBUG", "logger.org.elasticsearch.xpack.ml.inference.deployment" : "DEBUG", "logger.org.elasticsearch.xpack.ml.inference.pytorch" : "DEBUG", - "logger.org.elasticsearch.xpack.ml.process.logging" : "DEBUG" + "logger.org.elasticsearch.xpack.ml.process.logging" : "DEBUG", + "logger.org.elasticsearch.xpack.ml.action" : "DEBUG" }}"""); client().performRequest(loggingSettings); } @@ -122,15 +123,24 @@ protected void assertAllocationCount(String modelId, int expectedAllocationCount @SuppressWarnings("unchecked") protected void assertInferenceCount(int expectedCount, String deploymentId) throws IOException { - Response noInferenceCallsStatsResponse = getTrainedModelStats(deploymentId); - Map stats = entityAsMap(noInferenceCallsStatsResponse); + Response statsResponse = getTrainedModelStats(deploymentId); + Map stats = entityAsMap(statsResponse); + List> trainedModelStats = (List>) stats.get("trained_model_stats"); + + boolean deploymentFound = false; + for (var statsMap : trainedModelStats) { + var deploymentStats = (Map) XContentMapValues.extractValue("deployment_stats", statsMap); + // find the matching deployment + if (deploymentId.equals(deploymentStats.get("deployment_id"))) { + List> nodes = (List>) XContentMapValues.extractValue("nodes", deploymentStats); + int inferenceCount = sumInferenceCountOnNodes(nodes); + assertEquals(stats.toString(), expectedCount, inferenceCount); + deploymentFound = true; + break; + } + } - List> nodes = (List>) XContentMapValues.extractValue( - "trained_model_stats.0.deployment_stats.nodes", - stats - ); - int inferenceCount = sumInferenceCountOnNodes(nodes); - assertEquals(expectedCount, inferenceCount); + assertTrue("No deployment stats found for deployment [" + deploymentId + "]", deploymentFound); } protected int sumInferenceCountOnNodes(List> nodes) { diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java index 72acc7dd181b8..a604ff5ffc5a4 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java @@ -113,6 +113,7 @@ public void testStoreModelViaChunkedPersisterWithNodeInfo() throws IOException { Collections.emptySet(), ModelAliasMetadata.EMPTY, null, + Collections.emptySet(), getIdsFuture ); Tuple>> ids = getIdsFuture.actionGet(); @@ -184,6 +185,7 @@ public void testStoreModelViaChunkedPersisterWithoutNodeInfo() throws IOExceptio Collections.emptySet(), ModelAliasMetadata.EMPTY, null, + Collections.emptySet(), getIdsFuture ); Tuple>> ids = getIdsFuture.actionGet(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java index 8d03174ac199c..833955400b3cd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java @@ -134,7 +134,7 @@ protected void masterOperation( } } - if (TrainedModelAssignmentMetadata.fromState(state).isAssigned(request.getId())) { + if (TrainedModelAssignmentMetadata.fromState(state).modelIsDeployed(request.getId())) { if (request.isForce()) { forceStopDeployment( request.getId(), diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java index 804f4e9436a6f..d3eb172c94956 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java @@ -144,6 +144,7 @@ protected void doExecute(Task task, Request request, ActionListener li new HashSet<>(request.getTags()), ModelAliasMetadata.fromState(clusterService.state()), parentTaskId, + Collections.emptySet(), idExpansionListener ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java index 6e76147c7be7d..2ad4fed704244 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java @@ -6,7 +6,10 @@ */ package org.elasticsearch.xpack.ml.action; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.StepListener; import org.elasticsearch.action.admin.cluster.node.stats.NodeStats; import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction; import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest; @@ -18,6 +21,7 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.document.DocumentField; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.metrics.CounterMetric; @@ -32,6 +36,7 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher; import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction; @@ -43,12 +48,15 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelSizeStats; import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; +import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -59,12 +67,14 @@ import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; -import static org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor.pipelineIdsByModelIdsOrAliases; +import static org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor.pipelineIdsByResource; public class TransportGetTrainedModelsStatsAction extends HandledTransportAction< GetTrainedModelsStatsAction.Request, GetTrainedModelsStatsAction.Response> { + private static final Logger logger = LogManager.getLogger(TransportGetTrainedModelsStatsAction.class); + private final Client client; private final ClusterService clusterService; private final TrainedModelProvider trainedModelProvider; @@ -90,81 +100,175 @@ protected void doExecute( ActionListener listener ) { final TaskId parentTaskId = new TaskId(clusterService.localNode().getId(), task.getId()); - final ModelAliasMetadata currentMetadata = ModelAliasMetadata.fromState(clusterService.state()); + final ModelAliasMetadata modelAliasMetadata = ModelAliasMetadata.fromState(clusterService.state()); + final TrainedModelAssignmentMetadata assignmentMetadata = TrainedModelAssignmentMetadata.fromState(clusterService.state()); + final Set matchedDeploymentIds = matchedDeploymentIds(request.getResourceId(), assignmentMetadata); + GetTrainedModelsStatsAction.Response.Builder responseBuilder = new GetTrainedModelsStatsAction.Response.Builder(); - ActionListener> modelSizeStatsListener = ActionListener.wrap(modelSizeStatsByModelId -> { + StepListener> modelSizeStatsListener = new StepListener<>(); + modelSizeStatsListener.whenComplete(modelSizeStatsByModelId -> { responseBuilder.setModelSizeStatsByModelId(modelSizeStatsByModelId); - listener.onResponse(responseBuilder.build()); + listener.onResponse( + responseBuilder.build(modelToDeployments(responseBuilder.getExpandedModelIdsWithAliases().keySet(), assignmentMetadata)) + ); }, listener::onFailure); - ActionListener deploymentStatsListener = ActionListener.wrap(deploymentStats -> { - responseBuilder.setDeploymentStatsByModelId( + StepListener deploymentStatsListener = new StepListener<>(); + deploymentStatsListener.whenComplete(deploymentStats -> { + // deployment stats for each matching deployment + // not necessarily for all models + responseBuilder.setDeploymentStatsByDeploymentId( deploymentStats.getStats() .results() .stream() .collect(Collectors.toMap(AssignmentStats::getDeploymentId, Function.identity())) ); - modelSizeStats(responseBuilder.getExpandedIdsWithAliases(), request.isAllowNoResources(), parentTaskId, modelSizeStatsListener); + modelSizeStats( + responseBuilder.getExpandedModelIdsWithAliases(), + request.isAllowNoResources(), + parentTaskId, + modelSizeStatsListener + ); }, listener::onFailure); - ActionListener> inferenceStatsListener = ActionListener.wrap(inferenceStats -> { + StepListener> inferenceStatsListener = new StepListener<>(); + // inference stats are per model and are only + // persisted for boosted tree models + inferenceStatsListener.whenComplete(inferenceStats -> { responseBuilder.setInferenceStatsByModelId( inferenceStats.stream().collect(Collectors.toMap(InferenceStats::getModelId, Function.identity())) ); - GetDeploymentStatsAction.Request getDeploymentStatsRequest = new GetDeploymentStatsAction.Request(request.getResourceId()); - getDeploymentStatsRequest.setParentTask(parentTaskId); - executeAsyncWithOrigin( - client, - ML_ORIGIN, - GetDeploymentStatsAction.INSTANCE, - getDeploymentStatsRequest, - deploymentStatsListener - ); + getDeploymentStats(client, request.getResourceId(), parentTaskId, assignmentMetadata, deploymentStatsListener); }, listener::onFailure); - ActionListener nodesStatsListener = ActionListener.wrap(nodesStatsResponse -> { - Set allPossiblePipelineReferences = responseBuilder.getExpandedIdsWithAliases() + StepListener nodesStatsListener = new StepListener<>(); + nodesStatsListener.whenComplete(nodesStatsResponse -> { + // find all pipelines whether using the model id, + // alias or deployment id. + Set allPossiblePipelineReferences = responseBuilder.getExpandedModelIdsWithAliases() .entrySet() .stream() .flatMap(entry -> Stream.concat(entry.getValue().stream(), Stream.of(entry.getKey()))) .collect(Collectors.toSet()); - Map> pipelineIdsByModelIdsOrAliases = pipelineIdsByModelIdsOrAliases( - clusterService.state(), - allPossiblePipelineReferences - ); + allPossiblePipelineReferences.addAll(matchedDeploymentIds); + + Map> pipelineIdsByResource = pipelineIdsByResource(clusterService.state(), allPossiblePipelineReferences); Map modelIdIngestStats = inferenceIngestStatsByModelId( nodesStatsResponse, - currentMetadata, - pipelineIdsByModelIdsOrAliases + modelAliasMetadata, + pipelineIdsByResource ); responseBuilder.setIngestStatsByModelId(modelIdIngestStats); trainedModelProvider.getInferenceStats( - responseBuilder.getExpandedIdsWithAliases().keySet().toArray(new String[0]), + responseBuilder.getExpandedModelIdsWithAliases().keySet().toArray(new String[0]), parentTaskId, inferenceStatsListener ); }, listener::onFailure); - ActionListener>>> idsListener = ActionListener.wrap(tuple -> { - responseBuilder.setExpandedIdsWithAliases(tuple.v2()).setTotalModelCount(tuple.v1()); - String[] ingestNodes = ingestNodes(clusterService.state()); - NodesStatsRequest nodesStatsRequest = new NodesStatsRequest(ingestNodes).clear() - .addMetric(NodesStatsRequest.Metric.INGEST.metricName()); - nodesStatsRequest.setParentTask(parentTaskId); - executeAsyncWithOrigin(client, ML_ORIGIN, NodesStatsAction.INSTANCE, nodesStatsRequest, nodesStatsListener); + StepListener>>> idsListener = new StepListener<>(); + idsListener.whenComplete(tuple -> { + responseBuilder.setExpandedModelIdsWithAliases(tuple.v2()).setTotalModelCount(tuple.v1()); + executeAsyncWithOrigin( + client, + ML_ORIGIN, + NodesStatsAction.INSTANCE, + nodeStatsRequest(clusterService.state(), parentTaskId), + nodesStatsListener + ); }, listener::onFailure); + + // When the request resource is a deployment find the + // model used in that deployment for the model stats + String idExpression = addModelsUsedInMatchingDeployments(request.getResourceId(), assignmentMetadata); + logger.debug("Expanded models/deployment Ids request [{}]", idExpression); + + // the request id may contain deployment ids + // It is not an error if these don't match a model id but + // they need to be included in case the deployment id is also + // a model id. Hence, the `matchedDeploymentIds` parameter trainedModelProvider.expandIds( - request.getResourceId(), + idExpression, request.isAllowNoResources(), request.getPageParams(), Collections.emptySet(), - currentMetadata, + modelAliasMetadata, parentTaskId, + matchedDeploymentIds, idsListener ); } + static String addModelsUsedInMatchingDeployments(String idExpression, TrainedModelAssignmentMetadata assignmentMetadata) { + if (Strings.isAllOrWildcard(idExpression)) { + return idExpression; + } else { + var tokens = new HashSet<>(Arrays.asList(ExpandedIdsMatcher.tokenizeExpression(idExpression))); + var modelsUsedByMatchingDeployments = modelsUsedByMatchingDeploymentId(idExpression, assignmentMetadata); + tokens.addAll(modelsUsedByMatchingDeployments); + return String.join(",", tokens); + } + } + + static Map> modelToDeployments(Set modelIds, TrainedModelAssignmentMetadata assignments) { + var modelToDeploymentMap = new HashMap>(); + for (var assignment : assignments.allAssignments().values()) { + if (modelIds.contains(assignment.getModelId())) { + modelToDeploymentMap.computeIfAbsent(assignment.getModelId(), k -> new HashSet<>()).add(assignment.getDeploymentId()); + } + } + return modelToDeploymentMap; + } + + static Set matchedDeploymentIds(String resourceId, TrainedModelAssignmentMetadata assignments) { + var deploymentIds = new HashSet(); + var matcher = new ExpandedIdsMatcher.SimpleIdsMatcher(resourceId); + for (var assignment : assignments.allAssignments().values()) { + if (matcher.idMatches(assignment.getDeploymentId())) { + deploymentIds.add(assignment.getDeploymentId()); + } + } + return deploymentIds; + } + + static Set modelsUsedByMatchingDeploymentId(String resourceId, TrainedModelAssignmentMetadata assignments) { + var modelIds = new HashSet(); + var matcher = new ExpandedIdsMatcher.SimpleIdsMatcher(resourceId); + for (var assignment : assignments.allAssignments().values()) { + if (matcher.idMatches(assignment.getDeploymentId())) { + modelIds.add(assignment.getModelId()); + } + } + return modelIds; + } + + static void getDeploymentStats( + Client client, + String resourceId, + TaskId parentTaskId, + TrainedModelAssignmentMetadata assignments, + ActionListener deploymentStatsListener + ) { + // include all matched deployments and models + var matcher = new ExpandedIdsMatcher.SimpleIdsMatcher(resourceId); + var matchedDeployments = new HashSet(); + for (var assignment : assignments.allAssignments().values()) { + if (matcher.idMatches(assignment.getDeploymentId())) { + matchedDeployments.add(assignment.getDeploymentId()); + } else if (matcher.idMatches(assignment.getModelId())) { + matchedDeployments.add(assignment.getDeploymentId()); + } + } + String deployments = matchedDeployments.stream().collect(Collectors.joining(",")); + + logger.info("Fetching stats for deployments [{}]", deployments); + + GetDeploymentStatsAction.Request getDeploymentStatsRequest = new GetDeploymentStatsAction.Request(deployments); + getDeploymentStatsRequest.setParentTask(parentTaskId); + executeAsyncWithOrigin(client, ML_ORIGIN, GetDeploymentStatsAction.INSTANCE, getDeploymentStatsRequest, deploymentStatsListener); + } + private void modelSizeStats( Map> expandedIdsWithAliases, boolean allowNoResources, @@ -260,8 +364,12 @@ static Map inferenceIngestStatsByModelId( return ingestStatsMap; } - static String[] ingestNodes(final ClusterState clusterState) { - return clusterState.nodes().getIngestNodes().keySet().toArray(String[]::new); + static NodesStatsRequest nodeStatsRequest(ClusterState state, TaskId parentTaskId) { + String[] ingestNodes = state.nodes().getIngestNodes().keySet().toArray(String[]::new); + NodesStatsRequest nodesStatsRequest = new NodesStatsRequest(ingestNodes).clear() + .addMetric(NodesStatsRequest.Metric.INGEST.metricName()); + nodesStatsRequest.setParentTask(parentTaskId); + return nodesStatsRequest; } static IngestStats ingestStatsForPipelineIds(NodeStats nodeStats, Set pipelineIds) { @@ -361,4 +469,6 @@ IngestStats.Stats build() { } } + private record ModelAndDeployment(String modelId, String deploymentId) {} + } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadata.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadata.java index 02ca02acffad2..0927c96750047 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadata.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadata.java @@ -114,6 +114,17 @@ public boolean isAssigned(String deploymentId) { return deploymentRoutingEntries.containsKey(deploymentId); } + public boolean modelIsDeployed(String modelId) { + return deploymentRoutingEntries.values().stream().anyMatch(assignment -> modelId.equals(assignment.getModelId())); + } + + public List getDeploymentsUsingModel(String modelId) { + return deploymentRoutingEntries.values() + .stream() + .filter(assignment -> modelId.equals(assignment.getModelId())) + .collect(Collectors.toList()); + } + public Map allAssignments() { return Collections.unmodifiableMap(deploymentRoutingEntries); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 5a8f9db285968..bb3745684609e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -863,6 +863,23 @@ public void deleteTrainedModel(String modelId, ActionListener listener) })); } + /** + * Returns a Tuple of + * - hit count: the number of matching model Ids + * - Map model id -> aliases: All matched model Ids and + * the list of aliases that reference the model Id + * + * @param idExpression The expression to expand + * @param allowNoResources When wildcard expressions are used allow + * no matches (don't error) + * @param pageParams paging + * @param tags Tags the model must contain + * @param modelAliasMetadata Aliases + * @param parentTaskId Optional parent task Id + * @param previouslyMatchedIds Ids that have already been matched (e.g. deployment Id). + * It is not an error if these Ids are not matched in the query + * @param idsListener The listener + */ public void expandIds( String idExpression, boolean allowNoResources, @@ -870,6 +887,7 @@ public void expandIds( Set tags, ModelAliasMetadata modelAliasMetadata, @Nullable TaskId parentTaskId, + Set previouslyMatchedIds, ActionListener>>> idsListener ) { String[] tokens = Strings.tokenizeToStringArray(idExpression, ","); @@ -973,6 +991,7 @@ public void expandIds( // Reverse lookup to see what model aliases were matched by their found trained model IDs ExpandedIdsMatcher requiredMatches = new ExpandedIdsMatcher(tokens, allowNoResources); requiredMatches.filterMatchedIds(matchedTokens); + requiredMatches.filterMatchedIds(previouslyMatchedIds); if (requiredMatches.hasUnmatchedIds()) { idsListener.onFailure(ExceptionsHelper.missingTrainedModel(requiredMatches.unmatchedIdsString())); } else { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractor.java index 447e6d35f7aad..80bc6f208a501 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractor.java @@ -71,10 +71,10 @@ public static int countInferenceProcessors(ClusterState state) { /** * @param state Current cluster state - * @return a map from Model IDs or Aliases to each pipeline referencing them. + * @return a map from Model or Deployment IDs or Aliases to each pipeline referencing them. */ @SuppressWarnings("unchecked") - public static Map> pipelineIdsByModelIdsOrAliases(ClusterState state, Set modelIds) { + public static Map> pipelineIdsByResource(ClusterState state, Set ids) { Map> pipelineIdsByModelIds = new HashMap<>(); Metadata metadata = state.metadata(); if (metadata == null) { @@ -90,7 +90,7 @@ public static Map> pipelineIdsByModelIdsOrAliases(ClusterSta for (Map processorConfigWithKey : processorConfigs) { for (Map.Entry entry : processorConfigWithKey.entrySet()) { addModelsAndPipelines(entry.getKey(), pipelineId, (Map) entry.getValue(), pam -> { - if (modelIds.contains(pam.modelIdOrAlias)) { + if (ids.contains(pam.modelIdOrAlias)) { pipelineIdsByModelIds.computeIfAbsent(pam.modelIdOrAlias, m -> new LinkedHashSet<>()).add(pipelineId); } }, 0); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java index 8d98f32537ecb..00e4f48447ef7 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java @@ -23,6 +23,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; public class TrainedModelAssignmentMetadataTests extends AbstractChunkedSerializingTestCase { @@ -56,7 +57,7 @@ protected TrainedModelAssignmentMetadata mutateInstance(TrainedModelAssignmentMe return null;// TODO implement https://github.com/elastic/elasticsearch/issues/25929 } - public void testIsAllocated() { + public void testIsAssigned() { String allocatedModelId = "test_model_id"; String allocatedDeploymentId = "test_deployment"; TrainedModelAssignmentMetadata metadata = TrainedModelAssignmentMetadata.Builder.empty() @@ -66,9 +67,43 @@ public void testIsAllocated() { ) .build(); assertThat(metadata.isAssigned(allocatedDeploymentId), is(true)); + assertThat(metadata.isAssigned(allocatedModelId), is(false)); assertThat(metadata.isAssigned("unknown_model_id"), is(false)); } + public void testModelIsDeployed() { + String allocatedModelId = "test_model_id"; + String allocatedDeploymentId = "test_deployment"; + TrainedModelAssignmentMetadata metadata = TrainedModelAssignmentMetadata.Builder.empty() + .addNewAssignment( + allocatedDeploymentId, + TrainedModelAssignment.Builder.empty(randomParams(allocatedDeploymentId, allocatedModelId)) + ) + .build(); + assertThat(metadata.modelIsDeployed(allocatedDeploymentId), is(false)); + assertThat(metadata.modelIsDeployed(allocatedModelId), is(true)); + assertThat(metadata.modelIsDeployed("unknown_model_id"), is(false)); + } + + public void testGetDeploymentsUsingModel() { + String modelId1 = "test_model_id_1"; + String deployment1 = "test_deployment_1"; + String deployment2 = "test_deployment_2"; + String deployment3 = "test_deployment_3"; + TrainedModelAssignmentMetadata metadata = TrainedModelAssignmentMetadata.Builder.empty() + .addNewAssignment(deployment1, TrainedModelAssignment.Builder.empty(randomParams(deployment1, modelId1))) + .addNewAssignment(deployment2, TrainedModelAssignment.Builder.empty(randomParams(deployment2, modelId1))) + .addNewAssignment(deployment3, TrainedModelAssignment.Builder.empty(randomParams(deployment3, "different_model"))) + .build(); + var assignments = metadata.getDeploymentsUsingModel(modelId1); + assertThat(assignments, hasSize(2)); + assertEquals(assignments.get(0).getModelId(), modelId1); + assertEquals(assignments.get(1).getModelId(), modelId1); + + assignments = metadata.getDeploymentsUsingModel("not-deployed"); + assertThat(assignments, hasSize(0)); + } + private static StartTrainedModelDeploymentAction.TaskParams randomParams(String deploymentId, String modelId) { return new StartTrainedModelDeploymentAction.TaskParams( modelId, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractorTests.java index cdead95ea5d39..51ac8f0972b2b 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractorTests.java @@ -48,10 +48,7 @@ public void testPipelineIdsByModelIds() throws IOException { ClusterState clusterState = buildClusterStateWithModelReferences(2, modelId1, modelId2, modelId3); - Map> pipelineIdsByModelIds = InferenceProcessorInfoExtractor.pipelineIdsByModelIdsOrAliases( - clusterState, - modelIds - ); + Map> pipelineIdsByModelIds = InferenceProcessorInfoExtractor.pipelineIdsByResource(clusterState, modelIds); assertThat(pipelineIdsByModelIds.keySet(), equalTo(modelIds)); assertThat( diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml index bdbbf0f39a9d6..df3ce84865c38 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml @@ -458,6 +458,14 @@ setup: { "input": "words" } ] } + - do: + ml.get_trained_models_stats: + model_id: test_model + - match: { count: 1 } # one model matched + - match: { trained_model_stats.0.model_id: test_model } + - match: { trained_model_stats.0.deployment_stats.deployment_id: test_model_for_ingest } + - match: { trained_model_stats.1.model_id: test_model } + - match: { trained_model_stats.1.deployment_stats.deployment_id: test_model_for_search } - do: ml.stop_trained_model_deployment: