Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix cache accounting in loading failures case #74

Merged
merged 1 commit into from
Dec 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ int getUnloadBufferWeight() {
return UNLOAD_BUFF.getWeight();
}

// For instrumentation/diagnostics only
int getTotalUnloadingWeight() {
return totalUnloadingWeight;
}
// For instrumentation/diagnostics only
long getTotalModelCacheOccupancy() {
return totalModelCacheOccupancy;
}

void removeUnloadBufferEntry(Map<String, ?> entries) { //TODO TBD maybe static
entries.remove(UNLOAD_BUFFER_CACHE_KEY);
}
Expand Down
74 changes: 54 additions & 20 deletions src/main/java/com/ibm/watson/modelmesh/ModelMesh.java
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,9 @@ public abstract class ModelMesh extends ThriftService
protected /*final*/ long defaultAssumeLoadedAfterMs;

// time after which loading failure records expire (allowing for re-attempts)
public final long LOAD_FAILURE_EXPIRY_MS = getLongParameter(LOAD_FAILURE_EXPIRY_ENV_VAR, 600_000L); // default 10mins
public final long LOAD_FAILURE_EXPIRY_MS = getLongParameter(LOAD_FAILURE_EXPIRY_ENV_VAR, 900_000L); // default 15mins
// shorter expiry time for "in use" models (receiving recent requests)
public final long IN_USE_LOAD_FAILURE_EXPIRY_MS = (LOAD_FAILURE_EXPIRY_MS * 2) / 3;
public final long IN_USE_LOAD_FAILURE_EXPIRY_MS = LOAD_FAILURE_EXPIRY_MS / 2;
public static final int MAX_LOAD_FAILURES = 3;
// if unable to invoke in this many places, don't continue to load
public static final int MAX_LOAD_LOCATIONS = 5;
Expand Down Expand Up @@ -2197,7 +2197,7 @@ public final void run() {
}
}

