Skip to content

Commit

Permalink
Fix: Prevent resetting latest flag of real-time analysis when startin…
Browse files Browse the repository at this point in the history
…g historical analysis (#1287) (#1288)

This PR addresses a bug where starting a historical analysis after a real-time analysis on the same detector caused the real-time task’s latest flag to be incorrectly reset to false by the historical run.

The fix ensures that only the latest flags of the same analysis type are reset:
* Real-time analysis will only reset the latest flag of previous real-time analyses.
* Historical analysis will only reset the latest flag of previous historical analyses.

This PR also updated recencyEmphasis to have a minimum value of 2, aligning with RCF requirements.

Testing:
- Added an integration test to reproduce the bug and verified the fix.


(cherry picked from commit afd5da9)

Signed-off-by: Kaituo Li <kaituo@amazon.com>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
1 parent 822aab8 commit 48b56eb
Show file tree
Hide file tree
Showing 13 changed files with 263 additions and 62 deletions.
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,8 @@ List<String> jacocoExclusions = [
'org.opensearch.ad.task.ADBatchTaskCache',
'org.opensearch.timeseries.ratelimit.RateLimitedRequestWorker',
'org.opensearch.timeseries.util.TimeUtil',
'org.opensearch.ad.transport.ADHCImputeTransportAction',
'org.opensearch.timeseries.ml.RealTimeInferencer',
]


Expand Down
33 changes: 10 additions & 23 deletions src/main/java/org/opensearch/ad/task/ADTaskManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import static org.opensearch.ad.constant.ADCommonName.DETECTION_STATE_INDEX;
import static org.opensearch.ad.indices.ADIndexManagement.ALL_AD_RESULTS_INDEX_PATTERN;
import static org.opensearch.ad.model.ADTask.DETECTOR_ID_FIELD;
import static org.opensearch.ad.model.ADTaskType.ALL_HISTORICAL_TASK_TYPES;
import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES;
import static org.opensearch.ad.model.ADTaskType.REALTIME_TASK_TYPES;
import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_REQUEST_TIMEOUT;
Expand Down Expand Up @@ -1881,21 +1880,6 @@ private void maintainRunningHistoricalTask(ConcurrentLinkedQueue<ADTask> taskQue
}, TimeValue.timeValueSeconds(DEFAULT_MAINTAIN_INTERVAL_IN_SECONDS), AD_BATCH_TASK_THREAD_POOL_NAME);
}

/**
* Get list of task types.
* 1. If date range is null, will return all realtime task types
* 2. If date range is not null, will return all historical detector level tasks types
* if resetLatestTaskStateFlag is true; otherwise return all historical tasks types include
* HC entity level task type.
* @param dateRange detection date range
* @param resetLatestTaskStateFlag reset latest task state or not
* @return list of AD task types
*/
protected List<ADTaskType> getTaskTypes(DateRange dateRange, boolean resetLatestTaskStateFlag) {
// AD does not support run once
return getTaskTypes(dateRange, resetLatestTaskStateFlag, false);
}

