Skip to content

Commit

Permalink
Bump RCF Version and Fix Default Rules Bug in AnomalyDetector (opense…
Browse files Browse the repository at this point in the history
…arch-project#1334)

* Updated RCF version to the latest release.
* Fixed a bug in AnomalyDetector where default rules were not applied when the user provided an empty ruleset.

Testing:
* Added unit tests to cover the bug fix

Signed-off-by: Kaituo Li <kaituo@amazon.com>
  • Loading branch information
kaituo committed Oct 11, 2024
1 parent 1a8da83 commit cdb0dcd
Show file tree
Hide file tree
Showing 15 changed files with 452 additions and 28 deletions.
9 changes: 3 additions & 6 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ dependencies {
implementation group: 'com.yahoo.datasketches', name: 'memory', version: '0.12.2'
implementation group: 'commons-lang', name: 'commons-lang', version: '2.6'
implementation group: 'org.apache.commons', name: 'commons-pool2', version: '2.12.0'
implementation 'software.amazon.randomcutforest:randomcutforest-serialization:4.1.0'
implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:4.1.0'
implementation 'software.amazon.randomcutforest:randomcutforest-core:4.1.0'
implementation 'software.amazon.randomcutforest:randomcutforest-serialization:4.2.0'
implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:4.2.0'
implementation 'software.amazon.randomcutforest:randomcutforest-core:4.2.0'

// we inherit jackson-core from opensearch core
implementation "com.fasterxml.jackson.core:jackson-databind:2.16.1"
Expand Down Expand Up @@ -700,9 +700,6 @@ List<String> jacocoExclusions = [

// TODO: add test coverage (kaituo)
'org.opensearch.forecast.*',
'org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler',
'org.opensearch.timeseries.transport.SingleStreamResultRequest',
'org.opensearch.timeseries.rest.handler.IndexJobActionHandler.1',
'org.opensearch.timeseries.transport.SuggestConfigParamResponse',
'org.opensearch.timeseries.transport.SuggestConfigParamRequest',
'org.opensearch.timeseries.ml.MemoryAwareConcurrentHashmap',
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/ad/model/AnomalyDetector.java
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ public AnomalyDetector(

this.detectorType = isHC(categoryFields) ? MULTI_ENTITY.name() : SINGLE_ENTITY.name();

this.rules = rules == null ? getDefaultRule() : rules;
this.rules = rules == null || rules.isEmpty() ? getDefaultRule() : rules;
}

/*
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/org/opensearch/timeseries/JobProcessor.java
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ public void process(Job jobParameter, JobExecutionContext context) {
* @param executionStartTime analysis start time
* @param executionEndTime analysis end time
* @param recorder utility to record job execution result
* @param detector associated detector accessor
* @param config associated config accessor
*/
public void runJob(
Job jobParameter,
Expand All @@ -209,7 +209,7 @@ public void runJob(
Instant executionStartTime,
Instant executionEndTime,
ExecuteResultResponseRecorderType recorder,
Config detector
Config config
) {
String configId = jobParameter.getName();
if (lock == null) {
Expand All @@ -222,7 +222,7 @@ public void runJob(
"Can't run job due to null lock",
false,
recorder,
detector
config
);
return;
}
Expand All @@ -243,7 +243,7 @@ public void runJob(
user,
roles,
recorder,
detector
config
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,18 @@ protected void executeRequest(FeatureRequest coldStartRequest, ActionListener<Vo
);
IntermediateResultType result = modelManager.getResult(currentSample, modelState, modelId, config, taskId);
resultSaver.saveResult(result, config, coldStartRequest, modelId);
}

// only load model to memory for real time analysis that has no task id
if (null == coldStartRequest.getTaskId()) {
boolean hosted = cacheProvider.hostIfPossible(configOptional.get(), modelState);
LOG
.debug(
hosted
? new ParameterizedMessage("Loaded model {}.", modelState.getModelId())
: new ParameterizedMessage("Failed to load model {}.", modelState.getModelId())
);
// only load model to memory for real time analysis that has no task id
if (null == coldStartRequest.getTaskId()) {
boolean hosted = cacheProvider.hostIfPossible(configOptional.get(), modelState);
LOG
.debug(
hosted
? new ParameterizedMessage("Loaded model {}.", modelState.getModelId())
: new ParameterizedMessage("Failed to load model {}.", modelState.getModelId())
);
}
}

} finally {
listener.onResponse(null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ public void stopJob(String configId, TransportService transportService, ActionLi
}));
}

private ActionListener<StopConfigResponse> stopConfigListener(
public ActionListener<StopConfigResponse> stopConfigListener(
String configId,
TransportService transportService,
ActionListener<JobResponse> listener
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public void bulk(String resultIndexOrAlias, List<ResultType> results, String con
} catch (Exception e) {
String error = "Failed to bulk index result";
LOG.error(error, e);
listener.onFailure(new TimeSeriesException(error, e));
listener.onFailure(new TimeSeriesException(configId, error, e));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ protected String genDetector(String datasetName, int intervalMinutes, int trainT
if (relative) {
thresholdType1 = "actual_over_expected_ratio";
thresholdType2 = "expected_over_actual_ratio";
value = 0.3;
value = 0.2;
} else {
thresholdType1 = "actual_over_expected_margin";
thresholdType2 = "expected_over_actual_margin";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public void testRule() throws Exception {
minPrecision.put("Scottsdale", 0.5);
Map<String, Double> minRecall = new HashMap<>();
minRecall.put("Phoenix", 0.9);
minRecall.put("Scottsdale", 0.6);
minRecall.put("Scottsdale", 0.3);
verifyRule("rule", 10, minPrecision.size(), 1500, minPrecision, minRecall, 20);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public void testRule() throws Exception {
minPrecision.put("Scottsdale", 0.5);
Map<String, Double> minRecall = new HashMap<>();
minRecall.put("Phoenix", 0.9);
minRecall.put("Scottsdale", 0.6);
minRecall.put("Scottsdale", 0.3);
verifyRule("rule", 10, minPrecision.size(), 1500, minPrecision, minRecall, 20);
}
}
Expand Down
63 changes: 63 additions & 0 deletions src/test/java/org/opensearch/ad/ml/ADColdStartTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ad.ml;

import java.io.IOException;
import java.util.ArrayList;

import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.timeseries.TestHelpers;

import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;

public class ADColdStartTests extends OpenSearchTestCase {
private int baseDimensions = 1;
private int shingleSize = 8;
private int dimensions;

@Override
public void setUp() throws Exception {
super.setUp();
dimensions = baseDimensions * shingleSize;
}

/**
* Test if no explicit rule is provided, we apply 20% rule.
* @throws IOException when failing to constructor detector
*/
public void testEmptyRule() throws IOException {
AnomalyDetector detector = TestHelpers.AnomalyDetectorBuilder.newInstance(1).setRules(new ArrayList<>()).build();
ThresholdedRandomCutForest.Builder builder = new ThresholdedRandomCutForest.Builder<>()
.dimensions(dimensions)
.shingleSize(shingleSize);
ADColdStart.applyRule(builder, detector);

ThresholdedRandomCutForest forest = builder.build();
double[] ignore = forest.getPredictorCorrector().getIgnoreNearExpected();

// Specify a small delta for floating-point comparison
double delta = 1e-6;

assertArrayEquals("The double arrays are not equal", new double[] { 0, 0, 0.2, 0.2 }, ignore, delta);
}

public void testNullRule() throws IOException {
AnomalyDetector detector = TestHelpers.AnomalyDetectorBuilder.newInstance(1).setRules(null).build();
ThresholdedRandomCutForest.Builder builder = new ThresholdedRandomCutForest.Builder<>()
.dimensions(dimensions)
.shingleSize(shingleSize);
ADColdStart.applyRule(builder, detector);

ThresholdedRandomCutForest forest = builder.build();
double[] ignore = forest.getPredictorCorrector().getIgnoreNearExpected();

// Specify a small delta for floating-point comparison
double delta = 1e-6;

assertArrayEquals("The double arrays are not equal", new double[] { 0, 0, 0.2, 0.2 }, ignore, delta);
}
}
13 changes: 13 additions & 0 deletions src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,19 @@ public void testStatsAnomalyDetector() throws Exception {
.makeRequest(client(), "GET", TimeSeriesAnalyticsPlugin.LEGACY_AD_BASE + "/stats", ImmutableMap.of(), "", null);

assertEquals("Get stats failed", RestStatus.OK, TestHelpers.restStatus(statsResponse));

statsResponse = TestHelpers
.makeRequest(
client(),
"GET",
TimeSeriesAnalyticsPlugin.LEGACY_AD_BASE
+ "/_local/stats/ad_execute_request_count,anomaly_detectors_index_status,ad_hc_execute_request_count,ad_hc_execute_failure_count,ad_execute_failure_count,models_checkpoint_index_status,anomaly_results_index_status",
ImmutableMap.of(),
"",
null
);

assertEquals("Get stats failed", RestStatus.OK, TestHelpers.restStatus(statsResponse));
}

public void testPreviewAnomalyDetector() throws Exception {
Expand Down
29 changes: 29 additions & 0 deletions src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import org.apache.lucene.search.TotalHits;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.Version;
import org.opensearch.action.DocWriteRequest;
import org.opensearch.action.DocWriteResponse;
Expand Down Expand Up @@ -104,6 +105,7 @@
import org.opensearch.core.common.transport.TransportAddress;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.index.IndexNotFoundException;
Expand Down Expand Up @@ -136,6 +138,7 @@
import org.opensearch.timeseries.transport.JobResponse;
import org.opensearch.timeseries.transport.StatsNodeResponse;
import org.opensearch.timeseries.transport.StatsNodesResponse;
import org.opensearch.timeseries.transport.StopConfigResponse;
import org.opensearch.timeseries.util.ClientUtil;
import org.opensearch.timeseries.util.DiscoveryNodeFilterer;
import org.opensearch.transport.TransportResponseHandler;
Expand Down Expand Up @@ -1544,4 +1547,30 @@ public void testDeleteTaskDocs() {
verify(adTaskCacheManager, times(1)).addDeletedTask(anyString());
verify(function, times(1)).execute();
}

public void testStopConfigListener_onResponse_failure() {
// Arrange
String configId = randomAlphaOfLength(5);
TransportService transportService = mock(TransportService.class);
@SuppressWarnings("unchecked")
ActionListener<JobResponse> listener = mock(ActionListener.class);

// Act
ActionListener<StopConfigResponse> stopConfigListener = indexAnomalyDetectorJobActionHandler
.stopConfigListener(configId, transportService, listener);
StopConfigResponse stopConfigResponse = mock(StopConfigResponse.class);
when(stopConfigResponse.success()).thenReturn(false);

stopConfigListener.onResponse(stopConfigResponse);

// Assert
ArgumentCaptor<OpenSearchStatusException> exceptionCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class);

verify(adTaskManager, times(1))
.stopLatestRealtimeTask(eq(configId), eq(TaskState.FAILED), exceptionCaptor.capture(), eq(transportService), eq(listener));

OpenSearchStatusException capturedException = exceptionCaptor.getValue();
assertEquals("Failed to delete model", capturedException.getMessage());
assertEquals(RestStatus.INTERNAL_SERVER_ERROR, capturedException.status());
}
}
Loading

0 comments on commit cdb0dcd

Please sign in to comment.