From a853e7ebba6ab1711541939a16e04bf7b8ca8a9d Mon Sep 17 00:00:00 2001 From: Vedant Mahabaleshwarkar Date: Tue, 5 Sep 2023 20:37:11 -0400 Subject: [PATCH] feat: Add per model metrics (#90) - Add `modelId` parameter to `logTimingMetricDuration` function in `Metrics.java`: - `modelmesh_cache_miss_milliseconds` - `modelmesh_loadmodel_milliseconds` - `modelmesh_unloadmodel_milliseconds` - `modelmesh_req_queue_delay_milliseconds` - `modelmesh_model_sizing_milliseconds` - `modelmesh_age_at_eviction_milliseconds` - Add `modelId` parameter to `logSizeEventMetric` function in `Metrics.java`: - `modelmesh_loading_queue_delay_milliseconds` - `modelmesh_loaded_model_size_bytes` - Add `modelId` and `vModelId` param to `logRequestMetrics` in `Metrics.java`: - `modelmesh_invoke_model_milliseconds` - `modelmesh_api_request_milliseconds` Closes #60 Signed-off-by: Vedant Mahabaleshwarkar Signed-off-by: Nick Hill Co-authored-by: Prashant Sharma Co-authored-by: Daniele Zonca Co-authored-by: Nick Hill --- .../com/ibm/watson/modelmesh/Metrics.java | 112 +++++++++++++----- .../com/ibm/watson/modelmesh/ModelMesh.java | 40 ++++--- .../ibm/watson/modelmesh/ModelMeshApi.java | 47 ++++++-- .../ibm/watson/modelmesh/VModelManager.java | 1 - .../watson/prometheus/SimpleCollector.java | 2 +- .../modelmesh/ModelMeshMetricsTest.java | 38 +++--- 6 files changed, 166 insertions(+), 74 deletions(-) diff --git a/src/main/java/com/ibm/watson/modelmesh/Metrics.java b/src/main/java/com/ibm/watson/modelmesh/Metrics.java index b246a5c3..7be788fe 100644 --- a/src/main/java/com/ibm/watson/modelmesh/Metrics.java +++ b/src/main/java/com/ibm/watson/modelmesh/Metrics.java @@ -16,6 +16,7 @@ package com.ibm.watson.modelmesh; +import com.google.common.base.Strings; import com.ibm.watson.prometheus.Counter; import com.ibm.watson.prometheus.Gauge; import com.ibm.watson.prometheus.Histogram; @@ -36,34 +37,39 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import java.lang.reflect.Array; import java.net.SocketAddress; import java.nio.channels.DatagramChannel; -import java.util.*; +import java.util.Collections; +import java.util.EnumMap; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; import java.util.Map.Entry; +import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.stream.Stream; import static com.ibm.watson.modelmesh.Metric.*; +import static com.ibm.watson.modelmesh.Metric.MetricType.*; import static com.ibm.watson.modelmesh.ModelMesh.M; import static com.ibm.watson.modelmesh.ModelMeshEnvVars.MMESH_CUSTOM_ENV_VAR; -import static com.ibm.watson.modelmesh.ModelMeshEnvVars.MMESH_METRICS_ENV_VAR; import static java.util.concurrent.TimeUnit.*; /** * */ interface Metrics extends AutoCloseable { + boolean isPerModelMetricsEnabled(); boolean isEnabled(); void logTimingMetricSince(Metric metric, long prevTime, boolean isNano); - void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano); + void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano, String modelId); - void logSizeEventMetric(Metric metric, long value); + void logSizeEventMetric(Metric metric, long value, String modelId); void logGaugeMetric(Metric metric, long value); @@ -101,7 +107,7 @@ default void logInstanceStats(final InstanceRecord ir) { * @param respPayloadSize response payload size in bytes (or -1 if not applicable) */ void logRequestMetrics(boolean external, String name, long elapsedNanos, Code code, - int reqPayloadSize, int respPayloadSize); + int reqPayloadSize, int respPayloadSize, String modelId, String vModelId); default void registerGlobals() {} @@ -111,6 +117,11 @@ default void unregisterGlobals() {} default void close() {} Metrics NO_OP_METRICS = new Metrics() { + @Override + public boolean isPerModelMetricsEnabled() { + return false; + } + @Override public boolean isEnabled() { return false; @@ -120,10 +131,10 @@ public boolean isEnabled() { public void logTimingMetricSince(Metric metric, long prevTime, boolean isNano) {} @Override - public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano) {} + public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano, String modelId){} @Override - public void logSizeEventMetric(Metric metric, long value) {} + public void logSizeEventMetric(Metric metric, long value, String modelId){} @Override public void logGaugeMetric(Metric metric, long value) {} @@ -136,7 +147,7 @@ public void logInstanceStats(InstanceRecord ir) {} @Override public void logRequestMetrics(boolean external, String name, long elapsedNanos, Code code, - int reqPayloadSize, int respPayloadSize) {} + int reqPayloadSize, int respPayloadSize, String modelId, String vModelId) {} }; final class PrometheusMetrics implements Metrics { @@ -154,12 +165,14 @@ final class PrometheusMetrics implements Metrics { private final CollectorRegistry registry; private final NettyServer metricServer; private final boolean shortNames; + private final boolean perModelMetricsEnabled; private final EnumMap metricsMap = new EnumMap<>(Metric.class); public PrometheusMetrics(Map params, Map infoMetricParams) throws Exception { int port = 2112; boolean shortNames = true; boolean https = true; + boolean perModelMetricsEnabled = true; String memMetrics = "all"; // default to all for (Entry ent : params.entrySet()) { switch (ent.getKey()) { @@ -170,6 +183,9 @@ public PrometheusMetrics(Map params, Map infoMet throw new Exception("Invalid metrics port: " + ent.getValue()); } break; + case "per_model_metrics": + perModelMetricsEnabled= "true".equalsIgnoreCase(ent.getValue()); + break; case "fq_names": shortNames = !"true".equalsIgnoreCase(ent.getValue()); break; @@ -188,6 +204,7 @@ public PrometheusMetrics(Map params, Map infoMet throw new Exception("Unrecognized metrics config parameter: " + ent.getKey()); } } + this.perModelMetricsEnabled = perModelMetricsEnabled; registry = new CollectorRegistry(); for (Metric m : Metric.values()) { @@ -220,10 +237,15 @@ public PrometheusMetrics(Map params, Map infoMet } if (m == API_REQUEST_TIME || m == API_REQUEST_COUNT || m == INVOKE_MODEL_TIME - || m == INVOKE_MODEL_COUNT || m == REQUEST_PAYLOAD_SIZE || m == RESPONSE_PAYLOAD_SIZE) { - builder.labelNames("method", "code"); + || m == INVOKE_MODEL_COUNT || m == REQUEST_PAYLOAD_SIZE || m == RESPONSE_PAYLOAD_SIZE) { + if (this.perModelMetricsEnabled) { + builder.labelNames("method", "code", "modelId", "vModelId"); + } else { + builder.labelNames("method", "code"); + } + } else if (this.perModelMetricsEnabled && m.type != GAUGE && m.type != COUNTER && m.type != COUNTER_WITH_HISTO) { + builder.labelNames("modelId", "vModelId"); } - Collector collector = builder.name(m.promName).help(m.description).create(); metricsMap.put(m, collector); if (!m.global) { @@ -330,6 +352,11 @@ public void close() { this.metricServer.close(); } + @Override + public boolean isPerModelMetricsEnabled() { + return perModelMetricsEnabled; + } + @Override public boolean isEnabled() { return true; @@ -342,13 +369,23 @@ public void logTimingMetricSince(Metric metric, long prevTime, boolean isNano) { } @Override - public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano) { - ((Histogram) metricsMap.get(metric)).observe(isNano ? elapsed / M : elapsed); + public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano, String modelId) { + Histogram histogram = (Histogram) metricsMap.get(metric); + if (perModelMetricsEnabled && modelId != null) { + histogram.labels(modelId, "").observe(isNano ? elapsed / M : elapsed); + } else { + histogram.observe(isNano ? elapsed / M : elapsed); + } } @Override - public void logSizeEventMetric(Metric metric, long value) { - ((Histogram) metricsMap.get(metric)).observe(value * metric.newMultiplier); + public void logSizeEventMetric(Metric metric, long value, String modelId) { + Histogram histogram = (Histogram) metricsMap.get(metric); + if (perModelMetricsEnabled) { + histogram.labels(modelId, "").observe(value * metric.newMultiplier); + } else { + histogram.observe(value * metric.newMultiplier); + } } @Override @@ -365,23 +402,37 @@ public void logCounterMetric(Metric metric) { @Override public void logRequestMetrics(boolean external, String name, long elapsedNanos, Code code, - int reqPayloadSize, int respPayloadSize) { + int reqPayloadSize, int respPayloadSize, String modelId, String vModelId) { final long elapsedMillis = elapsedNanos / M; final Histogram timingHisto = (Histogram) metricsMap .get(external ? API_REQUEST_TIME : INVOKE_MODEL_TIME); int idx = shortNames ? name.indexOf('/') : -1; - final String methodName = idx == -1 ? name : name.substring(idx + 1); - - timingHisto.labels(methodName, code.name()).observe(elapsedMillis); - + String methodName = idx == -1 ? name : name.substring(idx + 1); + if (perModelMetricsEnabled) { + modelId = Strings.nullToEmpty(modelId); + vModelId = Strings.nullToEmpty(vModelId); + } + if (perModelMetricsEnabled) { + timingHisto.labels(methodName, code.name(), modelId, vModelId).observe(elapsedMillis); + } else { + timingHisto.labels(methodName, code.name()).observe(elapsedMillis); + } if (reqPayloadSize != -1) { - ((Histogram) metricsMap.get(REQUEST_PAYLOAD_SIZE)) - .labels(methodName, code.name()).observe(reqPayloadSize); + Histogram reqPayloadHisto = (Histogram) metricsMap.get(REQUEST_PAYLOAD_SIZE); + if (perModelMetricsEnabled) { + reqPayloadHisto.labels(methodName, code.name(), modelId, vModelId).observe(reqPayloadSize); + } else { + reqPayloadHisto.labels(methodName, code.name()).observe(reqPayloadSize); + } } if (respPayloadSize != -1) { - ((Histogram) metricsMap.get(RESPONSE_PAYLOAD_SIZE)) - .labels(methodName, code.name()).observe(respPayloadSize); + Histogram respPayloadHisto = (Histogram) metricsMap.get(RESPONSE_PAYLOAD_SIZE); + if (perModelMetricsEnabled) { + respPayloadHisto.labels(methodName, code.name(), modelId, vModelId).observe(respPayloadSize); + } else { + respPayloadHisto.labels(methodName, code.name()).observe(respPayloadSize); + } } } @@ -437,6 +488,11 @@ protected StatsDSender createSender(Callable addressLookup, int q + (shortNames ? "short" : "fully-qualified") + " method names"); } + @Override + public boolean isPerModelMetricsEnabled() { + return false; + } + @Override public boolean isEnabled() { return true; @@ -454,12 +510,12 @@ public void logTimingMetricSince(Metric metric, long prevTime, boolean isNano) { } @Override - public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano) { + public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano, String modelId) { client.recordExecutionTime(name(metric), isNano ? elapsed / M : elapsed); } @Override - public void logSizeEventMetric(Metric metric, long value) { + public void logSizeEventMetric(Metric metric, long value, String modelId) { if (!legacy) { value *= metric.newMultiplier; } @@ -497,7 +553,7 @@ static String[] getOkTags(String method, boolean shortName) { @Override public void logRequestMetrics(boolean external, String name, long elapsedNanos, Code code, - int reqPayloadSize, int respPayloadSize) { + int reqPayloadSize, int respPayloadSize, String modelId, String vModelId) { final StatsDClient client = this.client; final long elapsedMillis = elapsedNanos / M; final String countName = name(external ? API_REQUEST_COUNT : INVOKE_MODEL_COUNT); diff --git a/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java b/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java index 9755df49..78c776b4 100644 --- a/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java +++ b/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java @@ -1966,7 +1966,7 @@ final synchronized boolean doRemove(final boolean evicted, // "unload" event if explicit unloading isn't enabled. // Otherwise, this gets recorded in a callback set in the // CacheEntry.unload(int) method - metrics.logTimingMetricDuration(Metric.UNLOAD_MODEL_TIME, 0L, false); + metrics.logTimingMetricDuration(Metric.UNLOAD_MODEL_TIME, 0L, false, modelId); metrics.logCounterMetric(Metric.UNLOAD_MODEL); } } @@ -2037,7 +2037,7 @@ public void onSuccess(Boolean reallyHappened) { //TODO probably only log if took longer than a certain time long tookMillis = msSince(beforeNanos); logger.info("Unload of " + modelId + " completed in " + tookMillis + "ms"); - metrics.logTimingMetricDuration(Metric.UNLOAD_MODEL_TIME, tookMillis, false); + metrics.logTimingMetricDuration(Metric.UNLOAD_MODEL_TIME, tookMillis, false, modelId); metrics.logCounterMetric(Metric.UNLOAD_MODEL); } // else considered trivially succeeded because the corresponding @@ -2158,7 +2158,7 @@ public final void run() { long queueStartTimeNanos = getAndResetLoadingQueueStartTimeNanos(); if (queueStartTimeNanos > 0) { long queueDelayMillis = (nanoTime() - queueStartTimeNanos) / M; - metrics.logSizeEventMetric(Metric.LOAD_MODEL_QUEUE_DELAY, queueDelayMillis); + metrics.logSizeEventMetric(Metric.LOAD_MODEL_QUEUE_DELAY, queueDelayMillis, modelId); // Only log if the priority value is "in the future" which indicates // that there is or were runtime requests waiting for this load. // Otherwise we don't care about arbitrary delays here @@ -2228,7 +2228,7 @@ public final void run() { loadingTimeStats(modelType).recordTime(tookMillis); logger.info("Load of model " + modelId + " type=" + modelType + " completed in " + tookMillis + "ms"); - metrics.logTimingMetricDuration(Metric.LOAD_MODEL_TIME, tookMillis, false); + metrics.logTimingMetricDuration(Metric.LOAD_MODEL_TIME, tookMillis, false, modelId); metrics.logCounterMetric(Metric.LOAD_MODEL); } catch (Throwable t) { loadFuture = null; @@ -2388,7 +2388,7 @@ protected final void complete(LoadedRuntime result, Throwable error) { if (size > 0) { long sizeBytes = size * UNIT_SIZE; logger.info("Model " + modelId + " size = " + size + " units" + ", ~" + mb(sizeBytes)); - metrics.logSizeEventMetric(Metric.LOADED_MODEL_SIZE, sizeBytes); + metrics.logSizeEventMetric(Metric.LOADED_MODEL_SIZE, sizeBytes, modelId); } else { try { long before = nanoTime(); @@ -2397,9 +2397,9 @@ protected final void complete(LoadedRuntime result, Throwable error) { long took = msSince(before), sizeBytes = size * UNIT_SIZE; logger.info("Model " + modelId + " size = " + size + " units" + ", ~" + mb(sizeBytes) + " sizing took " + took + "ms"); - metrics.logTimingMetricDuration(Metric.MODEL_SIZING_TIME, took, false); + metrics.logTimingMetricDuration(Metric.MODEL_SIZING_TIME, took, false, modelId); // this is actually a size (bytes), not a "time" - metrics.logSizeEventMetric(Metric.LOADED_MODEL_SIZE, sizeBytes); + metrics.logSizeEventMetric(Metric.LOADED_MODEL_SIZE, sizeBytes, modelId); } } catch (Exception e) { if (!isInterruption(e) && state == SIZING) { @@ -2722,7 +2722,7 @@ protected void beforeInvoke(int requestWeight) //noinspection ThrowFromFinallyBlock throw new ModelNotHereException(instanceId, modelId); } - metrics.logTimingMetricDuration(Metric.QUEUE_DELAY, tookMillis, false); + metrics.logTimingMetricDuration(Metric.QUEUE_DELAY, tookMillis, false, modelId); } } } @@ -2901,7 +2901,7 @@ public void onEviction(String key, CacheEntry ce, long lastUsed) { logger.info("Evicted " + (failed ? "failed model record" : "model") + " " + key + " from local cache, last used " + readableTime(millisSinceLastUsed) + " ago (" + lastUsed + "ms), invoked " + ce.getTotalInvocationCount() + " times"); - metrics.logTimingMetricDuration(Metric.AGE_AT_EVICTION, millisSinceLastUsed, false); + metrics.logTimingMetricDuration(Metric.AGE_AT_EVICTION, millisSinceLastUsed, false, ce.modelId); metrics.logCounterMetric(Metric.EVICT_MODEL); } @@ -3315,6 +3315,7 @@ protected Map getMap(Object[] arr) { static final String KNOWN_SIZE_CXT_KEY = "tas.known_size"; static final String UNBALANCED_KEY = "mmesh.unbalanced"; static final String DEST_INST_ID_KEY = "tas.dest_iid"; + static final String VMODEL_ID = "vmodelid"; // these are the possible values for the tas.internal context parameter // it won't be set on requests from outside of the cluster, and will @@ -3430,6 +3431,7 @@ protected Object invokeModel(final String modelId, final Method method, final Me } final String tasInternal = contextMap.get(TAS_INTERNAL_CXT_KEY); + final String vModelId = contextMap.getOrDefault(VMODEL_ID, ""); // Set the external request flag if it's not a tasInternal call or if // tasInternal == INTERNAL_REQ. The latter is a new ensureLoaded // invocation originating from within the cluster. @@ -3502,7 +3504,7 @@ protected Object invokeModel(final String modelId, final Method method, final Me throw new ModelNotHereException(instanceId, modelId); } try { - return invokeLocalModel(ce, method, args, modelId); + return invokeLocalModel(ce, method, args, vModelId); } catch (ModelLoadException mle) { mr = registry.get(modelId); if (mr == null || !mr.loadFailedInInstance(instanceId)) { @@ -3716,7 +3718,7 @@ protected Object invokeModel(final String modelId, final Method method, final Me localInvokesInFlight.incrementAndGet(); } try { - Object result = invokeLocalModel(cacheEntry, method, args, modelId); + Object result = invokeLocalModel(cacheEntry, method, args, vModelId); return method == null && externalReq ? updateWithModelCopyInfo(result, mr) : result; } finally { if (!favourSelfForHits) { @@ -3936,7 +3938,7 @@ else if (mr.getInstanceIds().containsKey(instanceId)) { // invoke model try { - Object result = invokeLocalModel(cacheEntry, method, args, modelId); + Object result = invokeLocalModel(cacheEntry, method, args, vModelId); return method == null && externalReq ? updateWithModelCopyInfo(result, mr) : result; } catch (ModelNotHereException e) { if (loadTargetFilter != null) loadTargetFilter.remove(instanceId); @@ -3991,7 +3993,7 @@ else if (mr.getInstanceIds().containsKey(instanceId)) { if (methodStartNanos > 0L && metrics.isEnabled()) { // only logged here in non-grpc (legacy) mode metrics.logRequestMetrics(true, getRequestMethodName(method, args), - nanoTime() - methodStartNanos, metricStatusCode, -1, -1); + nanoTime() - methodStartNanos, metricStatusCode, -1, -1, modelId, vModelId); } curThread.setName(threadNameBefore); } @@ -4403,17 +4405,17 @@ protected Object invokeRemoteModel(BaseModelMeshService.Iface client, Method met return remoteMeth.invoke(client, ObjectArrays.concat(modelId, args)); } - protected Object invokeLocalModel(CacheEntry ce, Method method, Object[] args, String modelId) + protected Object invokeLocalModel(CacheEntry ce, Method method, Object[] args, String vModelId) throws InterruptedException, TException { - Object result = invokeLocalModel(ce, method, args); + final Object result = _invokeLocalModel(ce, method, args, vModelId); // if this is an ensure-loaded request, check-for and trigger a "chained" load if necessary if (method == null) { - triggerChainedLoadIfNecessary(modelId, result, args, ce.getWeight(), null); + triggerChainedLoadIfNecessary(ce.modelId, result, args, ce.getWeight(), null); } return result; } - private Object invokeLocalModel(CacheEntry ce, Method method, Object[] args) + private Object _invokeLocalModel(CacheEntry ce, Method method, Object[] args, String vModelId) throws InterruptedException, TException { if (method == null) { @@ -4450,7 +4452,7 @@ private Object invokeLocalModel(CacheEntry ce, Method method, Object[] args) long delayMillis = msSince(beforeNanos); logger.info("Cache miss for model invocation, held up " + delayMillis + "ms"); metrics.logCounterMetric(Metric.CACHE_MISS); - metrics.logTimingMetricDuration(Metric.CACHE_MISS_DELAY, delayMillis, false); + metrics.logTimingMetricDuration(Metric.CACHE_MISS_DELAY, delayMillis, false, ce.modelId); } } } else { @@ -4528,7 +4530,7 @@ private Object invokeLocalModel(CacheEntry ce, Method method, Object[] args) ce.afterInvoke(weight, tookNanos); if (code != null && metrics.isEnabled()) { metrics.logRequestMetrics(false, getRequestMethodName(method, args), - tookNanos, code, -1, -1); + tookNanos, code, -1, -1, ce.modelId, vModelId); } } } diff --git a/src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java b/src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java index 6f3f8202..715c0efe 100644 --- a/src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java +++ b/src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java @@ -30,6 +30,7 @@ import com.ibm.watson.litelinks.server.ReleaseAfterResponse; import com.ibm.watson.litelinks.server.ServerRequestThread; import com.ibm.watson.modelmesh.DataplaneApiConfig.RpcConfig; +import com.ibm.watson.modelmesh.GrpcSupport.InterruptingListener; import com.ibm.watson.modelmesh.ModelMesh.ExtendedStatusInfo; import com.ibm.watson.modelmesh.api.DeleteVModelRequest; import com.ibm.watson.modelmesh.api.DeleteVModelResponse; @@ -68,6 +69,7 @@ import io.grpc.ServerInterceptors; import io.grpc.ServerMethodDefinition; import io.grpc.ServerServiceDefinition; +import io.grpc.Status.Code; import io.grpc.StatusException; import io.grpc.StatusRuntimeException; import io.grpc.netty.GrpcSslContexts; @@ -85,6 +87,7 @@ import io.netty.handler.ssl.ClientAuth; import io.netty.handler.ssl.SslContext; import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.FastThreadLocal; import io.netty.util.concurrent.FastThreadLocalThread; import org.apache.thrift.TException; import org.slf4j.Logger; @@ -344,6 +347,10 @@ protected static void setUnbalancedLitelinksContextParam() { ThreadContext.addContextEntry(ModelMesh.UNBALANCED_KEY, "true"); // unbalanced } + protected static void setVModelIdLiteLinksContextParam(String vModelId) { + ThreadContext.addContextEntry(ModelMesh.VMODEL_ID, vModelId); + } + // ----------------- concrete model management methods @Override @@ -427,6 +434,9 @@ public void ensureLoaded(EnsureLoadedRequest request, StreamObserver resolvedModelId = new FastThreadLocal<>(); + // Returned ModelResponse will be released once the request thread exits so // must be retained before transferring. // non-private to avoid synthetic method access @@ -441,9 +451,13 @@ ModelResponse callModel(String modelId, boolean isVModel, String methodName, Str } String vModelId = modelId; modelId = null; + if (delegate.metrics.isPerModelMetricsEnabled()) { + setVModelIdLiteLinksContextParam(vModelId); + } boolean first = true; while (true) { modelId = vmm().resolveVModelId(vModelId, modelId); + resolvedModelId.set(modelId); if (unbalanced) { setUnbalancedLitelinksContextParam(); } @@ -542,7 +556,7 @@ protected static void respondAndComplete(StreamObserver response, } protected static io.grpc.Status toStatus(Exception e) { - io.grpc.Status s = null; + io.grpc.Status s; String msg = e.getMessage(); if (e instanceof ModelNotFoundException) { return MODEL_NOT_FOUND_STATUS; @@ -655,7 +669,7 @@ public Listener startCall(ServerCall call, Metadata h call.request(2); // request 2 to force failure if streaming method - return new Listener() { + return new Listener<>() { ByteBuf reqMessage; boolean canInvoke = true; Iterable modelIds = mids.modelIds; @@ -707,7 +721,8 @@ public void onHalfClose() { int respReaderIndex = 0; io.grpc.Status status = INTERNAL; - String modelId = null; + String resolvedModelId = null; + String vModelId = null; String requestId = null; ModelResponse response = null; try (InterruptingListener cancelListener = newInterruptingListener()) { @@ -721,16 +736,28 @@ public void onHalfClose() { String balancedMetaVal = headers.get(BALANCED_META_KEY); Iterator midIt = modelIds.iterator(); // guaranteed at least one - modelId = validateModelId(midIt.next(), isVModel); + String modelOrVModelId = validateModelId(midIt.next(), isVModel); if (!midIt.hasNext()) { // single model case (most common) - response = callModel(modelId, isVModel, methodName, - balancedMetaVal, headers, reqMessage).retain(); + if (isVModel) { + ModelMeshApi.resolvedModelId.set(null); + } + try { + response = callModel(modelOrVModelId, isVModel, methodName, + balancedMetaVal, headers, reqMessage).retain(); + } finally { + if (isVModel) { + vModelId = modelOrVModelId; + resolvedModelId = ModelMeshApi.resolvedModelId.getIfExists(); + } else { + resolvedModelId = modelOrVModelId; + } + } } else { // multi-model case (specialized) boolean allRequired = "all".equalsIgnoreCase(headers.get(REQUIRED_KEY)); List idList = new ArrayList<>(); - idList.add(modelId); + idList.add(modelOrVModelId); while (midIt.hasNext()) { idList.add(validateModelId(midIt.next(), isVModel)); } @@ -740,7 +767,7 @@ public void onHalfClose() { } finally { if (payloadProcessor != null) { processPayload(reqMessage.readerIndex(reqReaderIndex), - requestId, modelId, methodName, headers, null, true); + requestId, resolvedModelId, methodName, headers, null, true); } else { releaseReqMessage(); } @@ -776,7 +803,7 @@ public void onHalfClose() { data = response.data.readerIndex(respReaderIndex); metadata = response.metadata; } - processPayload(data, requestId, modelId, methodName, metadata, status, releaseResponse); + processPayload(data, requestId, resolvedModelId, methodName, metadata, status, releaseResponse); } else if (releaseResponse && response != null) { response.release(); } @@ -787,7 +814,7 @@ public void onHalfClose() { Metrics metrics = delegate.metrics; if (metrics.isEnabled()) { metrics.logRequestMetrics(true, methodName, nanoTime() - startNanos, - status.getCode(), reqSize, respSize); + status.getCode(), reqSize, respSize, resolvedModelId, vModelId); } } } diff --git a/src/main/java/com/ibm/watson/modelmesh/VModelManager.java b/src/main/java/com/ibm/watson/modelmesh/VModelManager.java index d2706a16..7ad5da8a 100644 --- a/src/main/java/com/ibm/watson/modelmesh/VModelManager.java +++ b/src/main/java/com/ibm/watson/modelmesh/VModelManager.java @@ -27,7 +27,6 @@ import com.ibm.watson.kvutils.KVTable.Helper.TableTxn; import com.ibm.watson.kvutils.KVTable.TableView; import com.ibm.watson.kvutils.factory.KVUtilsFactory; -import com.ibm.watson.litelinks.ThreadContext; import com.ibm.watson.litelinks.ThreadPoolHelper; import com.ibm.watson.modelmesh.GrpcSupport.InterruptingListener; import com.ibm.watson.modelmesh.api.ModelInfo; diff --git a/src/main/java/com/ibm/watson/prometheus/SimpleCollector.java b/src/main/java/com/ibm/watson/prometheus/SimpleCollector.java index ffca070b..c7b25c1f 100644 --- a/src/main/java/com/ibm/watson/prometheus/SimpleCollector.java +++ b/src/main/java/com/ibm/watson/prometheus/SimpleCollector.java @@ -161,7 +161,7 @@ private static int nextIdx(int i, int len) { private void validateCount(int count) { if (count != labelCount) { - throw new IllegalArgumentException("Incorrect number of labels."); + throw new IllegalArgumentException("Incorrect number of labels. Expected: " + labelCount + ", got: " + count); } } diff --git a/src/test/java/com/ibm/watson/modelmesh/ModelMeshMetricsTest.java b/src/test/java/com/ibm/watson/modelmesh/ModelMeshMetricsTest.java index dc6ee35e..a78cef1c 100644 --- a/src/test/java/com/ibm/watson/modelmesh/ModelMeshMetricsTest.java +++ b/src/test/java/com/ibm/watson/modelmesh/ModelMeshMetricsTest.java @@ -32,6 +32,7 @@ import io.grpc.ManagedChannel; import io.grpc.netty.NettyChannelBuilder; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import javax.net.ssl.SSLContext; @@ -76,10 +77,11 @@ protected int requestCount() { @Override protected Map extraEnvVars() { - return ImmutableMap.of("MM_METRICS", "prometheus:port=" + METRICS_PORT + ";scheme=" + SCHEME); + return ImmutableMap.of("MM_METRICS", "prometheus:port=" + METRICS_PORT + ";scheme=" + SCHEME + + ";per_model_metrics=true"); } - @Test + @BeforeAll public void metricsTest() throws Exception { ManagedChannel channel = NettyChannelBuilder.forAddress("localhost", 9000).usePlaintext().build(); @@ -151,11 +153,11 @@ public void metricsTest() throws Exception { } } - public void verifyMetrics() throws Exception { + protected Map prepareMetrics() throws Exception { // Insecure trust manager - skip TLS verification SSLContext sslContext = SSLContext.getInstance("TLS"); sslContext.init(null, InsecureTrustManagerFactory.INSTANCE.getTrustManagers(), null); - + HttpClient client = HttpClient.newBuilder().sslContext(sslContext).build(); HttpRequest metricsRequest = HttpRequest.newBuilder() .uri(URI.create(SCHEME + "://localhost:" + METRICS_PORT + "/metrics")).build(); @@ -172,29 +174,35 @@ public void verifyMetrics() throws Exception { .filter(Matcher::matches) .collect(Collectors.toMap(m -> m.group(1), m -> Double.parseDouble(m.group(2)))); + return metrics; + } + + @Test + public void verifyMetrics() throws Exception { + // Insecure trust manager - skip TLS verification + Map metrics = prepareMetrics(); + System.out.println(metrics.size() + " metrics scraped"); // Spot check some expected metrics and values // External response time should all be < 2000ms (includes cache hit loading time) - assertEquals(40.0, metrics.get("modelmesh_api_request_milliseconds_bucket{method=\"predict\",code=\"OK\",le=\"2000.0\",}")); + assertEquals(40.0, metrics.get("modelmesh_api_request_milliseconds_bucket{method=\"predict\",code=\"OK\",modelId=\"myModel\",vModelId=\"\",le=\"2000.0\",}")); // External response time should all be < 200ms (includes cache hit loading time) - assertEquals(40.0, metrics.get("modelmesh_invoke_model_milliseconds_bucket{method=\"predict\",code=\"OK\",le=\"200.0\",}")); + assertEquals(40.0, + metrics.get("modelmesh_invoke_model_milliseconds_bucket{method=\"predict\",code=\"OK\",modelId=\"myModel\",vModelId=\"\",le=\"120000.0\",}")); // Simulated model sizing time is < 200ms - assertEquals(1.0, metrics.get("modelmesh_model_sizing_milliseconds_bucket{le=\"200.0\",}")); + assertEquals(1.0, metrics.get("modelmesh_model_sizing_milliseconds_bucket{modelId=\"myModel\",vModelId=\"\",le=\"60000.0\",}")); // Simulated model sizing time is > 50ms - assertEquals(0.0, metrics.get("modelmesh_model_sizing_milliseconds_bucket{le=\"50.0\",}")); + assertEquals(0.0, metrics.get("modelmesh_model_sizing_milliseconds_bucket{modelId=\"myModel\",vModelId=\"\",le=\"50.0\",}")); // Simulated model size is between 64MiB and 256MiB - assertEquals(0.0, metrics.get("modelmesh_loaded_model_size_bytes_bucket{le=\"6.7108864E7\",}")); - assertEquals(1.0, metrics.get("modelmesh_loaded_model_size_bytes_bucket{le=\"2.68435456E8\",}")); + assertEquals(0.0, metrics.get("modelmesh_loaded_model_size_bytes_bucket{modelId=\"myModel\",vModelId=\"\",le=\"6.7108864E7\",}")); + assertEquals(1.0, metrics.get("modelmesh_loaded_model_size_bytes_bucket{modelId=\"myModel\",vModelId=\"\",le=\"2.68435456E8\",}")); // One model is loaded - assertEquals(1.0, metrics.get("modelmesh_models_loaded_total")); assertEquals(1.0, metrics.get("modelmesh_instance_models_total")); // Histogram counts should reflect the two payload sizes (30 small, 10 large) - assertEquals(30.0, metrics.get("modelmesh_request_size_bytes_bucket{method=\"predict\",code=\"OK\",le=\"128.0\",}")); - assertEquals(40.0, metrics.get("modelmesh_request_size_bytes_bucket{method=\"predict\",code=\"OK\",le=\"2097152.0\",}")); - assertEquals(30.0, metrics.get("modelmesh_response_size_bytes_bucket{method=\"predict\",code=\"OK\",le=\"128.0\",}")); - assertEquals(40.0, metrics.get("modelmesh_response_size_bytes_bucket{method=\"predict\",code=\"OK\",le=\"2097152.0\",}")); + assertEquals(30.0, metrics.get("modelmesh_request_size_bytes_bucket{method=\"predict\",code=\"OK\",modelId=\"myModel\",vModelId=\"\",le=\"128.0\",}")); + assertEquals(40.0, metrics.get("modelmesh_request_size_bytes_bucket{method=\"predict\",code=\"OK\",modelId=\"myModel\",vModelId=\"\",le=\"2097152.0\",}")); // Memory metrics assertTrue(metrics.containsKey("netty_pool_mem_allocated_bytes{area=\"direct\",}"));