Skip to content

Commit

Permalink
feat: Add per model metrics (kserve#90)
Browse files Browse the repository at this point in the history
- Add `modelId` parameter to `logTimingMetricDuration` function in `Metrics.java`:
  - `modelmesh_cache_miss_milliseconds`
  - `modelmesh_loadmodel_milliseconds`
  - `modelmesh_unloadmodel_milliseconds`
  - `modelmesh_req_queue_delay_milliseconds`
  - `modelmesh_model_sizing_milliseconds`
  - `modelmesh_age_at_eviction_milliseconds`
- Add `modelId` parameter to `logSizeEventMetric` function in `Metrics.java`:
  - `modelmesh_loading_queue_delay_milliseconds`
  - `modelmesh_loaded_model_size_bytes`
- Add `modelId` and `vModelId` param to `logRequestMetrics` in `Metrics.java`:
  - `modelmesh_invoke_model_milliseconds`
  - `modelmesh_api_request_milliseconds`

Closes red-hat-data-services#60

Signed-off-by: Vedant Mahabaleshwarkar <vmahabal@redhat.com>
Signed-off-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Prashant Sharma <prashsh1@in.ibm.com>
Co-authored-by: Daniele Zonca <dzonca@redhat.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
  • Loading branch information
4 people committed Oct 11, 2023
1 parent f4ea9c3 commit a853e7e
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 74 deletions.
112 changes: 84 additions & 28 deletions src/main/java/com/ibm/watson/modelmesh/Metrics.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.ibm.watson.modelmesh;

import com.google.common.base.Strings;
import com.ibm.watson.prometheus.Counter;
import com.ibm.watson.prometheus.Gauge;
import com.ibm.watson.prometheus.Histogram;
Expand All @@ -36,34 +37,39 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.lang.reflect.Array;
import java.net.SocketAddress;
import java.nio.channels.DatagramChannel;
import java.util.*;
import java.util.Collections;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;

import static com.ibm.watson.modelmesh.Metric.*;
import static com.ibm.watson.modelmesh.Metric.MetricType.*;
import static com.ibm.watson.modelmesh.ModelMesh.M;
import static com.ibm.watson.modelmesh.ModelMeshEnvVars.MMESH_CUSTOM_ENV_VAR;
import static com.ibm.watson.modelmesh.ModelMeshEnvVars.MMESH_METRICS_ENV_VAR;
import static java.util.concurrent.TimeUnit.*;

/**
*
*/
interface Metrics extends AutoCloseable {
boolean isPerModelMetricsEnabled();

boolean isEnabled();

void logTimingMetricSince(Metric metric, long prevTime, boolean isNano);

void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano);
void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano, String modelId);

void logSizeEventMetric(Metric metric, long value);
void logSizeEventMetric(Metric metric, long value, String modelId);

void logGaugeMetric(Metric metric, long value);

Expand Down Expand Up @@ -101,7 +107,7 @@ default void logInstanceStats(final InstanceRecord ir) {
* @param respPayloadSize response payload size in bytes (or -1 if not applicable)
*/
void logRequestMetrics(boolean external, String name, long elapsedNanos, Code code,
int reqPayloadSize, int respPayloadSize);
int reqPayloadSize, int respPayloadSize, String modelId, String vModelId);

default void registerGlobals() {}

Expand All @@ -111,6 +117,11 @@ default void unregisterGlobals() {}
default void close() {}

Metrics NO_OP_METRICS = new Metrics() {
@Override
public boolean isPerModelMetricsEnabled() {
return false;
}

@Override
public boolean isEnabled() {
return false;
Expand All @@ -120,10 +131,10 @@ public boolean isEnabled() {
public void logTimingMetricSince(Metric metric, long prevTime, boolean isNano) {}

@Override
public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano) {}
public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano, String modelId){}

@Override
public void logSizeEventMetric(Metric metric, long value) {}
public void logSizeEventMetric(Metric metric, long value, String modelId){}

@Override
public void logGaugeMetric(Metric metric, long value) {}
Expand All @@ -136,7 +147,7 @@ public void logInstanceStats(InstanceRecord ir) {}

@Override
public void logRequestMetrics(boolean external, String name, long elapsedNanos, Code code,
int reqPayloadSize, int respPayloadSize) {}
int reqPayloadSize, int respPayloadSize, String modelId, String vModelId) {}
};

