Skip to content

Commit

Permalink
Fix race condition in PageListener
Browse files Browse the repository at this point in the history
This PR
- Introduced an `AtomicInteger` called `pagesInFlight` to track the number of pages currently being processed. 
- Incremented `pagesInFlight` before processing each page and decremented it after processing is complete
- Adjusted the condition in `scheduleImputeHCTask` to check both `pagesInFlight.get() == 0` (all pages have been processed) and `sentOutPages.get() == receivedPages.get()` (all responses have been received) before scheduling the `imputeHC` task. 
- Removed the previous final check in `onResponse` that decided when to schedule `imputeHC`, relying instead on the updated counters for accurate synchronization.

These changes address the race condition where `sentOutPages` might not have been incremented in time before checking whether to schedule the `imputeHC` task. By accurately tracking the number of in-flight pages and sent pages, we ensure that `imputeHC` is executed only after all pages have been fully processed and all responses have been received.

Testing done:
1. Reproduced the race condition by starting two detectors with imputation. This causes an out of order illegal argument exception from RCF due to this race condition. Also verified the change fixed the problem.
2. added an IT for the above scenario.

Signed-off-by: Kaituo Li <kaituo@amazon.com>
  • Loading branch information
kaituo committed Oct 23, 2024
1 parent da73506 commit 64da191
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ protected void doExecute(Task task, ResultBulkRequestType request, ActionListene
// all non-zero anomaly grade index requests and index zero anomaly grade index requests with probability (1 - index pressure).
long totalBytes = indexingPressure.getCurrentCombinedCoordinatingAndPrimaryBytes() + indexingPressure.getCurrentReplicaBytes();
float indexingPressurePercent = (float) totalBytes / primaryAndCoordinatingLimits;
@SuppressWarnings("rawtypes")
List<? extends ResultWriteRequest> results = request.getResults();

if (results == null || results.size() < 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ class PageListener implements ActionListener<CompositeRetriever.Page> {
private String taskId;
private AtomicInteger receivedPages;
private AtomicInteger sentOutPages;
// By introducing pagesInFlight and incrementing it in the main thread before asynchronous processing begins,
// we ensure that the count of in-flight pages is accurate at all times. This allows us to reliably determine
// when all pages have been processed.
private AtomicInteger pagesInFlight;

PageListener(PageIterator pageIterator, Config config, long dataStartTime, long dataEndTime, String taskId) {
this.pageIterator = pageIterator;
Expand All @@ -220,14 +224,21 @@ class PageListener implements ActionListener<CompositeRetriever.Page> {
this.taskId = taskId;
this.receivedPages = new AtomicInteger();
this.sentOutPages = new AtomicInteger();
this.pagesInFlight = new AtomicInteger();
}

@Override
public void onResponse(CompositeRetriever.Page entityFeatures) {
// start processing next page after sending out features for previous page
if (pageIterator.hasNext()) {
pageIterator.next(this);
} else if (config.getImputationOption() != null) {
scheduleImputeHCTask();
}

// Increment pagesInFlight to track the processing of this page
pagesInFlight.incrementAndGet();

if (entityFeatures != null && false == entityFeatures.isEmpty()) {
LOG
.info(
Expand Down Expand Up @@ -309,19 +320,15 @@ public void onResponse(CompositeRetriever.Page entityFeatures) {
} catch (Exception e) {
LOG.error("Unexpected exception", e);
handleException(e);
} finally {
// Decrement pagesInFlight after processing is complete
pagesInFlight.decrementAndGet();
}
});
}

if (!pageIterator.hasNext() && config.getImputationOption() != null) {
if (sentOutPages.get() > 0) {
// at least 1 page sent out. Wait until all responses are back.
scheduleImputeHCTask();
} else {
// no data in current interval. Send out impute request right away.
imputeHC(dataStartTime, dataEndTime, configId, taskId);
}

} else {
// No entity features to process
// Decrement pagesInFlight immediately
pagesInFlight.decrementAndGet();
}
}

Expand Down Expand Up @@ -358,7 +365,10 @@ private void scheduleImputeHCTask() {

@Override
public void run() {
if (sentOutPages.get() == receivedPages.get()) {
// By using pagesInFlight in the condition within scheduleImputeHCTask, we ensure that imputeHC
// is executed only after all pages have been processed (pagesInFlight.get() == 0) and all
// responses have been received (sentOutPages.get() == receivedPages.get()).
if (pagesInFlight.get() == 0 && sentOutPages.get() == receivedPages.get()) {
if (!sent.get()) {
// since we don't know when cancel will succeed, need sent to ensure imputeHC is only called once
sent.set(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ protected String genDetector(
long windowDelayMinutes,
boolean hc,
ImputationMethod imputation,
long trainTimeMillis
long trainTimeMillis,
String name
) {
StringBuilder sb = new StringBuilder();
// common part
Expand Down
35 changes: 30 additions & 5 deletions src/test/java/org/opensearch/ad/e2e/MissingIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,27 @@ protected TrainResult createAndStartRealTimeDetector(
List<JsonObject> data,
ImputationMethod imputation,
boolean hc,
long trainTimeMillis
long trainTimeMillis,
String name
) throws Exception {
TrainResult trainResult = createDetector(numberOfEntities, trainTestSplit, data, imputation, hc, trainTimeMillis);
TrainResult trainResult = createDetector(numberOfEntities, trainTestSplit, data, imputation, hc, trainTimeMillis, name);
List<JsonObject> result = startRealTimeDetector(trainResult, numberOfEntities, intervalMinutes, true);
recordLastSeenFromResult(result);

return trainResult;
}

protected TrainResult createAndStartRealTimeDetector(
int numberOfEntities,
int trainTestSplit,
List<JsonObject> data,
ImputationMethod imputation,
boolean hc,
long trainTimeMillis
) throws Exception {
return createAndStartRealTimeDetector(numberOfEntities, trainTestSplit, data, imputation, hc, trainTimeMillis, "test");
}

protected TrainResult createAndStartHistoricalDetector(
int numberOfEntities,
int trainTestSplit,
Expand Down Expand Up @@ -115,12 +127,13 @@ protected TrainResult createDetector(
List<JsonObject> data,
ImputationMethod imputation,
boolean hc,
long trainTimeMillis
long trainTimeMillis,
String name
) throws Exception {
Instant trainTime = Instant.ofEpochMilli(trainTimeMillis);

Duration windowDelay = getWindowDelay(trainTimeMillis);
String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), hc, imputation, trainTimeMillis);
String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), hc, imputation, trainTimeMillis, name);

RestClient client = client();
String detectorId = createDetector(client, detector);
Expand All @@ -129,6 +142,17 @@ protected TrainResult createDetector(
return new TrainResult(detectorId, data, trainTestSplit * numberOfEntities, windowDelay, trainTime, "timestamp");
}

protected TrainResult createDetector(
int numberOfEntities,
int trainTestSplit,
List<JsonObject> data,
ImputationMethod imputation,
boolean hc,
long trainTimeMillis
) throws Exception {
return createDetector(numberOfEntities, trainTestSplit, data, imputation, hc, trainTimeMillis, "test");
}

protected Duration getWindowDelay(long trainTimeMillis) {
/*
* AD accepts windowDelay in the unit of minutes. Thus, we need to convert the delay in minutes. This will
Expand Down Expand Up @@ -156,7 +180,8 @@ protected abstract String genDetector(
long windowDelayMinutes,
boolean hc,
ImputationMethod imputation,
long trainTimeMillis
long trainTimeMillis,
String name
);

protected abstract AbstractSyntheticDataTest.GenData genData(
Expand Down
75 changes: 71 additions & 4 deletions src/test/java/org/opensearch/ad/e2e/MissingMultiFeatureIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,80 @@ public void testHCPrevious() throws Exception {
);
}

/**
* test we start two HC detector with zero imputation consecutively.
* We expect there is no out of order error from RCF.
* @throws Exception
*/
public void testDoubleHCZero() throws Exception {
lastSeen.clear();
int numberOfEntities = 2;

AbstractSyntheticDataTest.MISSING_MODE mode = AbstractSyntheticDataTest.MISSING_MODE.NO_MISSING_DATA;
ImputationMethod method = ImputationMethod.ZERO;

AbstractSyntheticDataTest.GenData dataGenerated = genData(trainTestSplit, numberOfEntities, mode);

// only ingest train data to avoid validation error as we use latest data time as starting point.
// otherwise, we will have too many missing points.
ingestUniformSingleFeatureData(
trainTestSplit + numberOfEntities * 6, // we only need a few to verify and trigger train.
dataGenerated.data
);

TrainResult trainResult1 = createAndStartRealTimeDetector(
numberOfEntities,
trainTestSplit,
dataGenerated.data,
method,
true,
dataGenerated.testStartTime,
"test1"
);

TrainResult trainResult2 = createAndStartRealTimeDetector(
numberOfEntities,
trainTestSplit,
dataGenerated.data,
method,
true,
dataGenerated.testStartTime,
"test2"
);

runTest(
dataGenerated.testStartTime,
dataGenerated,
trainResult1.windowDelay,
trainResult1.detectorId,
numberOfEntities,
mode,
method,
3,
true
);

runTest(
dataGenerated.testStartTime,
dataGenerated,
trainResult2.windowDelay,
trainResult2.detectorId,
numberOfEntities,
mode,
method,
3,
true
);
}

@Override
protected String genDetector(
int trainTestSplit,
long windowDelayMinutes,
boolean hc,
ImputationMethod imputation,
long trainTimeMillis
long trainTimeMillis,
String name
) {
StringBuilder sb = new StringBuilder();

Expand Down Expand Up @@ -185,7 +252,7 @@ protected String genDetector(
// common part
sb
.append(
"{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\""
"{ \"name\": \"%s\", \"description\": \"test\", \"time_field\": \"timestamp\""
+ ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_id\": \"feature2\", \"feature_name\": \"feature 2\", \"feature_enabled\": "
+ "\"true\", \"aggregation_query\": { \"Feature2\": { \"avg\": { \"field\": \"data\" } } } },"
+ featureWithFilter
Expand Down Expand Up @@ -226,9 +293,9 @@ protected String genDetector(
sb.append("\"schema_version\": 0}");

if (hc) {
return String.format(Locale.ROOT, sb.toString(), datasetName, intervalMinutes, trainTestSplit - 1, categoricalField);
return String.format(Locale.ROOT, sb.toString(), name, datasetName, intervalMinutes, trainTestSplit - 1, categoricalField);
} else {
return String.format(Locale.ROOT, sb.toString(), datasetName, intervalMinutes, trainTestSplit - 1);
return String.format(Locale.ROOT, sb.toString(), name, datasetName, intervalMinutes, trainTestSplit - 1);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public void testSingleStream() throws Exception {
);

Duration windowDelay = getWindowDelay(dataGenerated.testStartTime);
String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), false, method, dataGenerated.testStartTime);
String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), false, method, dataGenerated.testStartTime, "test");

Instant begin = Instant.ofEpochMilli(dataGenerated.data.get(0).get("timestamp").getAsLong());
Instant end = Instant.ofEpochMilli(dataGenerated.data.get(dataGenerated.data.size() - 1).get("timestamp").getAsLong());
Expand Down Expand Up @@ -63,7 +63,7 @@ public void testHC() throws Exception {
);

Duration windowDelay = getWindowDelay(dataGenerated.testStartTime);
String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), true, method, dataGenerated.testStartTime);
String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), true, method, dataGenerated.testStartTime, "test");

Instant begin = Instant.ofEpochMilli(dataGenerated.data.get(0).get("timestamp").getAsLong());
Instant end = Instant.ofEpochMilli(dataGenerated.data.get(dataGenerated.data.size() - 1).get("timestamp").getAsLong());
Expand Down

0 comments on commit 64da191

Please sign in to comment.