private void waitForSpaceToLoad(int required) throws Exception {
private void waitForSpaceToLoad(final int required) throws Exception {
//assert unloadManager != null;
// here state == WAITING; we wait if necessary for cache space to become available
// -- specifically that we can add back our prior subtraction from the aggregate
Expand Down Expand Up @@ -2372,7 +2372,10 @@ private boolean setFailed(Throwable t) {

// Called at most once, by the thread which moved the state to FAILED.
// unloadDelay == -1L means load wasn't attempted so don't unload at all
private boolean failed(Throwable t, long unloadDelay) {
private void failed(Throwable t, long unloadDelay) {
if (state != FAILED) {
return;
}

boolean isShuttingDown = shuttingDown;
if (!isShuttingDown) {
Expand All @@ -2382,9 +2385,21 @@ private boolean failed(Throwable t, long unloadDelay) {
int weightBefore;
synchronized (CacheEntry.this) {
weightBefore = getWeight();
updateWeight(FAILED_WEIGHT);
if (!isShuttingDown && unloadDelay >= 0) {
triggerUnload(weightBefore - FAILED_WEIGHT, unloadDelay);
int weightReduction = weightBefore - FAILED_WEIGHT;
if (weightReduction != 0) {
if (unloadManager != null) {
// Though this method is written the new entry adjustment, the accounting
// is identical to what we need here when reducing the size of the failed
// entry and triggering an unload tied to that space reduction.
unloadManager.adjustNewEntrySpaceRequest(-weightReduction, this, false);
if (!isShuttingDown && unloadDelay >= 0) {
triggerUnload(weightReduction, unloadDelay);
} else {
unloadManager.unloadComplete(weightReduction, true, modelId);
}
} else {
updateWeight(FAILED_WEIGHT);
}
}
}

Expand All @@ -2393,7 +2408,7 @@ private boolean failed(Throwable t, long unloadDelay) {
ModelRecord mr = registry.getOrStrongIfAbsent(modelId);
while (true) {
if (mr == null) {
return true; // deleted while loading
return; // deleted while loading
}
if (lastUsed <= 0L) {
lastUsed = mr.getLastUsed();
Expand Down Expand Up @@ -2438,7 +2453,6 @@ private boolean failed(Throwable t, long unloadDelay) {
+ modelId, e);
}
}
return true;
}

// Used when there is a failure updating the KV registry prior to a load attempt
Expand Down Expand Up @@ -3609,10 +3623,12 @@ protected Object invokeModel(final String modelId, final Method method, final Me
ModelLoadException mle = newModelLoadException(
"KV store error attempting to prune model record: " + e,
KVSTORE_LOAD_FAILURE, e);
CacheEntry<?> failedEntry = new CacheEntry<>(modelId, mr, mle);
cacheEntry = unloadManager != null
? unloadManager.insertFailedPlaceholderEntry(modelId, failedEntry, mr.getLastUsed())
: runtimeCache.putIfAbsent(modelId, failedEntry, mr.getLastUsed());
if (io.grpc.Status.fromThrowable(e).getCode() != Code.CANCELLED) {
CacheEntry<?> failedEntry = new CacheEntry<>(modelId, mr, mle);
cacheEntry = unloadManager != null
? unloadManager.insertFailedPlaceholderEntry(modelId, failedEntry, mr.getLastUsed())
: runtimeCache.putIfAbsent(modelId, failedEntry, mr.getLastUsed());
}
if (cacheEntry == null) {
throw mle;
}
Expand Down Expand Up @@ -5223,7 +5239,7 @@ private ModelRecord handleUnexpectedFailedCacheEntry(CacheEntry<?> ce, ModelReco
|| ((ModelLoadException) failure).getTimeout() != KVSTORE_LOAD_FAILURE) {
// We assume that this is an expired entry yet to be cleaned up
if (ce.remove()) {
logger.info("Removed kv-store failure cache entry for model " + ce.modelId);
logger.info("Removed residual failed cache entry for model " + ce.modelId);
}
return mr;
}
Expand Down Expand Up @@ -5315,11 +5331,20 @@ protected void publishInstanceRecord(boolean force, boolean preShutdown) throws
long oldest = runtimeCache.oldestTime();
long cap = runtimeCache.capacity(), used = runtimeCache.weightedSize();
int count = runtimeCache.size(); //TODO maybe don't get every time
int unloadBufferWeight = -1, totalUnloadingWeight = -1;
long totalCacheOccupancy = -1;
if (unloadManager != null) {
runtimeCache.getEvictionLock().lock();
try {
unloadBufferWeight = unloadManager.getUnloadBufferWeight();
totalUnloadingWeight = unloadManager.getTotalUnloadingWeight();
totalCacheOccupancy = unloadManager.getTotalModelCacheOccupancy();
} finally {
runtimeCache.getEvictionLock().unlock();;
}
// remove unloading buffer weight from published values
int weight = unloadManager.getUnloadBufferWeight();
cap -= weight;
used -= weight;
cap -= unloadBufferWeight;
used -= unloadBufferWeight;
count--;
}
if (oldest == -1L) {
Expand Down Expand Up @@ -5395,7 +5420,13 @@ else if (curRec.isShuttingDown() == isShuttingDown
InstanceRecord existRec = instanceInfo.conditionalSetAndGet(instanceId, curRec, sessionId);
if (existRec == curRec) {
curRec.setActionableUpdate();
logger.info("Published new instance record: " + curRec);
String message = "Published new instance record: " + curRec;
if (unloadBufferWeight != -1) {
// Also log some internal values to help identify cache accounting anomalies
message += ", UBW=" + unloadBufferWeight + ", TUW=" + totalUnloadingWeight
+ ", TCO=" + totalCacheOccupancy;
}
logger.info(message);
lastPublished = now;
// our own record in clusterState will subsequently be
// updated via the instanceInfo listener.
Expand Down Expand Up @@ -5982,7 +6013,10 @@ public void run() {
}
}
j++;
if (loaded && !remLoaded) {
if (remFailed && ce != null && ce.isFailed()) {
// Also remove expired failure records from local cache
ce.remove();
} else if (loaded && !remLoaded) {
if (lastUsed <= 0L) {
lastUsed = runtimeCache.getLastUsedTime(modelId);
}
Expand Down Expand Up @@ -6027,7 +6061,7 @@ public void run() {
weightRemoved += weight;
}
} catch (Exception e) {
logger.error("Janitor exception while scaling copies"
logger.error("Janitor exception while scaling down copies"
+ (modelId != null ? " for model " + modelId : ""), e);
}
}
Expand Down