final class PrometheusMetrics implements Metrics {
Expand All @@ -154,12 +165,14 @@ final class PrometheusMetrics implements Metrics {
private final CollectorRegistry registry;
private final NettyServer metricServer;
private final boolean shortNames;
private final boolean perModelMetricsEnabled;
private final EnumMap<Metric, Collector> metricsMap = new EnumMap<>(Metric.class);

public PrometheusMetrics(Map<String, String> params, Map<String, String> infoMetricParams) throws Exception {
int port = 2112;
boolean shortNames = true;
boolean https = true;
boolean perModelMetricsEnabled = true;
String memMetrics = "all"; // default to all
for (Entry<String, String> ent : params.entrySet()) {
switch (ent.getKey()) {
Expand All @@ -170,6 +183,9 @@ public PrometheusMetrics(Map<String, String> params, Map<String, String> infoMet
throw new Exception("Invalid metrics port: " + ent.getValue());
}
break;
case "per_model_metrics":
perModelMetricsEnabled= "true".equalsIgnoreCase(ent.getValue());
break;
case "fq_names":
shortNames = !"true".equalsIgnoreCase(ent.getValue());
break;
Expand All @@ -188,6 +204,7 @@ public PrometheusMetrics(Map<String, String> params, Map<String, String> infoMet
throw new Exception("Unrecognized metrics config parameter: " + ent.getKey());
}
}
this.perModelMetricsEnabled = perModelMetricsEnabled;

registry = new CollectorRegistry();
for (Metric m : Metric.values()) {
Expand Down Expand Up @@ -220,10 +237,15 @@ public PrometheusMetrics(Map<String, String> params, Map<String, String> infoMet
}

if (m == API_REQUEST_TIME || m == API_REQUEST_COUNT || m == INVOKE_MODEL_TIME
|| m == INVOKE_MODEL_COUNT || m == REQUEST_PAYLOAD_SIZE || m == RESPONSE_PAYLOAD_SIZE) {
builder.labelNames("method", "code");
|| m == INVOKE_MODEL_COUNT || m == REQUEST_PAYLOAD_SIZE || m == RESPONSE_PAYLOAD_SIZE) {
if (this.perModelMetricsEnabled) {
builder.labelNames("method", "code", "modelId", "vModelId");
} else {
builder.labelNames("method", "code");
}
} else if (this.perModelMetricsEnabled && m.type != GAUGE && m.type != COUNTER && m.type != COUNTER_WITH_HISTO) {
builder.labelNames("modelId", "vModelId");
}

Collector collector = builder.name(m.promName).help(m.description).create();
metricsMap.put(m, collector);
if (!m.global) {
Expand Down Expand Up @@ -330,6 +352,11 @@ public void close() {
this.metricServer.close();
}

@Override
public boolean isPerModelMetricsEnabled() {
return perModelMetricsEnabled;
}

@Override
public boolean isEnabled() {
return true;
Expand All @@ -342,13 +369,23 @@ public void logTimingMetricSince(Metric metric, long prevTime, boolean isNano) {
}

@Override
public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano) {
((Histogram) metricsMap.get(metric)).observe(isNano ? elapsed / M : elapsed);
public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano, String modelId) {
Histogram histogram = (Histogram) metricsMap.get(metric);
if (perModelMetricsEnabled && modelId != null) {
histogram.labels(modelId, "").observe(isNano ? elapsed / M : elapsed);
} else {
histogram.observe(isNano ? elapsed / M : elapsed);
}
}

@Override
public void logSizeEventMetric(Metric metric, long value) {
((Histogram) metricsMap.get(metric)).observe(value * metric.newMultiplier);
public void logSizeEventMetric(Metric metric, long value, String modelId) {
Histogram histogram = (Histogram) metricsMap.get(metric);
if (perModelMetricsEnabled) {
histogram.labels(modelId, "").observe(value * metric.newMultiplier);
} else {
histogram.observe(value * metric.newMultiplier);
}
}

@Override
Expand All @@ -365,23 +402,37 @@ public void logCounterMetric(Metric metric) {

@Override
public void logRequestMetrics(boolean external, String name, long elapsedNanos, Code code,
int reqPayloadSize, int respPayloadSize) {
int reqPayloadSize, int respPayloadSize, String modelId, String vModelId) {
final long elapsedMillis = elapsedNanos / M;
final Histogram timingHisto = (Histogram) metricsMap
.get(external ? API_REQUEST_TIME : INVOKE_MODEL_TIME);

int idx = shortNames ? name.indexOf('/') : -1;
final String methodName = idx == -1 ? name : name.substring(idx + 1);

timingHisto.labels(methodName, code.name()).observe(elapsedMillis);

String methodName = idx == -1 ? name : name.substring(idx + 1);
if (perModelMetricsEnabled) {
modelId = Strings.nullToEmpty(modelId);
vModelId = Strings.nullToEmpty(vModelId);
}
if (perModelMetricsEnabled) {
timingHisto.labels(methodName, code.name(), modelId, vModelId).observe(elapsedMillis);
} else {
timingHisto.labels(methodName, code.name()).observe(elapsedMillis);
}
if (reqPayloadSize != -1) {
((Histogram) metricsMap.get(REQUEST_PAYLOAD_SIZE))
.labels(methodName, code.name()).observe(reqPayloadSize);
Histogram reqPayloadHisto = (Histogram) metricsMap.get(REQUEST_PAYLOAD_SIZE);
if (perModelMetricsEnabled) {
reqPayloadHisto.labels(methodName, code.name(), modelId, vModelId).observe(reqPayloadSize);
} else {
reqPayloadHisto.labels(methodName, code.name()).observe(reqPayloadSize);
}
}
if (respPayloadSize != -1) {
((Histogram) metricsMap.get(RESPONSE_PAYLOAD_SIZE))
.labels(methodName, code.name()).observe(respPayloadSize);
Histogram respPayloadHisto = (Histogram) metricsMap.get(RESPONSE_PAYLOAD_SIZE);
if (perModelMetricsEnabled) {
respPayloadHisto.labels(methodName, code.name(), modelId, vModelId).observe(respPayloadSize);
} else {
respPayloadHisto.labels(methodName, code.name()).observe(respPayloadSize);
}
}
}

Expand Down Expand Up @@ -437,6 +488,11 @@ protected StatsDSender createSender(Callable<SocketAddress> addressLookup, int q
+ (shortNames ? "short" : "fully-qualified") + " method names");
}

@Override
public boolean isPerModelMetricsEnabled() {
return false;
}

@Override
public boolean isEnabled() {
return true;
Expand All @@ -454,12 +510,12 @@ public void logTimingMetricSince(Metric metric, long prevTime, boolean isNano) {
}

@Override
public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano) {
public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano, String modelId) {
client.recordExecutionTime(name(metric), isNano ? elapsed / M : elapsed);
}

@Override
public void logSizeEventMetric(Metric metric, long value) {
public void logSizeEventMetric(Metric metric, long value, String modelId) {
if (!legacy) {
value *= metric.newMultiplier;
}
Expand Down Expand Up @@ -497,7 +553,7 @@ static String[] getOkTags(String method, boolean shortName) {

@Override
public void logRequestMetrics(boolean external, String name, long elapsedNanos, Code code,
int reqPayloadSize, int respPayloadSize) {
int reqPayloadSize, int respPayloadSize, String modelId, String vModelId) {
final StatsDClient client = this.client;
final long elapsedMillis = elapsedNanos / M;
final String countName = name(external ? API_REQUEST_COUNT : INVOKE_MODEL_COUNT);
Expand Down
Loading

0 comments on commit a853e7e

Please sign in to comment.