Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Get stats by deployment or model id #95440

Merged
merged 4 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/95440.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 95440
summary: "[ML} Get stats by deployment or model id"
davidkyle marked this conversation as resolved.
Show resolved Hide resolved
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.core.RestApiVersion;
import org.elasticsearch.ingest.IngestStats;
Expand All @@ -30,7 +29,6 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -236,7 +234,7 @@ protected Reader<Response.TrainedModelStats> getReader() {
public static class Builder {

private long totalModelCount;
private Map<String, Set<String>> expandedIdsWithAliases;
private Map<String, Set<String>> expandedModelIdsWithAliases;
private Map<String, TrainedModelSizeStats> modelSizeStatsMap;
private Map<String, IngestStats> ingestStatsMap;
private Map<String, InferenceStats> inferenceStatsMap;
Expand All @@ -247,13 +245,13 @@ public Builder setTotalModelCount(long totalModelCount) {
return this;
}

public Builder setExpandedIdsWithAliases(Map<String, Set<String>> expandedIdsWithAliases) {
this.expandedIdsWithAliases = expandedIdsWithAliases;
public Builder setExpandedModelIdsWithAliases(Map<String, Set<String>> expandedIdsWithAliases) {
this.expandedModelIdsWithAliases = expandedIdsWithAliases;
return this;
}

public Map<String, Set<String>> getExpandedIdsWithAliases() {
return this.expandedIdsWithAliases;
public Map<String, Set<String>> getExpandedModelIdsWithAliases() {
return this.expandedModelIdsWithAliases;
}

public Builder setModelSizeStatsByModelId(Map<String, TrainedModelSizeStats> modelSizeStatsByModelId) {
Expand All @@ -276,36 +274,86 @@ public Builder setInferenceStatsByModelId(Map<String, InferenceStats> inferenceS
* @param assignmentStatsMap map of model_id to assignment stats
* @return the builder with inference stats map updated and assignment stats map set
*/
public Builder setDeploymentStatsByModelId(Map<String, AssignmentStats> assignmentStatsMap) {
public Builder setDeploymentStatsByDeploymentId(Map<String, AssignmentStats> assignmentStatsMap) {
this.assignmentStatsMap = assignmentStatsMap;
if (inferenceStatsMap == null) {
inferenceStatsMap = Maps.newHashMapWithExpectedSize(assignmentStatsMap.size());
}
assignmentStatsMap.forEach(
(modelId, assignmentStats) -> inferenceStatsMap.put(modelId, assignmentStats.getOverallInferenceStats())
);
return this;
}

public Response build() {
List<TrainedModelStats> trainedModelStats = new ArrayList<>(expandedIdsWithAliases.size());
expandedIdsWithAliases.keySet().forEach(id -> {
TrainedModelSizeStats modelSizeStats = modelSizeStatsMap.get(id);
IngestStats ingestStats = ingestStatsMap.get(id);
InferenceStats inferenceStats = inferenceStatsMap.get(id);
AssignmentStats assignmentStats = assignmentStatsMap.get(id);
trainedModelStats.add(
new TrainedModelStats(
id,
modelSizeStats,
ingestStats,
ingestStats == null ? 0 : ingestStats.getPipelineStats().size(),
inferenceStats,
assignmentStats
)
);
public Response build(Map<String, Set<String>> modelToDeploymentIds) {
int numResponses = expandedModelIdsWithAliases.size();
// plus an extra response for every deployment after
// the first per model
for (var entry : modelToDeploymentIds.entrySet()) {
assert expandedModelIdsWithAliases.containsKey(entry.getKey()); // model id
assert entry.getValue().size() > 0; // must have a deployment
numResponses += entry.getValue().size() - 1;
}

if (inferenceStatsMap == null) {
inferenceStatsMap = Collections.emptyMap();
}

List<TrainedModelStats> trainedModelStats = new ArrayList<>(numResponses);
expandedModelIdsWithAliases.keySet().forEach(modelId -> {
if (modelToDeploymentIds.containsKey(modelId) == false) { // not deployed
TrainedModelSizeStats modelSizeStats = modelSizeStatsMap.get(modelId);
IngestStats ingestStats = ingestStatsMap.get(modelId);
InferenceStats inferenceStats = inferenceStatsMap.get(modelId);
trainedModelStats.add(
new TrainedModelStats(
modelId,
modelSizeStats,
ingestStats,
ingestStats == null ? 0 : ingestStats.getPipelineStats().size(),
inferenceStats,
null // no assignment stats for undeployed models
)
);
} else {
for (var deploymentId : modelToDeploymentIds.get(modelId)) {
AssignmentStats assignmentStats = assignmentStatsMap.get(deploymentId);
if (assignmentStats == null) {
continue;
}
InferenceStats inferenceStats = assignmentStats.getOverallInferenceStats();
IngestStats ingestStats = ingestStatsMap.get(deploymentId);
if (ingestStats == null) {
// look up by model id
ingestStats = ingestStatsMap.get(modelId);
}
TrainedModelSizeStats modelSizeStats = modelSizeStatsMap.get(modelId);
trainedModelStats.add(
new TrainedModelStats(
modelId,
modelSizeStats,
ingestStats,
ingestStats == null ? 0 : ingestStats.getPipelineStats().size(),
inferenceStats,
assignmentStats
)
);
}
}
});

// Sort first by model id then by deployment id
trainedModelStats.sort((modelStats1, modelStats2) -> {
var comparison = modelStats1.getModelId().compareTo(modelStats2.getModelId());
if (comparison == 0) {
var deploymentId1 = modelStats1.getDeploymentStats() == null
? null
: modelStats1.getDeploymentStats().getDeploymentId();
var deploymentId2 = modelStats2.getDeploymentStats() == null
? null
: modelStats1.getDeploymentStats().getDeploymentId();

assert deploymentId1 != null && deploymentId2 != null
: "2 results for model " + modelStats1.getModelId() + " both should have deployment stats";

comparison = deploymentId1.compareTo(deploymentId2);
}
return comparison;
});
trainedModelStats.sort(Comparator.comparing(TrainedModelStats::getModelId));
return new Response(new QueryPage<>(trainedModelStats, totalModelCount, RESULTS_FIELD));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,24 @@
package org.elasticsearch.xpack.ml.integration;

import org.elasticsearch.client.Response;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;

import java.io.IOException;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasSize;

public class MultipleDeploymentsIT extends PyTorchModelRestTestCase {

@SuppressWarnings("unchecked")
public void testDeployModelMultipleTimes() throws IOException {
String baseModelId = "base-model";
createPassThroughModel(baseModelId);
putModelDefinition(baseModelId);
putVocabulary(List.of("these", "are", "my", "words"), baseModelId);
putAllModelParts(baseModelId);

String forSearch = "for-search";
startWithDeploymentId(baseModelId, forSearch);
Expand All @@ -35,12 +41,141 @@ public void testDeployModelMultipleTimes() throws IOException {
inference = infer("my words", forIngest);
assertOK(inference);

// TODO
// assertInferenceCount(1, forSearch);
// assertInferenceCount(2, forIngest);
assertInferenceCount(1, forSearch);
assertInferenceCount(2, forIngest);

stopDeployment(forSearch);
stopDeployment(forIngest);

Response statsResponse = getTrainedModelStats("_all");
Map<String, Object> stats = entityAsMap(statsResponse);
List<Map<String, Object>> trainedModelStats = (List<Map<String, Object>>) stats.get("trained_model_stats");
assertThat(stats.toString(), trainedModelStats, hasSize(2));

for (var statsMap : trainedModelStats) {
// no deployment stats when the deployment is stopped
assertNull(stats.toString(), statsMap.get("deployment_stats"));
}
}

@SuppressWarnings("unchecked")
public void testGetStats() throws IOException {
String undeployedModel1 = "undeployed_1";
putAllModelParts(undeployedModel1);
String undeployedModel2 = "undeployed_2";
putAllModelParts(undeployedModel2);

String modelWith1Deployment = "model-with-1-deployment";
putAllModelParts(modelWith1Deployment);

String modelWith2Deployments = "model-with-2-deployments";
putAllModelParts(modelWith2Deployments);
String forSearchDeployment = "for-search";
startWithDeploymentId(modelWith2Deployments, forSearchDeployment);
String forIngestDeployment = "for-ingest";
startWithDeploymentId(modelWith2Deployments, forIngestDeployment);

// deployment Id is the same as model
startDeployment(modelWith1Deployment);

{
Map<String, Object> stats = entityAsMap(getTrainedModelStats("_all"));
List<Map<String, Object>> trainedModelStats = (List<Map<String, Object>>) stats.get("trained_model_stats");
checkExpectedStats(
List.of(
new Tuple<>(undeployedModel1, null),
new Tuple<>(undeployedModel2, null),
new Tuple<>(modelWith1Deployment, modelWith1Deployment),
new Tuple<>(modelWith2Deployments, forSearchDeployment),
new Tuple<>(modelWith2Deployments, forIngestDeployment)
),
trainedModelStats,
true
);

// check the sorted order
assertEquals(trainedModelStats.get(0).get("model_id"), "lang_ident_model_1");
assertEquals(trainedModelStats.get(1).get("model_id"), modelWith1Deployment);
assertEquals(MapHelper.dig("deployment_stats.deployment_id", trainedModelStats.get(1)), modelWith1Deployment);
assertEquals(trainedModelStats.get(2).get("model_id"), modelWith2Deployments);
assertEquals(MapHelper.dig("deployment_stats.deployment_id", trainedModelStats.get(2)), forIngestDeployment);
assertEquals(trainedModelStats.get(3).get("model_id"), modelWith2Deployments);
assertEquals(MapHelper.dig("deployment_stats.deployment_id", trainedModelStats.get(3)), forSearchDeployment);
assertEquals(trainedModelStats.get(4).get("model_id"), undeployedModel1);
assertEquals(trainedModelStats.get(5).get("model_id"), undeployedModel2);
}
{
Map<String, Object> stats = entityAsMap(getTrainedModelStats(modelWith1Deployment));
List<Map<String, Object>> trainedModelStats = (List<Map<String, Object>>) stats.get("trained_model_stats");
checkExpectedStats(List.of(new Tuple<>(modelWith1Deployment, modelWith1Deployment)), trainedModelStats);
}
{
Map<String, Object> stats = entityAsMap(getTrainedModelStats(modelWith2Deployments));
List<Map<String, Object>> trainedModelStats = (List<Map<String, Object>>) stats.get("trained_model_stats");
checkExpectedStats(
List.of(new Tuple<>(modelWith2Deployments, forSearchDeployment), new Tuple<>(modelWith2Deployments, forIngestDeployment)),
trainedModelStats
);
}
{
Map<String, Object> stats = entityAsMap(getTrainedModelStats(forIngestDeployment));
List<Map<String, Object>> trainedModelStats = (List<Map<String, Object>>) stats.get("trained_model_stats");
checkExpectedStats(List.of(new Tuple<>(modelWith2Deployments, forIngestDeployment)), trainedModelStats);
}
{
// wildcard model id matching
Map<String, Object> stats = entityAsMap(getTrainedModelStats("model-with-*"));
List<Map<String, Object>> trainedModelStats = (List<Map<String, Object>>) stats.get("trained_model_stats");
checkExpectedStats(
List.of(
new Tuple<>(modelWith1Deployment, modelWith1Deployment),
new Tuple<>(modelWith2Deployments, forSearchDeployment),
new Tuple<>(modelWith2Deployments, forIngestDeployment)
),
trainedModelStats
);
}
{
// wildcard deployment id matching
Map<String, Object> stats = entityAsMap(getTrainedModelStats("for-*"));
List<Map<String, Object>> trainedModelStats = (List<Map<String, Object>>) stats.get("trained_model_stats");
checkExpectedStats(
List.of(new Tuple<>(modelWith2Deployments, forSearchDeployment), new Tuple<>(modelWith2Deployments, forIngestDeployment)),
trainedModelStats
);
}
}

private void checkExpectedStats(List<Tuple<String, String>> modelDeploymentPairs, List<Map<String, Object>> trainedModelStats) {
checkExpectedStats(modelDeploymentPairs, trainedModelStats, false);
}

private void checkExpectedStats(
List<Tuple<String, String>> modelDeploymentPairs,
List<Map<String, Object>> trainedModelStats,
boolean plusOneForLangIdent
) {
var concatenatedIds = new HashSet<String>();
modelDeploymentPairs.forEach(t -> concatenatedIds.add(t.v1() + t.v2()));

int expectedSize = modelDeploymentPairs.size();
if (plusOneForLangIdent) {
expectedSize++;
}
assertEquals(trainedModelStats.toString(), trainedModelStats.size(), expectedSize);
for (var tmStats : trainedModelStats) {
String modelId = (String) tmStats.get("model_id");
String deploymentId = (String) XContentMapValues.extractValue("deployment_stats.deployment_id", tmStats);
concatenatedIds.remove(modelId + deploymentId);
}

assertThat("Missing stats for " + concatenatedIds, concatenatedIds, empty());
}

private void putAllModelParts(String modelId) throws IOException {
createPassThroughModel(modelId);
putModelDefinition(modelId);
putVocabulary(List.of("these", "are", "my", "words"), modelId);
}

private void putModelDefinition(String modelId) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ public void setLogging() throws IOException {
"logger.org.elasticsearch.xpack.ml.inference.assignment" : "DEBUG",
"logger.org.elasticsearch.xpack.ml.inference.deployment" : "DEBUG",
"logger.org.elasticsearch.xpack.ml.inference.pytorch" : "DEBUG",
"logger.org.elasticsearch.xpack.ml.process.logging" : "DEBUG"
"logger.org.elasticsearch.xpack.ml.process.logging" : "DEBUG",
"logger.org.elasticsearch.xpack.ml.action" : "DEBUG"
}}""");
client().performRequest(loggingSettings);
}
Expand Down Expand Up @@ -122,15 +123,24 @@ protected void assertAllocationCount(String modelId, int expectedAllocationCount

@SuppressWarnings("unchecked")
protected void assertInferenceCount(int expectedCount, String deploymentId) throws IOException {
Response noInferenceCallsStatsResponse = getTrainedModelStats(deploymentId);
Map<String, Object> stats = entityAsMap(noInferenceCallsStatsResponse);
Response statsResponse = getTrainedModelStats(deploymentId);
Map<String, Object> stats = entityAsMap(statsResponse);
List<Map<String, Object>> trainedModelStats = (List<Map<String, Object>>) stats.get("trained_model_stats");

boolean deploymentFound = false;
for (var statsMap : trainedModelStats) {
var deploymentStats = (Map<String, Object>) XContentMapValues.extractValue("deployment_stats", statsMap);
// find the matching deployment
if (deploymentId.equals(deploymentStats.get("deployment_id"))) {
List<Map<String, Object>> nodes = (List<Map<String, Object>>) XContentMapValues.extractValue("nodes", deploymentStats);
int inferenceCount = sumInferenceCountOnNodes(nodes);
assertEquals(stats.toString(), expectedCount, inferenceCount);
deploymentFound = true;
break;
}
}

List<Map<String, Object>> nodes = (List<Map<String, Object>>) XContentMapValues.extractValue(
"trained_model_stats.0.deployment_stats.nodes",
stats
);
int inferenceCount = sumInferenceCountOnNodes(nodes);
assertEquals(expectedCount, inferenceCount);
assertTrue("No deployment stats found for deployment [" + deploymentId + "]", deploymentFound);
}

protected int sumInferenceCountOnNodes(List<Map<String, Object>> nodes) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ public void testStoreModelViaChunkedPersisterWithNodeInfo() throws IOException {
Collections.emptySet(),
ModelAliasMetadata.EMPTY,
null,
Collections.emptySet(),
getIdsFuture
);
Tuple<Long, Map<String, Set<String>>> ids = getIdsFuture.actionGet();
Expand Down Expand Up @@ -184,6 +185,7 @@ public void testStoreModelViaChunkedPersisterWithoutNodeInfo() throws IOExceptio
Collections.emptySet(),
ModelAliasMetadata.EMPTY,
null,
Collections.emptySet(),
getIdsFuture
);
Tuple<Long, Map<String, Set<String>>> ids = getIdsFuture.actionGet();
Expand Down
Loading