Skip to content

Commit

Permalink
wip: Add per model metrics
Browse files Browse the repository at this point in the history
Signed-off-by: Prashant Sharma <prashsh1@in.ibm.com>
  • Loading branch information
ScrapCodes committed Feb 24, 2023
1 parent 92aa82f commit 53378d9
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 39 deletions.
62 changes: 38 additions & 24 deletions src/main/java/com/ibm/watson/modelmesh/Metrics.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,8 @@

import java.net.SocketAddress;
import java.nio.channels.DatagramChannel;
import java.util.Collections;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.*;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
Expand All @@ -59,12 +54,11 @@
interface Metrics extends AutoCloseable {

boolean isEnabled();

void logTimingMetricSince(Metric metric, long prevTime, boolean isNano);

void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano);
void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano, String[] extraLabels);

void logSizeEventMetric(Metric metric, long value);
void logSizeEventMetric(Metric metric, long value, String[] extraLabels);

void logGaugeMetric(Metric metric, long value);

Expand Down Expand Up @@ -102,7 +96,7 @@ default void logInstanceStats(final InstanceRecord ir) {
* @param respPayloadSize response payload size in bytes (or -1 if not applicable)
*/
void logRequestMetrics(boolean external, String name, long elapsedNanos, Code code,
int reqPayloadSize, int respPayloadSize);
int reqPayloadSize, int respPayloadSize, String[] extraLabels);

default void registerGlobals() {}

Expand All @@ -121,10 +115,10 @@ public boolean isEnabled() {
public void logTimingMetricSince(Metric metric, long prevTime, boolean isNano) {}

@Override
public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano) {}
public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano, String[] extraLabels) {}

@Override
public void logSizeEventMetric(Metric metric, long value) {}
public void logSizeEventMetric(Metric metric, long value, String[] extraLabels) {}

@Override
public void logGaugeMetric(Metric metric, long value) {}
Expand All @@ -137,7 +131,7 @@ public void logInstanceStats(InstanceRecord ir) {}

@Override
public void logRequestMetrics(boolean external, String name, long elapsedNanos, Code code,
int reqPayloadSize, int respPayloadSize) {}
int reqPayloadSize, int respPayloadSize, String[] extraLabels) {}
};

