Skip to content

Commit

Permalink
addressing PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
VedantMahabaleshwarkar committed Jun 2, 2023
1 parent 35df353 commit b111f0e
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 38 deletions.
26 changes: 16 additions & 10 deletions src/main/java/com/ibm/watson/modelmesh/Metrics.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.lang.reflect.Array;
import java.net.SocketAddress;
import java.nio.channels.DatagramChannel;
import java.util.Collections;
Expand Down Expand Up @@ -239,12 +238,12 @@ public PrometheusMetrics(Map<String, String> params, Map<String, String> infoMet
if (m == API_REQUEST_TIME || m == API_REQUEST_COUNT || m == INVOKE_MODEL_TIME
|| m == INVOKE_MODEL_COUNT || m == REQUEST_PAYLOAD_SIZE || m == RESPONSE_PAYLOAD_SIZE) {
if (this.enablePerModelMetrics && m.type != COUNTER_WITH_HISTO) {
builder.labelNames("method", "code", "modelId");
builder.labelNames("method", "code", "modelId", "vModelId");
} else {
builder.labelNames("method", "code");
}
} else if (this.enablePerModelMetrics && m.type != GAUGE && m.type != COUNTER && m.type != COUNTER_WITH_HISTO) {
builder.labelNames("modelId");
builder.labelNames("modelId", "vModelId");
}
Collector collector = builder.name(m.promName).help(m.description).create();
metricsMap.put(m, collector);
Expand Down Expand Up @@ -369,8 +368,8 @@ public void logTimingMetricSince(Metric metric, long prevTime, boolean isNano) {

@Override
public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano, String modelId) {
if (enablePerModelMetrics) {
((Histogram) metricsMap.get(metric)).labels(modelId).observe(isNano ? elapsed / M : elapsed);
if (enablePerModelMetrics && modelId != null) {
((Histogram) metricsMap.get(metric)).labels(modelId, "").observe(isNano ? elapsed / M : elapsed);
} else {
((Histogram) metricsMap.get(metric)).observe(isNano ? elapsed / M : elapsed);
}
Expand All @@ -379,7 +378,7 @@ public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano,
@Override
public void logSizeEventMetric(Metric metric, long value, String modelId) {
if (enablePerModelMetrics) {
((Histogram) metricsMap.get(metric)).labels(modelId).observe(value * metric.newMultiplier);
((Histogram) metricsMap.get(metric)).labels(modelId, "").observe(value * metric.newMultiplier);
} else {
((Histogram) metricsMap.get(metric)).observe(value * metric.newMultiplier);
}
Expand All @@ -403,18 +402,25 @@ public void logRequestMetrics(boolean external, String name, long elapsedNanos,
final long elapsedMillis = elapsedNanos / M;
final Histogram timingHisto = (Histogram) metricsMap
.get(external ? API_REQUEST_TIME : INVOKE_MODEL_TIME);
String mId = vModelId == null ? modelId : vModelId;
int idx = shortNames ? name.indexOf('/') : -1;
String methodName = idx == -1 ? name : name.substring(idx + 1);
if (modelId == null) {
logger.error("invalid ModelId. Label value for ModelId will be left blank");
modelId = "";
}
if (vModelId == null) {
logger.debug("vModelId is empty, creating empty label");
vModelId = "";
}
if (enablePerModelMetrics) {
timingHisto.labels(methodName, code.name(), mId).observe(elapsedMillis);
timingHisto.labels(methodName, code.name(), modelId, vModelId).observe(elapsedMillis);
} else {
timingHisto.labels(methodName, code.name()).observe(elapsedMillis);
}
if (reqPayloadSize != -1) {
if (enablePerModelMetrics) {
((Histogram) metricsMap.get(REQUEST_PAYLOAD_SIZE))
.labels(methodName, code.name(), mId).observe(reqPayloadSize);
.labels(methodName, code.name(), modelId, vModelId).observe(reqPayloadSize);
} else {
((Histogram) metricsMap.get(REQUEST_PAYLOAD_SIZE))
.labels(methodName, code.name()).observe(reqPayloadSize);
Expand All @@ -423,7 +429,7 @@ public void logRequestMetrics(boolean external, String name, long elapsedNanos,
if (respPayloadSize != -1) {
if (enablePerModelMetrics) {
((Histogram) metricsMap.get(RESPONSE_PAYLOAD_SIZE))
.labels(methodName, code.name(), mId).observe(respPayloadSize);
.labels(methodName, code.name(), modelId, vModelId).observe(respPayloadSize);
} else {
((Histogram) metricsMap.get(RESPONSE_PAYLOAD_SIZE))
.labels(methodName, code.name()).observe(respPayloadSize);
Expand Down
40 changes: 25 additions & 15 deletions src/main/java/com/ibm/watson/modelmesh/ModelMesh.java
Original file line number Diff line number Diff line change
Expand Up @@ -3315,6 +3315,7 @@ protected Map<String, ServiceInstanceInfo> getMap(Object[] arr) {
static final String KNOWN_SIZE_CXT_KEY = "tas.known_size";
static final String UNBALANCED_KEY = "mmesh.unbalanced";
static final String DEST_INST_ID_KEY = "tas.dest_iid";
static final String VMODELID = "vmodelid";

// these are the possible values for the tas.internal context parameter
// it won't be set on requests from outside of the cluster, and will
Expand Down Expand Up @@ -3347,8 +3348,8 @@ public StatusInfo internalOperation(String modelId, boolean returnStatus, boolea
List<String> excludeInstances)
throws ModelNotFoundException, ModelLoadException, ModelNotHereException, InternalException {
try {
return (StatusInfo) invokeModel(modelId, null, internalOpRemoteMeth,
returnStatus, load, sync, lastUsed, excludeInstances); // <-- "args"
return (StatusInfo) invokeModel(modelId, false, null,
internalOpRemoteMeth, returnStatus, load, sync, lastUsed, excludeInstances); // <-- "args"
} catch (ModelNotFoundException | ModelLoadException | ModelNotHereException | InternalException e) {
throw e;
} catch (TException e) {
Expand Down Expand Up @@ -3416,8 +3417,8 @@ public StatusInfo internalOperation(String modelId, boolean returnStatus, boolea
* @throws TException
*/
@SuppressWarnings("unchecked")
protected Object invokeModel(final String modelId, final Method method, final Method remoteMeth,
final Object... args) throws ModelNotFoundException, ModelNotHereException, ModelLoadException, TException {
protected Object invokeModel(final String modelId, Boolean isVModel, final Method method,
final Method remoteMeth, final Object... args) throws ModelNotFoundException, ModelNotHereException, ModelLoadException, TException {

//verify parameter values
if (modelId == null || modelId.isEmpty()) {
Expand All @@ -3430,6 +3431,10 @@ protected Object invokeModel(final String modelId, final Method method, final Me
}

final String tasInternal = contextMap.get(TAS_INTERNAL_CXT_KEY);
String vModelId = "";
if (isVModel) {
vModelId = contextMap.get(VMODELID);
}
// Set the external request flag if it's not a tasInternal call or if
// tasInternal == INTERNAL_REQ. The latter is a new ensureLoaded
// invocation originating from within the cluster.
Expand Down Expand Up @@ -3502,7 +3507,7 @@ protected Object invokeModel(final String modelId, final Method method, final Me
throw new ModelNotHereException(instanceId, modelId);
}
try {
return invokeLocalModel(ce, method, args, modelId);
return invokeLocalModel(ce, method, args, modelId, isVModel);
} catch (ModelLoadException mle) {
mr = registry.get(modelId);
if (mr == null || !mr.loadFailedInInstance(instanceId)) {
Expand Down Expand Up @@ -3716,7 +3721,7 @@ protected Object invokeModel(final String modelId, final Method method, final Me
localInvokesInFlight.incrementAndGet();
}
try {
Object result = invokeLocalModel(cacheEntry, method, args, modelId);
Object result = invokeLocalModel(cacheEntry, method, args, modelId, isVModel);
return method == null && externalReq ? updateWithModelCopyInfo(result, mr) : result;
} finally {
if (!favourSelfForHits) {
Expand Down Expand Up @@ -3936,7 +3941,7 @@ else if (mr.getInstanceIds().containsKey(instanceId)) {

// invoke model
try {
Object result = invokeLocalModel(cacheEntry, method, args, modelId);
Object result = invokeLocalModel(cacheEntry, method, args, modelId, isVModel);
return method == null && externalReq ? updateWithModelCopyInfo(result, mr) : result;
} catch (ModelNotHereException e) {
if (loadTargetFilter != null) loadTargetFilter.remove(instanceId);
Expand Down Expand Up @@ -3989,10 +3994,9 @@ else if (mr.getInstanceIds().containsKey(instanceId)) {
throw t;
} finally {
if (methodStartNanos > 0L && metrics.isEnabled()) {
String[] extraLabels = new String[]{modelId};
// only logged here in non-grpc (legacy) mode
metrics.logRequestMetrics(true, getRequestMethodName(method, args),
nanoTime() - methodStartNanos, metricStatusCode, -1, -1, modelId, "");
nanoTime() - methodStartNanos, metricStatusCode, -1, -1, modelId, vModelId);
}
curThread.setName(threadNameBefore);
}
Expand Down Expand Up @@ -4122,13 +4126,15 @@ private Map<String, Long> filterIfReadOnly(Map<String, Long> instId) {
* instances inside and some out, and a request has been sent from outside the
* cluster to an instance inside (since it may land on an unintended instance in
* that case).
*
* @param isVModel TODO
* @throws TException TODO
* @throws ModelNotHereException if the specified destination instance isn't found
*/
protected Object forwardInvokeModel(String destId, String modelId, Method remoteMeth, Object... args)
throws TException {
destinationInstance.set(destId);
try {
//TODO: not sure what is happening here.. do I need to pass vmodelid to the remoteMeth.invoke?
return remoteMeth.invoke(directClient, ObjectArrays.concat(modelId, args));
} catch (Exception e) {
if (e instanceof InvocationTargetException) {
Expand Down Expand Up @@ -4404,17 +4410,17 @@ protected Object invokeRemoteModel(BaseModelMeshService.Iface client, Method met
return remoteMeth.invoke(client, ObjectArrays.concat(modelId, args));
}

protected Object invokeLocalModel(CacheEntry<?> ce, Method method, Object[] args, String modelId)
protected Object invokeLocalModel(CacheEntry<?> ce, Method method, Object[] args, String modelId, Boolean isVModel)
throws InterruptedException, TException {
Object result = invokeLocalModel(ce, method, args);
Object result = invokeLocalModel(ce, method, false, args);
// if this is an ensure-loaded request, check-for and trigger a "chained" load if necessary
if (method == null) {
triggerChainedLoadIfNecessary(modelId, result, args, ce.getWeight(), null);
}
return result;
}

private Object invokeLocalModel(CacheEntry<?> ce, Method method, Object[] args)
private Object invokeLocalModel(CacheEntry<?> ce, Method method, Boolean isVModel, Object[] args)
throws InterruptedException, TException {

if (method == null) {
Expand All @@ -4429,7 +4435,11 @@ private Object invokeLocalModel(CacheEntry<?> ce, Method method, Object[] args)
long now = currentTimeMillis();
ce.upgradePriority(now + 3600_000L, now + 7200_000L); // (2 hours in future)
}

Map<String, String> contextMap = ThreadContext.getCurrentContext();
String vModelId = null;
if (isVModel) {
vModelId = contextMap.get(VMODELID);
}
// The future-waiting timeouts should not be needed, request threads are interrupted when their
// timeouts/deadlines expire, and the model loading thread that it waits for has its own timeout.
// But we still set a large one as a safeguard (there can be pathalogical cases where model-loading
Expand Down Expand Up @@ -4529,7 +4539,7 @@ private Object invokeLocalModel(CacheEntry<?> ce, Method method, Object[] args)
ce.afterInvoke(weight, tookNanos);
if (code != null && metrics.isEnabled()) {
metrics.logRequestMetrics(false, getRequestMethodName(method, args),
tookNanos, code, -1, -1, ce.modelId, "");
tookNanos, code, -1, -1, ce.modelId, vModelId);
}
}
}
Expand Down
35 changes: 26 additions & 9 deletions src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.FastThreadLocalThread;
import org.apache.thrift.TException;
import org.checkerframework.checker.units.qual.A;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;
Expand Down Expand Up @@ -345,6 +344,10 @@ protected static void setUnbalancedLitelinksContextParam() {
ThreadContext.addContextEntry(ModelMesh.UNBALANCED_KEY, "true"); // unbalanced
}

protected static void setvModelIdLiteLinksContextParam(String vModelId) {
ThreadContext.addContextEntry(ModelMesh.VMODELID, vModelId);
}

// ----------------- concrete model management methods

@Override
Expand Down Expand Up @@ -438,18 +441,19 @@ ModelResponse callModel(String modelId, boolean isVModel, String methodName, Str
if (unbalanced) {
setUnbalancedLitelinksContextParam();
}
return delegate.callModel(modelId, methodName, headers, data);
return delegate.callModel(modelId, isVModel, methodName, headers, data);
}
String vModelId = modelId;
modelId = null;
boolean first = true;
while (true) {
modelId = vmm().resolveVModelId(vModelId, modelId);
setvModelIdLiteLinksContextParam(vModelId);
if (unbalanced) {
setUnbalancedLitelinksContextParam();
}
try {
return delegate.callModel(modelId, methodName, headers, data);
return delegate.callModel(modelId, true, methodName, headers, data);
} catch (ModelNotFoundException mnfe) {
if (!first) throw mnfe;
} catch (Exception e) {
Expand Down Expand Up @@ -784,12 +788,25 @@ public void onHalfClose() {
call.close(status, emptyMeta());
Metrics metrics = delegate.metrics;
if (metrics.isEnabled()) {
if (isVModel) {
metrics.logRequestMetrics(true, methodName, nanoTime() - startNanos,
status.getCode(), reqSize, respSize, "", Iterables.toString(modelIds));
} else {
metrics.logRequestMetrics(true, methodName, nanoTime() - startNanos,
status.getCode(), reqSize, respSize, Iterables.toString(modelIds), "");
Iterator<String> midIt = modelIds.iterator();
while (midIt.hasNext()) {
if (isVModel) {
String mId = null;
String vmId = midIt.next();
try {
mId = vmm().resolveVModelId(midIt.next(), mId);
metrics.logRequestMetrics(true, methodName, nanoTime() - startNanos,
status.getCode(), reqSize, respSize, mId, vmId);
}
catch (Exception e) {
logger.error("Could not resolve model id for vModelId" + vmId, e);
metrics.logRequestMetrics(true, methodName, nanoTime() - startNanos,
status.getCode(), reqSize, respSize, "", vmId);
}
} else {
metrics.logRequestMetrics(true, methodName, nanoTime() - startNanos,
status.getCode(), reqSize, respSize, midIt.next(), "");
}
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/com/ibm/watson/modelmesh/SidecarModelMesh.java
Original file line number Diff line number Diff line change
Expand Up @@ -1098,12 +1098,12 @@ public List<ByteBuffer> applyModelMulti(String modelId, List<ByteBuffer> input,
@SuppressWarnings("unchecked")
List<ByteBuffer> applyModel(String modelId, List<ByteBuffer> input, Map<String, String> metadata)
throws TException {
return (List<ByteBuffer>) invokeModel(modelId, localMeth, remoteMeth, input, metadata);
return (List<ByteBuffer>) invokeModel(modelId, false, localMeth, remoteMeth, input, metadata);
}

// refcount of provided ByteBuf should not be modified
ModelResponse callModel(String modelId, String methodName, Metadata headers, ByteBuf data) throws TException {
return (ModelResponse) invokeModel(modelId, directLocalMeth, remoteMeth, methodName, headers, data);
ModelResponse callModel(String modelId, Boolean isVModel, String methodName, Metadata headers, ByteBuf data) throws TException {
return (ModelResponse) invokeModel(modelId, isVModel, directLocalMeth, remoteMeth, methodName, headers, data);
}

@Idempotent
Expand Down
2 changes: 1 addition & 1 deletion src/test/java/com/ibm/watson/modelmesh/DummyModelMesh.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ protected ModelLoader<?> getLoader() {
@Override
public ByteBuffer applyModel(String modelId, ByteBuffer input, Map<String, String> metadata)
throws TException {
return (ByteBuffer) invokeModel(modelId, localMeth, remoteMeth, input, metadata);
return (ByteBuffer) invokeModel(modelId, false, localMeth, remoteMeth, input, metadata);
}

@Override
Expand Down

0 comments on commit b111f0e

Please sign in to comment.