@Override
protected BiCheckedFunction<XContentParser, String, ADTask, IOException> getTaskParser() {
return ADTask::parse;
Expand All @@ -1912,17 +1896,20 @@ public void createRunOnceTaskAndCleanupStaleTasks(
throw new UnsupportedOperationException("AD has no run once yet");
}

/**
* Get list of task types.
* 1. If date range is null, will return all realtime task types
* 2. If date range is not null, will return all historical detector level tasks types
*
* @param dateRange detection date range
* @return list of AD task types
*/
@Override
public List<ADTaskType> getTaskTypes(DateRange dateRange, boolean resetLatestTaskStateFlag, boolean runOnce) {
public List<ADTaskType> getTaskTypes(DateRange dateRange, boolean runOnce) {
if (dateRange == null) {
return REALTIME_TASK_TYPES;
} else {
if (resetLatestTaskStateFlag) {
// return all task types include HC entity task to make sure we can reset all tasks latest flag
return ALL_HISTORICAL_TASK_TYPES;
} else {
return HISTORICAL_DETECTOR_TASK_TYPES;
}
return HISTORICAL_DETECTOR_TASK_TYPES;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
import org.opensearch.action.support.nodes.TransportNodesAction;
import org.opensearch.ad.caching.ADCacheProvider;
import org.opensearch.ad.ml.ADRealTimeInferencer;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.timeseries.AnalysisType;
import org.opensearch.timeseries.NodeStateManager;
import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin;
import org.opensearch.timeseries.cluster.HashRing;
import org.opensearch.timeseries.ml.ModelState;
import org.opensearch.timeseries.ml.Sample;
import org.opensearch.timeseries.model.Config;
Expand Down Expand Up @@ -69,6 +71,7 @@ public class ADHCImputeTransportAction extends
private ADCacheProvider cache;
private NodeStateManager nodeStateManager;
private ADRealTimeInferencer adInferencer;
private HashRing hashRing;

@Inject
public ADHCImputeTransportAction(
Expand All @@ -78,7 +81,8 @@ public ADHCImputeTransportAction(
ActionFilters actionFilters,
ADCacheProvider priorityCache,
NodeStateManager nodeStateManager,
ADRealTimeInferencer adInferencer
ADRealTimeInferencer adInferencer,
HashRing hashRing
) {
super(
ADHCImputeAction.NAME,
Expand All @@ -94,6 +98,7 @@ public ADHCImputeTransportAction(
this.cache = priorityCache;
this.nodeStateManager = nodeStateManager;
this.adInferencer = adInferencer;
this.hashRing = hashRing;
}

@Override
Expand Down Expand Up @@ -131,9 +136,7 @@ protected ADHCImputeNodeResponse nodeOperation(ADHCImputeNodeRequest nodeRequest
long executionEndTime = dataEndMillis + windowDelayMillis;
String taskId = nodeRequest.getRequest().getTaskId();
for (ModelState<ThresholdedRandomCutForest> modelState : cache.get().getAllModels(configId)) {
// execution end time (when job starts execution in this interval) >= last used time => the model state is updated in
// previous intervals
if (executionEndTime >= modelState.getLastUsedTime().toEpochMilli()) {
if (shouldProcessModelState(modelState, executionEndTime, clusterService, hashRing)) {
double[] nanArray = new double[featureSize];
Arrays.fill(nanArray, Double.NaN);
adInferencer
Expand All @@ -156,4 +159,46 @@ protected ADHCImputeNodeResponse nodeOperation(ADHCImputeNodeRequest nodeRequest
}
}

/**
* Determines whether the model state should be processed based on various conditions.
*
* Conditions checked:
* - The model's last seen execution end time is not the minimum Instant value.
* - The current execution end time is greater than or equal to the model's last seen execution end time,
* indicating that the model state was updated in previous intervals.
* - The entity associated with the model state is present.
* - The owning node for real-time processing of the entity, with the same local version, is present in the hash ring.
* - The owning node for real-time processing matches the current local node.
*
* This method helps avoid processing model states that were already handled in previous intervals. The conditions
* ensure that only the relevant model states are processed while accounting for scenarios where processing can occur
* concurrently (e.g., during tests when multiple threads may operate quickly).
*
* @param modelState The current state of the model.
* @param executionEndTime The end time of the current execution interval.
* @param clusterService The service providing information about the current cluster node.
* @param hashRing The hash ring used to determine the owning node for real-time processing of entities.
* @return true if the model state should be processed; otherwise, false.
*/
private boolean shouldProcessModelState(
ModelState<ThresholdedRandomCutForest> modelState,
long executionEndTime,
ClusterService clusterService,
HashRing hashRing
) {
// Get the owning node for the entity in real-time processing from the hash ring
Optional<DiscoveryNode> owningNode = modelState.getEntity().isPresent()
? hashRing.getOwningNodeWithSameLocalVersionForRealtime(modelState.getEntity().get().toString())
: Optional.empty();

// Check if the model state conditions are met for processing
// We cannot use last used time as it will be updated whenever we update its priority in CacheBuffer.update when there is a
// PriorityCache.get.
return modelState.getLastSeenExecutionEndTime() != Instant.MIN
&& executionEndTime >= modelState.getLastSeenExecutionEndTime().toEpochMilli()
&& modelState.getEntity().isPresent()
&& owningNode.isPresent()
&& owningNode.get().getId().equals(clusterService.localNode().getId());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ public void createRunOnceTaskAndCleanupStaleTasks(
}

@Override
public List<ForecastTaskType> getTaskTypes(DateRange dateRange, boolean resetLatestTaskStateFlag, boolean runOnce) {
public List<ForecastTaskType> getTaskTypes(DateRange dateRange, boolean runOnce) {
if (runOnce) {
return ForecastTaskType.RUN_ONCE_TASK_TYPES;
} else {
Expand Down
11 changes: 0 additions & 11 deletions src/main/java/org/opensearch/timeseries/ml/RealTimeInferencer.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package org.opensearch.timeseries.ml;

import java.time.Instant;
import java.util.Collections;
import java.util.Locale;
import java.util.Map;
Expand Down Expand Up @@ -135,16 +134,6 @@ private boolean processWithTimeout(
}

private boolean tryProcess(Sample sample, ModelState<RCFModelType> modelState, Config config, String taskId, long curExecutionEnd) {
// execution end time (when job starts execution in this interval) >= last seen execution end time => the model state is updated in
// previous intervals
// This branch being true can happen while scheduled to waiting some other threads have already scored the same interval
// (e.g., during tests when everything happens fast)
// We cannot use last used time as it will be updated whenever we update its priority in CacheBuffer.update when there is a
// PriorityCache.get.
if (modelState.getLastSeenExecutionEndTime() != Instant.MIN
&& curExecutionEnd < modelState.getLastSeenExecutionEndTime().toEpochMilli()) {
return false;
}
String modelId = modelState.getModelId();
try {
RCFResultType result = modelManager.getResult(sample, modelState, modelId, config, taskId);
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/timeseries/model/Config.java
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,9 @@ protected Config(
return;
}

if (recencyEmphasis != null && (recencyEmphasis <= 0)) {
if (recencyEmphasis != null && recencyEmphasis <= 1) {
issueType = ValidationIssueType.RECENCY_EMPHASIS;
errorMessage = "recency emphasis has to be a positive integer";
errorMessage = "Recency emphasis must be an integer greater than 1.";
return;
}

Expand Down
9 changes: 6 additions & 3 deletions src/main/java/org/opensearch/timeseries/task/TaskManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,10 @@ public <T> void updateLatestFlagOfOldTasksAndCreateNewTask(
query.filter(new TermQueryBuilder(configIdFieldName, config.getId()));
query.filter(new TermQueryBuilder(TimeSeriesTask.IS_LATEST_FIELD, true));
// make sure we reset all latest task as false when user switch from single entity to HC, vice versa.
query.filter(new TermsQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, taskTypeToString(getTaskTypes(dateRange, true, runOnce))));
// Ensures that only the latest flags of the same analysis type are reset:
// Real-time analysis will only reset the latest flag of previous real-time analyses.
// Historical analysis will only reset the latest flag of previous historical analyses.
query.filter(new TermsQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, taskTypeToString(getTaskTypes(dateRange, runOnce))));
updateByQueryRequest.setQuery(query);
updateByQueryRequest.setRefresh(true);
String script = String.format(Locale.ROOT, "ctx._source.%s=%s;", TimeSeriesTask.IS_LATEST_FIELD, false);
Expand Down Expand Up @@ -432,7 +435,7 @@ public <T> void getAndExecuteOnLatestConfigTask(
}

public List<TaskTypeEnum> getTaskTypes(DateRange dateRange) {
return getTaskTypes(dateRange, false, false);
return getTaskTypes(dateRange, false);
}

/**
Expand Down Expand Up @@ -1081,5 +1084,5 @@ public abstract void createRunOnceTaskAndCleanupStaleTasks(
ActionListener<TaskClass> listener
);

public abstract List<TaskTypeEnum> getTaskTypes(DateRange dateRange, boolean resetLatestTaskStateFlag, boolean runOnce);
public abstract List<TaskTypeEnum> getTaskTypes(DateRange dateRange, boolean runOnce);
}
103 changes: 95 additions & 8 deletions src/test/java/org/opensearch/ad/AbstractADSyntheticDataTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import java.nio.charset.Charset;
import java.time.Duration;
import java.time.Instant;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.time.format.DateTimeParseException;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -56,20 +58,22 @@ protected static class TrainResult {
// last data time in data
public Instant finalDataTime;

public TrainResult(String detectorId, List<JsonObject> data, int rawDataTrainTestSplit, Duration windowDelay, Instant trainTime) {
public TrainResult(
String detectorId,
List<JsonObject> data,
int rawDataTrainTestSplit,
Duration windowDelay,
Instant trainTime,
String timeStampField
) {
this.detectorId = detectorId;
this.data = data;
this.rawDataTrainTestSplit = rawDataTrainTestSplit;
this.windowDelay = windowDelay;
this.trainTime = trainTime;

this.firstDataTime = getDataTime(0);
this.finalDataTime = getDataTime(data.size() - 1);
}

private Instant getDataTime(int index) {
String finalTimeStr = data.get(index).get("timestamp").getAsString();
return Instant.ofEpochMilli(Long.parseLong(finalTimeStr));
this.firstDataTime = getDataTimeOfEpochMillis(timeStampField, data, 0);
this.finalDataTime = getDataTimeOfEpochMillis(timeStampField, data, data.size() - 1);
}
}

Expand Down Expand Up @@ -689,4 +693,87 @@ public static boolean areDoublesEqual(double d1, double d2) {
public interface ConditionChecker {
boolean checkCondition(JsonArray hits, int expectedSize);
}

protected static Instant getDataTimeOfEpochMillis(String timestampField, List<JsonObject> data, int index) {
String finalTimeStr = data.get(index).get(timestampField).getAsString();
return Instant.ofEpochMilli(Long.parseLong(finalTimeStr));
}

protected static Instant getDataTimeofISOFormat(String timestampField, List<JsonObject> data, int index) {
String finalTimeStr = data.get(index).get(timestampField).getAsString();

try {
// Attempt to parse as an ISO 8601 formatted string (e.g., "2019-11-01T00:00:00Z")
ZonedDateTime zonedDateTime = ZonedDateTime.parse(finalTimeStr, DateTimeFormatter.ISO_DATE_TIME);
return zonedDateTime.toInstant();
} catch (DateTimeParseException ex) {
throw new IllegalArgumentException("Invalid timestamp format: " + finalTimeStr, ex);
}
}

protected List<JsonObject> getTasks(String detectorId, int size, ConditionChecker checker, RestClient client)
throws InterruptedException {
Request request = new Request("POST", "/_plugins/_anomaly_detection/detectors/tasks/_search");

String jsonTemplate = "{\n"
+ " \"size\": %d,\n"
+ " \"query\": {\n"
+ " \"bool\": {\n"
+ " \"filter\": [\n"
+ " {\n"
+ " \"term\": {\n"
+ " \"detector_id\": \"%s\"\n"
+ " }\n"
+ " }\n"
+ " ]\n"
+ " }\n"
+ " }\n"
+ "}";

// try to get size + 10 results if there are that many
String formattedJson = String.format(Locale.ROOT, jsonTemplate, size + 10, detectorId);

request.setJsonEntity(formattedJson);

// wait until results are available
// max wait for 60_000 milliseconds
int maxWaitCycles = 30;
do {
try {
JsonArray hits = getHits(client, request);
if (hits != null && checker.checkCondition(hits, size)) {
List<JsonObject> res = new ArrayList<>();
for (int i = 0; i < hits.size(); i++) {
JsonObject source = hits.get(i).getAsJsonObject().get("_source").getAsJsonObject();
res.add(source);
}

return res;
} else {
LOG.info("wait for result, previous result: {}, size: {}", hits, hits.size());
}
Thread.sleep(2_000 * size);
} catch (Exception e) {
LOG.warn("Exception while waiting for result", e);
Thread.sleep(2_000 * size);
}
} while (maxWaitCycles-- >= 0);

// leave some debug information before returning empty
try {
String matchAll = "{\n" + " \"size\": 1000,\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " }\n" + "}";
request.setJsonEntity(matchAll);
JsonArray hits = getHits(client, request);
LOG.info("Query: {}", formattedJson);
LOG.info("match all result: {}", hits);
} catch (Exception e) {
LOG.warn("Exception while waiting for match all result", e);
}

return new ArrayList<>();
}

protected static boolean getLatest(List<JsonObject> data, int index) {
return data.get(index).get("is_latest").getAsBoolean();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ protected TrainResult ingestTrainData(
long windowDelayMinutes = Duration.between(trainTime, Instant.now()).toMinutes();

Duration windowDelay = Duration.ofMinutes(windowDelayMinutes);
return new TrainResult(null, data, rawDataTrainTestSplit, windowDelay, trainTime);
return new TrainResult(null, data, rawDataTrainTestSplit, windowDelay, trainTime, "timestamp");
}

public Map<String, List<Entry<Instant, Instant>>> getAnomalyWindowsMap(String labelFileName) throws Exception {
Expand Down
2 changes: 1 addition & 1 deletion src/test/java/org/opensearch/ad/e2e/MissingIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ protected TrainResult createDetector(
String detectorId = createDetector(client, detector);
LOG.info("Created detector {}", detectorId);

return new TrainResult(detectorId, data, trainTestSplit * numberOfEntities, windowDelay, trainTime);
return new TrainResult(detectorId, data, trainTestSplit * numberOfEntities, windowDelay, trainTime, "timestamp");
}

protected Duration getWindowDelay(long trainTimeMillis) {
Expand Down
Loading

0 comments on commit 48b56eb

Please sign in to comment.