final class PrometheusMetrics implements Metrics {
Expand All @@ -153,12 +147,14 @@ final class PrometheusMetrics implements Metrics {
private final CollectorRegistry registry;
private final NettyServer metricServer;
private final boolean shortNames;
private final boolean enablePerModelMetrics;
private final EnumMap<Metric, Collector> metricsMap = new EnumMap<>(Metric.class);

public PrometheusMetrics(Map<String, String> params) throws Exception {
int port = 2112;
boolean shortNames = true;
boolean https = true;
boolean enablePerModelMetrics = false;
String memMetrics = "all"; // default to all
for (Entry<String, String> ent : params.entrySet()) {
switch (ent.getKey()) {
Expand All @@ -169,6 +165,9 @@ public PrometheusMetrics(Map<String, String> params) throws Exception {
throw new Exception("Invalid metrics port: " + ent.getValue());
}
break;
case "single_model_metrics":
enablePerModelMetrics = "true".equalsIgnoreCase(ent.getValue());
break;
case "fq_names":
shortNames = "true".equalsIgnoreCase(ent.getValue());
break;
Expand All @@ -187,6 +186,7 @@ public PrometheusMetrics(Map<String, String> params) throws Exception {
throw new Exception("Unrecognized metrics config parameter: " + ent.getKey());
}
}
this.enablePerModelMetrics = enablePerModelMetrics;

registry = new CollectorRegistry();
for (Metric m : Metric.values()) {
Expand Down Expand Up @@ -222,7 +222,9 @@ public PrometheusMetrics(Map<String, String> params) throws Exception {
|| m == INVOKE_MODEL_COUNT || m == REQUEST_PAYLOAD_SIZE || m == RESPONSE_PAYLOAD_SIZE) {
builder.labelNames("method", "code");
}

if (this.enablePerModelMetrics) {
builder.enableArbitraryLabels(true);
}
Collector collector = builder.name(m.promName).help(m.description).create();
metricsMap.put(m, collector);
if (!m.global) {
Expand All @@ -232,7 +234,6 @@ public PrometheusMetrics(Map<String, String> params) throws Exception {

this.metricServer = new NettyServer(registry, port, https);
this.shortNames = shortNames;

logger.info("Will expose " + (https ? "https" : "http") + " Prometheus metrics on port " + port
+ " using " + (shortNames ? "short" : "fully-qualified") + " method names");

Expand Down Expand Up @@ -323,13 +324,21 @@ public void logTimingMetricSince(Metric metric, long prevTime, boolean isNano) {
}

@Override
public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano) {
((Histogram) metricsMap.get(metric)).observe(isNano ? elapsed / M : elapsed);
public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano, String[] extraLabels) {
if (enablePerModelMetrics) {
((Histogram) metricsMap.get(metric)).labels(extraLabels).observe(isNano ? elapsed / M : elapsed);
} else {
((Histogram) metricsMap.get(metric)).observe(isNano ? elapsed / M : elapsed);
}
}

@Override
public void logSizeEventMetric(Metric metric, long value) {
((Histogram) metricsMap.get(metric)).observe(value * metric.newMultiplier);
public void logSizeEventMetric(Metric metric, long value, String[] extraLabels) {
if (enablePerModelMetrics) {
((Histogram) metricsMap.get(metric)).labels(extraLabels).observe(value * metric.newMultiplier);
} else {
((Histogram) metricsMap.get(metric)).observe(value * metric.newMultiplier);
}
}

@Override
Expand All @@ -346,14 +355,19 @@ public void logCounterMetric(Metric metric) {

@Override
public void logRequestMetrics(boolean external, String name, long elapsedNanos, Code code,
int reqPayloadSize, int respPayloadSize) {
int reqPayloadSize, int respPayloadSize, String[] extraLabels) {
final long elapsedMillis = elapsedNanos / M;
final Histogram timingHisto = (Histogram) metricsMap
.get(external ? API_REQUEST_TIME : INVOKE_MODEL_TIME);

int idx = shortNames ? name.indexOf('/') : -1;
final String methodName = idx == -1 ? name : name.substring(idx + 1);

ArrayList<String> labels = new ArrayList<>();
labels.add(methodName);
labels.add(code.name());
if (enablePerModelMetrics) {
labels.addAll(List.of(extraLabels));
}
timingHisto.labels(methodName, code.name()).observe(elapsedMillis);

if (reqPayloadSize != -1) {
Expand Down Expand Up @@ -435,12 +449,12 @@ public void logTimingMetricSince(Metric metric, long prevTime, boolean isNano) {
}

@Override
public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano) {
public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano, String[] extraLabels) {
client.recordExecutionTime(name(metric), isNano ? elapsed / M : elapsed);
}

@Override
public void logSizeEventMetric(Metric metric, long value) {
public void logSizeEventMetric(Metric metric, long value, String[] extraLabels) {
if (!legacy) {
value *= metric.newMultiplier;
}
Expand Down Expand Up @@ -478,7 +492,7 @@ static String[] getOkTags(String method, boolean shortName) {

@Override
public void logRequestMetrics(boolean external, String name, long elapsedNanos, Code code,
int reqPayloadSize, int respPayloadSize) {
int reqPayloadSize, int respPayloadSize, String[] extraLabels) {
final StatsDClient client = this.client;
final long elapsedMillis = elapsedNanos / M;
final String countName = name(external ? API_REQUEST_COUNT : INVOKE_MODEL_COUNT);
Expand Down
26 changes: 14 additions & 12 deletions src/main/java/com/ibm/watson/modelmesh/ModelMesh.java
Original file line number Diff line number Diff line change
Expand Up @@ -1896,7 +1896,7 @@ final synchronized boolean doRemove(final boolean evicted,
// "unload" event if explicit unloading isn't enabled.
// Otherwise, this gets recorded in a callback set in the
// CacheEntry.unload(int) method
metrics.logTimingMetricDuration(Metric.UNLOAD_MODEL_TIME, 0L, false);
metrics.logTimingMetricDuration(Metric.UNLOAD_MODEL_TIME, 0L, false, new String[]{modelId});
metrics.logCounterMetric(Metric.UNLOAD_MODEL);
}
}
Expand Down Expand Up @@ -1967,7 +1967,7 @@ public void onSuccess(Boolean reallyHappened) {
//TODO probably only log if took longer than a certain time
long tookMillis = msSince(beforeNanos);
logger.info("Unload of " + modelId + " completed in " + tookMillis + "ms");
metrics.logTimingMetricDuration(Metric.UNLOAD_MODEL_TIME, tookMillis, false);
metrics.logTimingMetricDuration(Metric.UNLOAD_MODEL_TIME, tookMillis, false, new String[]{modelId});
metrics.logCounterMetric(Metric.UNLOAD_MODEL);
}
// else considered trivially succeeded because the corresponding
Expand Down Expand Up @@ -2088,7 +2088,7 @@ public final void run() {
long queueStartTimeNanos = getAndResetLoadingQueueStartTimeNanos();
if (queueStartTimeNanos > 0) {
long queueDelayMillis = (nanoTime() - queueStartTimeNanos) / M;
metrics.logSizeEventMetric(Metric.LOAD_MODEL_QUEUE_DELAY, queueDelayMillis);
metrics.logSizeEventMetric(Metric.LOAD_MODEL_QUEUE_DELAY, queueDelayMillis, new String[]{modelId});
// Only log if the priority value is "in the future" which indicates
// that there is or were runtime requests waiting for this load.
// Otherwise we don't care about arbitrary delays here
Expand Down Expand Up @@ -2158,7 +2158,7 @@ public final void run() {
loadingTimeStats(modelType).recordTime(tookMillis);
logger.info("Load of model " + modelId + " type=" + modelType + " completed in " + tookMillis
+ "ms");
metrics.logTimingMetricDuration(Metric.LOAD_MODEL_TIME, tookMillis, false);
metrics.logTimingMetricDuration(Metric.LOAD_MODEL_TIME, tookMillis, false, new String[]{modelId});
metrics.logCounterMetric(Metric.LOAD_MODEL);
} catch (Throwable t) {
loadFuture = null;
Expand Down Expand Up @@ -2318,7 +2318,7 @@ protected final void complete(LoadedRuntime<T> result, Throwable error) {
if (size > 0) {
long sizeBytes = size * UNIT_SIZE;
logger.info("Model " + modelId + " size = " + size + " units" + ", ~" + mb(sizeBytes));
metrics.logSizeEventMetric(Metric.LOADED_MODEL_SIZE, sizeBytes);
metrics.logSizeEventMetric(Metric.LOADED_MODEL_SIZE, sizeBytes, new String[]{modelId});
} else {
try {
long before = nanoTime();
Expand All @@ -2327,9 +2327,9 @@ protected final void complete(LoadedRuntime<T> result, Throwable error) {
long took = msSince(before), sizeBytes = size * UNIT_SIZE;
logger.info("Model " + modelId + " size = " + size + " units" + ", ~" + mb(sizeBytes)
+ " sizing took " + took + "ms");
metrics.logTimingMetricDuration(Metric.MODEL_SIZING_TIME, took, false);
metrics.logTimingMetricDuration(Metric.MODEL_SIZING_TIME, took, false, new String[]{modelId});
// this is actually a size (bytes), not a "time"
metrics.logSizeEventMetric(Metric.LOADED_MODEL_SIZE, sizeBytes);
metrics.logSizeEventMetric(Metric.LOADED_MODEL_SIZE, sizeBytes, new String[]{modelId});
}
} catch (Exception e) {
if (!isInterruption(e) && state == SIZING) {
Expand Down Expand Up @@ -2638,7 +2638,7 @@ protected void beforeInvoke(int requestWeight)
//noinspection ThrowFromFinallyBlock
throw new ModelNotHereException(instanceId, modelId);
}
metrics.logTimingMetricDuration(Metric.QUEUE_DELAY, tookMillis, false);
metrics.logTimingMetricDuration(Metric.QUEUE_DELAY, tookMillis, false, new String[]{modelId});
}
}
}
Expand Down Expand Up @@ -2817,7 +2817,7 @@ public void onEviction(String key, CacheEntry<?> ce, long lastUsed) {
logger.info("Evicted " + (failed ? "failed model record" : "model") + " " + key
+ " from local cache, last used " + readableTime(millisSinceLastUsed) + " ago (" + lastUsed
+ "ms), invoked " + ce.getTotalInvocationCount() + " times");
metrics.logTimingMetricDuration(Metric.AGE_AT_EVICTION, millisSinceLastUsed, false);
metrics.logTimingMetricDuration(Metric.AGE_AT_EVICTION, millisSinceLastUsed, false, new String[]{ce.modelId});
metrics.logCounterMetric(Metric.EVICT_MODEL);
}

Expand Down Expand Up @@ -3903,9 +3903,10 @@ else if (mr.getInstanceIds().containsKey(instanceId)) {
throw t;
} finally {
if (methodStartNanos > 0L && metrics.isEnabled()) {
String[] extraLabels = new String[]{"modelId:" + modelId};
// only logged here in non-grpc (legacy) mode
metrics.logRequestMetrics(true, getRequestMethodName(method, args),
nanoTime() - methodStartNanos, metricStatusCode, -1, -1);
nanoTime() - methodStartNanos, metricStatusCode, -1, -1, extraLabels);
}
curThread.setName(threadNameBefore);
}
Expand Down Expand Up @@ -4364,7 +4365,7 @@ private Object invokeLocalModel(CacheEntry<?> ce, Method method, Object[] args)
long delayMillis = msSince(beforeNanos);
logger.info("Cache miss for model invocation, held up " + delayMillis + "ms");
metrics.logCounterMetric(Metric.CACHE_MISS);
metrics.logTimingMetricDuration(Metric.CACHE_MISS_DELAY, delayMillis, false);
metrics.logTimingMetricDuration(Metric.CACHE_MISS_DELAY, delayMillis, false, new String[]{ce.modelId});
}
}
} else {
Expand Down Expand Up @@ -4441,8 +4442,9 @@ private Object invokeLocalModel(CacheEntry<?> ce, Method method, Object[] args)
long tookNanos = nanoTime() - beforeNanos;
ce.afterInvoke(weight, tookNanos);
if (code != null && metrics.isEnabled()) {
String[] extraLabels = new String[]{ce.modelId};
metrics.logRequestMetrics(false, getRequestMethodName(method, args),
tookNanos, code, -1, -1);
tookNanos, code, -1, -1, extraLabels);
}
}
}
Expand Down
7 changes: 6 additions & 1 deletion src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.FastThreadLocalThread;
import org.apache.thrift.TException;
import org.checkerframework.checker.units.qual.A;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;
Expand Down Expand Up @@ -751,8 +752,12 @@ public void onHalfClose() {
call.close(status, emptyMeta());
Metrics metrics = delegate.metrics;
if (metrics.isEnabled()) {
ArrayList<String> extraLables = new ArrayList<>();
for (String modelId: modelIds) {
extraLables.add("modelId:" + modelId);
}
metrics.logRequestMetrics(true, methodName, nanoTime() - startNanos,
status.getCode(), reqSize, respSize);
status.getCode(), reqSize, respSize, (String[]) extraLables.toArray());
}
}
}
Expand Down
16 changes: 14 additions & 2 deletions src/main/java/com/ibm/watson/prometheus/SimpleCollector.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ public abstract class SimpleCollector<Child> extends Collector {
protected final String help;
protected final List<String> labelNames;
protected final int labelCount;
protected final int MAX_LABEL_COUNT = 1000; // TODO: make it configurable.
boolean allowArbitraryLabelCount; // Required for per model metrics in prometheus.

@SuppressWarnings("rawtypes")
private static final ChildEntry[] EMPTY = new ChildEntry[1]; // one null slot
Expand Down Expand Up @@ -160,7 +162,11 @@ private static int nextIdx(int i, int len) {
}

private void validateCount(int count) {
if (count != labelCount) {
if (allowArbitraryLabelCount) {
if (count > MAX_LABEL_COUNT) {
throw new IllegalArgumentException("Number of labels limit exceeded i.e :" + MAX_LABEL_COUNT);
}
} else if (count != labelCount) {
throw new IllegalArgumentException("Incorrect number of labels.");
}
}
Expand Down Expand Up @@ -445,6 +451,7 @@ protected List<MetricFamilySamples> familySamplesList(Collector.Type type, List<
protected SimpleCollector(Builder b) {
if (b.name.isEmpty()) throw new IllegalStateException("Name hasn't been set.");
String name = b.name;
allowArbitraryLabelCount = b.allowArbitraryLabelCount;
if (!b.subsystem.isEmpty()) {
name = b.subsystem + '_' + name;
}
Expand Down Expand Up @@ -479,7 +486,7 @@ public abstract static class Builder<B extends Builder<B, C>, C extends SimpleCo
String[] labelNames = new String[]{};
// Some metrics require additional setup before the initialization can be done.
boolean dontInitializeNoLabelsChild;

boolean allowArbitraryLabelCount = false;
/**
* Set the name of the metric. Required.
*/
Expand Down Expand Up @@ -516,6 +523,11 @@ public B labelNames(String... labelNames) {
return (B)this;
}

public B enableArbitraryLabels(boolean value) {
this.allowArbitraryLabelCount = value;
return (B)this;
}

/**
* Return the constructed collector.
* <p>
Expand Down

0 comments on commit 53378d9

Please sign in to comment.