diff --git a/CHANGELOG.md b/CHANGELOG.md index 573d2485f2ac3..82f16e99fdd45 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,9 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Recommissioning of zone. REST layer support. ([#4624](https://github.com/opensearch-project/OpenSearch/pull/4604)) - Build no-jdk distributions as part of release build ([#4902](https://github.com/opensearch-project/OpenSearch/pull/4902)) - Use getParameterCount instead of getParameterTypes ([#4821](https://github.com/opensearch-project/OpenSearch/pull/4821)) +- Added in-flight cancellation of SearchShardTask based on resource consumption ([#4565](https://github.com/opensearch-project/OpenSearch/pull/4565)) +- Added resource usage trackers for in-flight cancellation of SearchShardTask ([#4805](https://github.com/opensearch-project/OpenSearch/pull/4805)) +- Added search backpressure stats API ([#4932](https://github.com/opensearch-project/OpenSearch/pull/4932)) ### Dependencies - Bumps `com.diffplug.spotless` from 6.9.1 to 6.10.0 diff --git a/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/NodeStats.java b/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/NodeStats.java index 7f0ac615cc449..f28411e8b6446 100644 --- a/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/NodeStats.java +++ b/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/NodeStats.java @@ -56,6 +56,7 @@ import org.opensearch.node.AdaptiveSelectionStats; import org.opensearch.script.ScriptCacheStats; import org.opensearch.script.ScriptStats; +import org.opensearch.search.backpressure.stats.SearchBackpressureStats; import org.opensearch.threadpool.ThreadPoolStats; import org.opensearch.transport.TransportStats; @@ -119,6 +120,9 @@ public class NodeStats extends BaseNodeResponse implements ToXContentFragment { @Nullable private ShardIndexingPressureStats shardIndexingPressureStats; + @Nullable + private SearchBackpressureStats searchBackpressureStats; + public NodeStats(StreamInput in) throws IOException { super(in); timestamp = in.readVLong(); @@ -156,6 +160,11 @@ public NodeStats(StreamInput in) throws IOException { shardIndexingPressureStats = null; } + if (in.getVersion().onOrAfter(Version.V_2_4_0)) { + searchBackpressureStats = in.readOptionalWriteable(SearchBackpressureStats::new); + } else { + searchBackpressureStats = null; + } } public NodeStats( @@ -176,7 +185,8 @@ public NodeStats( @Nullable AdaptiveSelectionStats adaptiveSelectionStats, @Nullable ScriptCacheStats scriptCacheStats, @Nullable IndexingPressureStats indexingPressureStats, - @Nullable ShardIndexingPressureStats shardIndexingPressureStats + @Nullable ShardIndexingPressureStats shardIndexingPressureStats, + @Nullable SearchBackpressureStats searchBackpressureStats ) { super(node); this.timestamp = timestamp; @@ -196,6 +206,7 @@ public NodeStats( this.scriptCacheStats = scriptCacheStats; this.indexingPressureStats = indexingPressureStats; this.shardIndexingPressureStats = shardIndexingPressureStats; + this.searchBackpressureStats = searchBackpressureStats; } public long getTimestamp() { @@ -305,6 +316,11 @@ public ShardIndexingPressureStats getShardIndexingPressureStats() { return shardIndexingPressureStats; } + @Nullable + public SearchBackpressureStats getSearchBackpressureStats() { + return searchBackpressureStats; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); @@ -336,6 +352,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getVersion().onOrAfter(Version.V_1_2_0)) { out.writeOptionalWriteable(shardIndexingPressureStats); } + if (out.getVersion().onOrAfter(Version.V_2_4_0)) { + out.writeOptionalWriteable(searchBackpressureStats); + } } @Override @@ -408,6 +427,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (getShardIndexingPressureStats() != null) { getShardIndexingPressureStats().toXContent(builder, params); } + if (getSearchBackpressureStats() != null) { + getSearchBackpressureStats().toXContent(builder, params); + } return builder; } } diff --git a/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/NodesStatsRequest.java b/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/NodesStatsRequest.java index babec0b7c119f..0f24f478abc51 100644 --- a/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/NodesStatsRequest.java +++ b/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/NodesStatsRequest.java @@ -237,7 +237,8 @@ public enum Metric { ADAPTIVE_SELECTION("adaptive_selection"), SCRIPT_CACHE("script_cache"), INDEXING_PRESSURE("indexing_pressure"), - SHARD_INDEXING_PRESSURE("shard_indexing_pressure"); + SHARD_INDEXING_PRESSURE("shard_indexing_pressure"), + SEARCH_BACKPRESSURE("search_backpressure"); private String metricName; diff --git a/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/TransportNodesStatsAction.java b/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/TransportNodesStatsAction.java index 644c7f02d45f0..2b08b0844064a 100644 --- a/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/TransportNodesStatsAction.java +++ b/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/TransportNodesStatsAction.java @@ -118,7 +118,8 @@ protected NodeStats nodeOperation(NodeStatsRequest nodeStatsRequest) { NodesStatsRequest.Metric.ADAPTIVE_SELECTION.containedIn(metrics), NodesStatsRequest.Metric.SCRIPT_CACHE.containedIn(metrics), NodesStatsRequest.Metric.INDEXING_PRESSURE.containedIn(metrics), - NodesStatsRequest.Metric.SHARD_INDEXING_PRESSURE.containedIn(metrics) + NodesStatsRequest.Metric.SHARD_INDEXING_PRESSURE.containedIn(metrics), + NodesStatsRequest.Metric.SEARCH_BACKPRESSURE.containedIn(metrics) ); } diff --git a/server/src/main/java/org/opensearch/action/admin/cluster/stats/TransportClusterStatsAction.java b/server/src/main/java/org/opensearch/action/admin/cluster/stats/TransportClusterStatsAction.java index a13932e137ab0..8c6c5faf2ed14 100644 --- a/server/src/main/java/org/opensearch/action/admin/cluster/stats/TransportClusterStatsAction.java +++ b/server/src/main/java/org/opensearch/action/admin/cluster/stats/TransportClusterStatsAction.java @@ -162,6 +162,7 @@ protected ClusterStatsNodeResponse nodeOperation(ClusterStatsNodeRequest nodeReq false, false, false, + false, false ); List shardsStats = new ArrayList<>(); diff --git a/server/src/main/java/org/opensearch/cluster/ClusterModule.java b/server/src/main/java/org/opensearch/cluster/ClusterModule.java index d1fb3fcce87fd..bf23c035635fd 100644 --- a/server/src/main/java/org/opensearch/cluster/ClusterModule.java +++ b/server/src/main/java/org/opensearch/cluster/ClusterModule.java @@ -97,7 +97,6 @@ import org.opensearch.script.ScriptMetadata; import org.opensearch.snapshots.SnapshotsInfoService; import org.opensearch.tasks.Task; -import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.tasks.TaskResultsService; import java.util.ArrayList; @@ -416,7 +415,6 @@ protected void configure() { bind(NodeMappingRefreshAction.class).asEagerSingleton(); bind(MappingUpdatedAction.class).asEagerSingleton(); bind(TaskResultsService.class).asEagerSingleton(); - bind(TaskResourceTrackingService.class).asEagerSingleton(); bind(AllocationDeciders.class).toInstance(allocationDeciders); bind(ShardsAllocator.class).toInstance(shardsAllocator); } diff --git a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java index 0a67f94707fa2..6dfce7089bda8 100644 --- a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java +++ b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java @@ -42,6 +42,12 @@ import org.opensearch.index.ShardIndexingPressureMemoryManager; import org.opensearch.index.ShardIndexingPressureSettings; import org.opensearch.index.ShardIndexingPressureStore; +import org.opensearch.search.backpressure.settings.NodeDuressSettings; +import org.opensearch.search.backpressure.settings.SearchBackpressureSettings; +import org.opensearch.search.backpressure.settings.SearchShardTaskSettings; +import org.opensearch.search.backpressure.trackers.CpuUsageTracker; +import org.opensearch.search.backpressure.trackers.ElapsedTimeTracker; +import org.opensearch.search.backpressure.trackers.HeapUsageTracker; import org.opensearch.tasks.TaskManager; import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.watcher.ResourceWatcherService; @@ -583,7 +589,22 @@ public void apply(Settings value, Settings current, Settings previous) { ShardIndexingPressureMemoryManager.MAX_OUTSTANDING_REQUESTS, IndexingPressure.MAX_INDEXING_BYTES, TaskResourceTrackingService.TASK_RESOURCE_TRACKING_ENABLED, - TaskManager.TASK_RESOURCE_CONSUMERS_ENABLED + TaskManager.TASK_RESOURCE_CONSUMERS_ENABLED, + + // Settings related to search backpressure + SearchBackpressureSettings.SETTING_MODE, + SearchBackpressureSettings.SETTING_CANCELLATION_RATIO, + SearchBackpressureSettings.SETTING_CANCELLATION_RATE, + SearchBackpressureSettings.SETTING_CANCELLATION_BURST, + NodeDuressSettings.SETTING_NUM_SUCCESSIVE_BREACHES, + NodeDuressSettings.SETTING_CPU_THRESHOLD, + NodeDuressSettings.SETTING_HEAP_THRESHOLD, + SearchShardTaskSettings.SETTING_TOTAL_HEAP_PERCENT_THRESHOLD, + HeapUsageTracker.SETTING_HEAP_PERCENT_THRESHOLD, + HeapUsageTracker.SETTING_HEAP_VARIANCE_THRESHOLD, + HeapUsageTracker.SETTING_HEAP_MOVING_AVERAGE_WINDOW_SIZE, + CpuUsageTracker.SETTING_CPU_TIME_MILLIS_THRESHOLD, + ElapsedTimeTracker.SETTING_ELAPSED_TIME_MILLIS_THRESHOLD ) ) ); diff --git a/server/src/main/java/org/opensearch/common/util/MovingAverage.java b/server/src/main/java/org/opensearch/common/util/MovingAverage.java new file mode 100644 index 0000000000000..650ba62ecd8c8 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/util/MovingAverage.java @@ -0,0 +1,57 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.util; + +/** + * MovingAverage is used to calculate the moving average of last 'n' observations. + * + * @opensearch.internal + */ +public class MovingAverage { + private final int windowSize; + private final long[] observations; + + private long count = 0; + private long sum = 0; + private double average = 0; + + public MovingAverage(int windowSize) { + if (windowSize <= 0) { + throw new IllegalArgumentException("window size must be greater than zero"); + } + + this.windowSize = windowSize; + this.observations = new long[windowSize]; + } + + /** + * Records a new observation and evicts the n-th last observation. + */ + public synchronized double record(long value) { + long delta = value - observations[(int) (count % observations.length)]; + observations[(int) (count % observations.length)] = value; + + count++; + sum += delta; + average = (double) sum / Math.min(count, observations.length); + return average; + } + + public double getAverage() { + return average; + } + + public long getCount() { + return count; + } + + public boolean isReady() { + return count >= windowSize; + } +} diff --git a/server/src/main/java/org/opensearch/common/util/Streak.java b/server/src/main/java/org/opensearch/common/util/Streak.java new file mode 100644 index 0000000000000..5f6ad3021659e --- /dev/null +++ b/server/src/main/java/org/opensearch/common/util/Streak.java @@ -0,0 +1,33 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.util; + +import java.util.concurrent.atomic.AtomicInteger; + +/** + * Streak is a data structure that keeps track of the number of successive successful events. + * + * @opensearch.internal + */ +public class Streak { + private final AtomicInteger successiveSuccessfulEvents = new AtomicInteger(); + + public int record(boolean isSuccessful) { + if (isSuccessful) { + return successiveSuccessfulEvents.incrementAndGet(); + } else { + successiveSuccessfulEvents.set(0); + return 0; + } + } + + public int length() { + return successiveSuccessfulEvents.get(); + } +} diff --git a/server/src/main/java/org/opensearch/common/util/TokenBucket.java b/server/src/main/java/org/opensearch/common/util/TokenBucket.java new file mode 100644 index 0000000000000..d2e7e836bf07f --- /dev/null +++ b/server/src/main/java/org/opensearch/common/util/TokenBucket.java @@ -0,0 +1,124 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.util; + +import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.LongSupplier; + +/** + * TokenBucket is used to limit the number of operations at a constant rate while allowing for short bursts. + * + * @opensearch.internal + */ +public class TokenBucket { + /** + * Defines a monotonically increasing counter. + * + * Usage examples: + * 1. clock = System::nanoTime can be used to perform rate-limiting per unit time + * 2. clock = AtomicLong::get can be used to perform rate-limiting per unit number of operations + */ + private final LongSupplier clock; + + /** + * Defines the number of tokens added to the bucket per clock cycle. + */ + private final double rate; + + /** + * Defines the capacity and the maximum number of operations that can be performed per clock cycle before + * the bucket runs out of tokens. + */ + private final double burst; + + /** + * Defines the current state of the token bucket. + */ + private final AtomicReference state; + + public TokenBucket(LongSupplier clock, double rate, double burst) { + this(clock, rate, burst, burst); + } + + public TokenBucket(LongSupplier clock, double rate, double burst, double initialTokens) { + if (rate <= 0.0) { + throw new IllegalArgumentException("rate must be greater than zero"); + } + + if (burst <= 0.0) { + throw new IllegalArgumentException("burst must be greater than zero"); + } + + this.clock = clock; + this.rate = rate; + this.burst = burst; + this.state = new AtomicReference<>(new State(Math.min(initialTokens, burst), clock.getAsLong())); + } + + /** + * If there are enough tokens in the bucket, it requests/deducts 'n' tokens and returns true. + * Otherwise, returns false and leaves the bucket untouched. + */ + public boolean request(double n) { + if (n <= 0) { + throw new IllegalArgumentException("requested tokens must be greater than zero"); + } + + // Refill tokens + State currentState, updatedState; + do { + currentState = state.get(); + long now = clock.getAsLong(); + double incr = (now - currentState.lastRefilledAt) * rate; + updatedState = new State(Math.min(currentState.tokens + incr, burst), now); + } while (state.compareAndSet(currentState, updatedState) == false); + + // Deduct tokens + do { + currentState = state.get(); + if (currentState.tokens < n) { + return false; + } + updatedState = new State(currentState.tokens - n, currentState.lastRefilledAt); + } while (state.compareAndSet(currentState, updatedState) == false); + + return true; + } + + public boolean request() { + return request(1.0); + } + + /** + * Represents an immutable token bucket state. + */ + private static class State { + final double tokens; + final long lastRefilledAt; + + public State(double tokens, long lastRefilledAt) { + this.tokens = tokens; + this.lastRefilledAt = lastRefilledAt; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + State state = (State) o; + return Double.compare(state.tokens, tokens) == 0 && lastRefilledAt == state.lastRefilledAt; + } + + @Override + public int hashCode() { + return Objects.hash(tokens, lastRefilledAt); + } + } +} diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index 907e21d861582..07512de763ce9 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -45,6 +45,10 @@ import org.opensearch.indices.replication.SegmentReplicationSourceFactory; import org.opensearch.indices.replication.SegmentReplicationTargetService; import org.opensearch.indices.replication.SegmentReplicationSourceService; +import org.opensearch.search.backpressure.SearchBackpressureService; +import org.opensearch.search.backpressure.settings.SearchBackpressureSettings; +import org.opensearch.tasks.TaskResourceTrackingService; +import org.opensearch.threadpool.RunnableTaskExecutionListener; import org.opensearch.index.store.RemoteSegmentStoreDirectoryFactory; import org.opensearch.watcher.ResourceWatcherService; import org.opensearch.Assertions; @@ -797,6 +801,23 @@ protected Node( // development. Then we can deprecate Getter and Setter for IndexingPressureService in ClusterService (#478). clusterService.setIndexingPressureService(indexingPressureService); + final TaskResourceTrackingService taskResourceTrackingService = new TaskResourceTrackingService( + settings, + clusterService.getClusterSettings(), + threadPool + ); + + final SearchBackpressureSettings searchBackpressureSettings = new SearchBackpressureSettings( + settings, + clusterService.getClusterSettings() + ); + + final SearchBackpressureService searchBackpressureService = new SearchBackpressureService( + searchBackpressureSettings, + taskResourceTrackingService, + threadPool + ); + final RecoverySettings recoverySettings = new RecoverySettings(settings, settingsModule.getClusterSettings()); RepositoriesModule repositoriesModule = new RepositoriesModule( this.environment, @@ -884,7 +905,8 @@ protected Node( responseCollectorService, searchTransportService, indexingPressureService, - searchModule.getValuesSourceRegistry().getUsageService() + searchModule.getValuesSourceRegistry().getUsageService(), + searchBackpressureService ); final SearchService searchService = newSearchService( @@ -942,6 +964,8 @@ protected Node( b.bind(AnalysisRegistry.class).toInstance(analysisModule.getAnalysisRegistry()); b.bind(IngestService.class).toInstance(ingestService); b.bind(IndexingPressureService.class).toInstance(indexingPressureService); + b.bind(TaskResourceTrackingService.class).toInstance(taskResourceTrackingService); + b.bind(SearchBackpressureService.class).toInstance(searchBackpressureService); b.bind(UsageService.class).toInstance(usageService); b.bind(AggregationUsageService.class).toInstance(searchModule.getValuesSourceRegistry().getUsageService()); b.bind(NamedWriteableRegistry.class).toInstance(namedWriteableRegistry); @@ -1105,6 +1129,7 @@ public Node start() throws NodeValidationException { injector.getInstance(SearchService.class).start(); injector.getInstance(FsHealthService.class).start(); nodeService.getMonitorService().start(); + nodeService.getSearchBackpressureService().start(); final ClusterService clusterService = injector.getInstance(ClusterService.class); @@ -1260,6 +1285,7 @@ private Node stop() { injector.getInstance(NodeConnectionsService.class).stop(); injector.getInstance(FsHealthService.class).stop(); nodeService.getMonitorService().stop(); + nodeService.getSearchBackpressureService().stop(); injector.getInstance(GatewayService.class).stop(); injector.getInstance(SearchService.class).stop(); injector.getInstance(TransportService.class).stop(); @@ -1319,6 +1345,7 @@ public synchronized void close() throws IOException { toClose.add(injector.getInstance(Discovery.class)); toClose.add(() -> stopWatch.stop().start("monitor")); toClose.add(nodeService.getMonitorService()); + toClose.add(nodeService.getSearchBackpressureService()); toClose.add(() -> stopWatch.stop().start("fsHealth")); toClose.add(injector.getInstance(FsHealthService.class)); toClose.add(() -> stopWatch.stop().start("gateway")); diff --git a/server/src/main/java/org/opensearch/node/NodeService.java b/server/src/main/java/org/opensearch/node/NodeService.java index ab98b47c7287b..f24e85d4ea117 100644 --- a/server/src/main/java/org/opensearch/node/NodeService.java +++ b/server/src/main/java/org/opensearch/node/NodeService.java @@ -53,6 +53,7 @@ import org.opensearch.plugins.PluginsService; import org.opensearch.script.ScriptService; import org.opensearch.search.aggregations.support.AggregationUsageService; +import org.opensearch.search.backpressure.SearchBackpressureService; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -81,6 +82,7 @@ public class NodeService implements Closeable { private final SearchTransportService searchTransportService; private final IndexingPressureService indexingPressureService; private final AggregationUsageService aggregationUsageService; + private final SearchBackpressureService searchBackpressureService; private final Discovery discovery; @@ -101,7 +103,8 @@ public class NodeService implements Closeable { ResponseCollectorService responseCollectorService, SearchTransportService searchTransportService, IndexingPressureService indexingPressureService, - AggregationUsageService aggregationUsageService + AggregationUsageService aggregationUsageService, + SearchBackpressureService searchBackpressureService ) { this.settings = settings; this.threadPool = threadPool; @@ -119,6 +122,7 @@ public class NodeService implements Closeable { this.searchTransportService = searchTransportService; this.indexingPressureService = indexingPressureService; this.aggregationUsageService = aggregationUsageService; + this.searchBackpressureService = searchBackpressureService; clusterService.addStateApplier(ingestService); } @@ -169,7 +173,8 @@ public NodeStats stats( boolean adaptiveSelection, boolean scriptCache, boolean indexingPressure, - boolean shardIndexingPressure + boolean shardIndexingPressure, + boolean searchBackpressure ) { // for indices stats we want to include previous allocated shards stats as well (it will // only be applied to the sensible ones to use, like refresh/merge/flush/indexing stats) @@ -191,7 +196,8 @@ public NodeStats stats( adaptiveSelection ? responseCollectorService.getAdaptiveStats(searchTransportService.getPendingSearchRequests()) : null, scriptCache ? scriptService.cacheStats() : null, indexingPressure ? this.indexingPressureService.nodeStats() : null, - shardIndexingPressure ? this.indexingPressureService.shardStats(indices) : null + shardIndexingPressure ? this.indexingPressureService.shardStats(indices) : null, + searchBackpressure ? this.searchBackpressureService.nodeStats() : null ); } @@ -203,6 +209,10 @@ public MonitorService getMonitorService() { return monitorService; } + public SearchBackpressureService getSearchBackpressureService() { + return searchBackpressureService; + } + @Override public void close() throws IOException { IOUtils.close(indicesService); diff --git a/server/src/main/java/org/opensearch/search/backpressure/SearchBackpressureService.java b/server/src/main/java/org/opensearch/search/backpressure/SearchBackpressureService.java new file mode 100644 index 0000000000000..fd13198b957da --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/SearchBackpressureService.java @@ -0,0 +1,335 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.search.SearchShardTask; +import org.opensearch.common.component.AbstractLifecycleComponent; +import org.opensearch.common.util.TokenBucket; +import org.opensearch.monitor.jvm.JvmStats; +import org.opensearch.monitor.process.ProcessProbe; +import org.opensearch.search.backpressure.settings.SearchBackpressureMode; +import org.opensearch.search.backpressure.settings.SearchBackpressureSettings; +import org.opensearch.search.backpressure.stats.SearchBackpressureStats; +import org.opensearch.search.backpressure.stats.SearchShardTaskStats; +import org.opensearch.search.backpressure.trackers.CpuUsageTracker; +import org.opensearch.search.backpressure.trackers.ElapsedTimeTracker; +import org.opensearch.search.backpressure.trackers.HeapUsageTracker; +import org.opensearch.search.backpressure.trackers.NodeDuressTracker; +import org.opensearch.search.backpressure.trackers.TaskResourceUsageTracker; +import org.opensearch.search.backpressure.trackers.TaskResourceUsageTrackerType; +import org.opensearch.tasks.CancellableTask; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskCancellation; +import org.opensearch.tasks.TaskResourceTrackingService; +import org.opensearch.tasks.TaskResourceTrackingService.TaskCompletionListener; +import org.opensearch.threadpool.Scheduler; +import org.opensearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.LongSupplier; +import java.util.stream.Collectors; + +/** + * SearchBackpressureService is responsible for monitoring and cancelling in-flight search tasks if they are + * breaching resource usage limits when the node is in duress. + * + * @opensearch.internal + */ +public class SearchBackpressureService extends AbstractLifecycleComponent + implements + TaskCompletionListener, + SearchBackpressureSettings.Listener { + private static final Logger logger = LogManager.getLogger(SearchBackpressureService.class); + + private volatile Scheduler.Cancellable scheduledFuture; + + private final SearchBackpressureSettings settings; + private final TaskResourceTrackingService taskResourceTrackingService; + private final ThreadPool threadPool; + private final LongSupplier timeNanosSupplier; + + private final List nodeDuressTrackers; + private final List taskResourceUsageTrackers; + + private final AtomicReference taskCancellationRateLimiter = new AtomicReference<>(); + private final AtomicReference taskCancellationRatioLimiter = new AtomicReference<>(); + + // Currently, only the state of SearchShardTask is being tracked. + // This can be generalized to Map once we start supporting cancellation of SearchTasks as well. + private final SearchBackpressureState state = new SearchBackpressureState(); + + public SearchBackpressureService( + SearchBackpressureSettings settings, + TaskResourceTrackingService taskResourceTrackingService, + ThreadPool threadPool + ) { + this( + settings, + taskResourceTrackingService, + threadPool, + System::nanoTime, + List.of( + new NodeDuressTracker( + () -> ProcessProbe.getInstance().getProcessCpuPercent() / 100.0 >= settings.getNodeDuressSettings().getCpuThreshold() + ), + new NodeDuressTracker( + () -> JvmStats.jvmStats().getMem().getHeapUsedPercent() / 100.0 >= settings.getNodeDuressSettings().getHeapThreshold() + ) + ), + List.of(new CpuUsageTracker(settings), new HeapUsageTracker(settings), new ElapsedTimeTracker(settings, System::nanoTime)) + ); + } + + public SearchBackpressureService( + SearchBackpressureSettings settings, + TaskResourceTrackingService taskResourceTrackingService, + ThreadPool threadPool, + LongSupplier timeNanosSupplier, + List nodeDuressTrackers, + List taskResourceUsageTrackers + ) { + this.settings = settings; + this.settings.addListener(this); + this.taskResourceTrackingService = taskResourceTrackingService; + this.taskResourceTrackingService.addTaskCompletionListener(this); + this.threadPool = threadPool; + this.timeNanosSupplier = timeNanosSupplier; + this.nodeDuressTrackers = nodeDuressTrackers; + this.taskResourceUsageTrackers = taskResourceUsageTrackers; + + this.taskCancellationRateLimiter.set( + new TokenBucket(timeNanosSupplier, getSettings().getCancellationRateNanos(), getSettings().getCancellationBurst()) + ); + + this.taskCancellationRatioLimiter.set( + new TokenBucket(state::getCompletionCount, getSettings().getCancellationRatio(), getSettings().getCancellationBurst()) + ); + } + + void doRun() { + SearchBackpressureMode mode = getSettings().getMode(); + if (mode == SearchBackpressureMode.DISABLED) { + return; + } + + if (isNodeInDuress() == false) { + return; + } + + // We are only targeting in-flight cancellation of SearchShardTask for now. + List searchShardTasks = getSearchShardTasks(); + + // Force-refresh usage stats of these tasks before making a cancellation decision. + taskResourceTrackingService.refreshResourceStats(searchShardTasks.toArray(new Task[0])); + + // Skip cancellation if the increase in heap usage is not due to search requests. + if (isHeapUsageDominatedBySearch(searchShardTasks) == false) { + return; + } + + for (TaskCancellation taskCancellation : getTaskCancellations(searchShardTasks)) { + logger.debug( + "[{} mode] cancelling task [{}] due to high resource consumption [{}]", + mode.getName(), + taskCancellation.getTask().getId(), + taskCancellation.getReasonString() + ); + + if (mode != SearchBackpressureMode.ENFORCED) { + continue; + } + + // Independently remove tokens from both token buckets. + boolean rateLimitReached = taskCancellationRateLimiter.get().request() == false; + boolean ratioLimitReached = taskCancellationRatioLimiter.get().request() == false; + + // Stop cancelling tasks if there are no tokens in either of the two token buckets. + if (rateLimitReached && ratioLimitReached) { + logger.debug("task cancellation limit reached"); + state.incrementLimitReachedCount(); + break; + } + + taskCancellation.cancel(); + } + } + + /** + * Returns true if the node is in duress consecutively for the past 'n' observations. + */ + boolean isNodeInDuress() { + boolean isNodeInDuress = false; + int numSuccessiveBreaches = getSettings().getNodeDuressSettings().getNumSuccessiveBreaches(); + + for (NodeDuressTracker tracker : nodeDuressTrackers) { + if (tracker.check() >= numSuccessiveBreaches) { + isNodeInDuress = true; // not breaking the loop so that each tracker's streak gets updated. + } + } + + return isNodeInDuress; + } + + /** + * Returns true if the increase in heap usage is due to search requests. + */ + boolean isHeapUsageDominatedBySearch(List searchShardTasks) { + long usage = searchShardTasks.stream().mapToLong(task -> task.getTotalResourceStats().getMemoryInBytes()).sum(); + long threshold = getSettings().getSearchShardTaskSettings().getTotalHeapBytesThreshold(); + if (usage < threshold) { + logger.debug("heap usage not dominated by search requests [{}/{}]", usage, threshold); + return false; + } + + return true; + } + + /** + * Filters and returns the list of currently running SearchShardTasks. + */ + List getSearchShardTasks() { + return taskResourceTrackingService.getResourceAwareTasks() + .values() + .stream() + .filter(task -> task instanceof SearchShardTask) + .map(task -> (SearchShardTask) task) + .collect(Collectors.toUnmodifiableList()); + } + + /** + * Returns a TaskCancellation wrapper containing the list of reasons (possibly zero), along with an overall + * cancellation score for the given task. Cancelling a task with a higher score has better chance of recovering the + * node from duress. + */ + TaskCancellation getTaskCancellation(CancellableTask task) { + List reasons = new ArrayList<>(); + List callbacks = new ArrayList<>(); + + for (TaskResourceUsageTracker tracker : taskResourceUsageTrackers) { + Optional reason = tracker.checkAndMaybeGetCancellationReason(task); + if (reason.isPresent()) { + reasons.add(reason.get()); + callbacks.add(tracker::incrementCancellations); + } + } + + if (task instanceof SearchShardTask) { + callbacks.add(state::incrementCancellationCount); + } + + return new TaskCancellation(task, reasons, callbacks); + } + + /** + * Returns a list of TaskCancellations sorted by descending order of their cancellation scores. + */ + List getTaskCancellations(List tasks) { + return tasks.stream() + .map(this::getTaskCancellation) + .filter(TaskCancellation::isEligibleForCancellation) + .sorted(Comparator.reverseOrder()) + .collect(Collectors.toUnmodifiableList()); + } + + SearchBackpressureSettings getSettings() { + return settings; + } + + SearchBackpressureState getState() { + return state; + } + + @Override + public void onTaskCompleted(Task task) { + if (getSettings().getMode() == SearchBackpressureMode.DISABLED) { + return; + } + + if (task instanceof SearchShardTask == false) { + return; + } + + SearchShardTask searchShardTask = (SearchShardTask) task; + if (searchShardTask.isCancelled() == false) { + state.incrementCompletionCount(); + } + + List exceptions = new ArrayList<>(); + for (TaskResourceUsageTracker tracker : taskResourceUsageTrackers) { + try { + tracker.update(searchShardTask); + } catch (Exception e) { + exceptions.add(e); + } + } + ExceptionsHelper.maybeThrowRuntimeAndSuppress(exceptions); + } + + @Override + public void onCancellationRatioChanged() { + taskCancellationRatioLimiter.set( + new TokenBucket(state::getCompletionCount, getSettings().getCancellationRatio(), getSettings().getCancellationBurst()) + ); + } + + @Override + public void onCancellationRateChanged() { + taskCancellationRateLimiter.set( + new TokenBucket(timeNanosSupplier, getSettings().getCancellationRateNanos(), getSettings().getCancellationBurst()) + ); + } + + @Override + public void onCancellationBurstChanged() { + onCancellationRatioChanged(); + onCancellationRateChanged(); + } + + @Override + protected void doStart() { + scheduledFuture = threadPool.scheduleWithFixedDelay(() -> { + try { + doRun(); + } catch (Exception e) { + logger.debug("failure in search search backpressure", e); + } + }, getSettings().getInterval(), ThreadPool.Names.GENERIC); + } + + @Override + protected void doStop() { + if (scheduledFuture != null) { + scheduledFuture.cancel(); + } + } + + @Override + protected void doClose() throws IOException {} + + public SearchBackpressureStats nodeStats() { + List searchShardTasks = getSearchShardTasks(); + + SearchShardTaskStats searchShardTaskStats = new SearchShardTaskStats( + state.getCancellationCount(), + state.getLimitReachedCount(), + taskResourceUsageTrackers.stream() + .collect(Collectors.toUnmodifiableMap(t -> TaskResourceUsageTrackerType.fromName(t.name()), t -> t.stats(searchShardTasks))) + ); + + return new SearchBackpressureStats(searchShardTaskStats, getSettings().getMode()); + } +} diff --git a/server/src/main/java/org/opensearch/search/backpressure/SearchBackpressureState.java b/server/src/main/java/org/opensearch/search/backpressure/SearchBackpressureState.java new file mode 100644 index 0000000000000..a62231ec29ede --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/SearchBackpressureState.java @@ -0,0 +1,57 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure; + +import java.util.concurrent.atomic.AtomicLong; + +/** + * Tracks the current state of task completions and cancellations. + * + * @opensearch.internal + */ +public class SearchBackpressureState { + /** + * The number of successful task completions. + */ + private final AtomicLong completionCount = new AtomicLong(); + + /** + * The number of task cancellations due to limit breaches. + */ + private final AtomicLong cancellationCount = new AtomicLong(); + + /** + * The number of times task cancellation limit was reached. + */ + private final AtomicLong limitReachedCount = new AtomicLong(); + + public long getCompletionCount() { + return completionCount.get(); + } + + long incrementCompletionCount() { + return completionCount.incrementAndGet(); + } + + public long getCancellationCount() { + return cancellationCount.get(); + } + + long incrementCancellationCount() { + return cancellationCount.incrementAndGet(); + } + + public long getLimitReachedCount() { + return limitReachedCount.get(); + } + + long incrementLimitReachedCount() { + return limitReachedCount.incrementAndGet(); + } +} diff --git a/server/src/main/java/org/opensearch/search/backpressure/package-info.java b/server/src/main/java/org/opensearch/search/backpressure/package-info.java new file mode 100644 index 0000000000000..36d216993b2fc --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * This package contains classes responsible for search backpressure. + */ +package org.opensearch.search.backpressure; diff --git a/server/src/main/java/org/opensearch/search/backpressure/settings/NodeDuressSettings.java b/server/src/main/java/org/opensearch/search/backpressure/settings/NodeDuressSettings.java new file mode 100644 index 0000000000000..09c1e4fcef46c --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/settings/NodeDuressSettings.java @@ -0,0 +1,99 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.settings; + +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; + +/** + * Defines the settings for a node to be considered in duress. + * + * @opensearch.internal + */ +public class NodeDuressSettings { + private static class Defaults { + private static final int NUM_SUCCESSIVE_BREACHES = 3; + private static final double CPU_THRESHOLD = 0.9; + private static final double HEAP_THRESHOLD = 0.7; + } + + /** + * Defines the number of successive limit breaches after the node is marked "in duress". + */ + private volatile int numSuccessiveBreaches; + public static final Setting SETTING_NUM_SUCCESSIVE_BREACHES = Setting.intSetting( + "search_backpressure.node_duress.num_successive_breaches", + Defaults.NUM_SUCCESSIVE_BREACHES, + 1, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + /** + * Defines the CPU usage threshold (in percentage) for a node to be considered "in duress". + */ + private volatile double cpuThreshold; + public static final Setting SETTING_CPU_THRESHOLD = Setting.doubleSetting( + "search_backpressure.node_duress.cpu_threshold", + Defaults.CPU_THRESHOLD, + 0.0, + 1.0, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + /** + * Defines the heap usage threshold (in percentage) for a node to be considered "in duress". + */ + private volatile double heapThreshold; + public static final Setting SETTING_HEAP_THRESHOLD = Setting.doubleSetting( + "search_backpressure.node_duress.heap_threshold", + Defaults.HEAP_THRESHOLD, + 0.0, + 1.0, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + public NodeDuressSettings(Settings settings, ClusterSettings clusterSettings) { + numSuccessiveBreaches = SETTING_NUM_SUCCESSIVE_BREACHES.get(settings); + clusterSettings.addSettingsUpdateConsumer(SETTING_NUM_SUCCESSIVE_BREACHES, this::setNumSuccessiveBreaches); + + cpuThreshold = SETTING_CPU_THRESHOLD.get(settings); + clusterSettings.addSettingsUpdateConsumer(SETTING_CPU_THRESHOLD, this::setCpuThreshold); + + heapThreshold = SETTING_HEAP_THRESHOLD.get(settings); + clusterSettings.addSettingsUpdateConsumer(SETTING_HEAP_THRESHOLD, this::setHeapThreshold); + } + + public int getNumSuccessiveBreaches() { + return numSuccessiveBreaches; + } + + private void setNumSuccessiveBreaches(int numSuccessiveBreaches) { + this.numSuccessiveBreaches = numSuccessiveBreaches; + } + + public double getCpuThreshold() { + return cpuThreshold; + } + + private void setCpuThreshold(double cpuThreshold) { + this.cpuThreshold = cpuThreshold; + } + + public double getHeapThreshold() { + return heapThreshold; + } + + private void setHeapThreshold(double heapThreshold) { + this.heapThreshold = heapThreshold; + } +} diff --git a/server/src/main/java/org/opensearch/search/backpressure/settings/SearchBackpressureMode.java b/server/src/main/java/org/opensearch/search/backpressure/settings/SearchBackpressureMode.java new file mode 100644 index 0000000000000..a0e4e3c0d25aa --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/settings/SearchBackpressureMode.java @@ -0,0 +1,52 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.settings; + +/** + * Defines the search backpressure mode. + */ +public enum SearchBackpressureMode { + /** + * SearchBackpressureService is completely disabled. + */ + DISABLED("disabled"), + + /** + * SearchBackpressureService only monitors the resource usage of running tasks. + */ + MONITOR_ONLY("monitor_only"), + + /** + * SearchBackpressureService monitors and rejects tasks that exceed resource usage thresholds. + */ + ENFORCED("enforced"); + + private final String name; + + SearchBackpressureMode(String name) { + this.name = name; + } + + public String getName() { + return name; + } + + public static SearchBackpressureMode fromName(String name) { + switch (name) { + case "disabled": + return DISABLED; + case "monitor_only": + return MONITOR_ONLY; + case "enforced": + return ENFORCED; + } + + throw new IllegalArgumentException("Invalid SearchBackpressureMode: " + name); + } +} diff --git a/server/src/main/java/org/opensearch/search/backpressure/settings/SearchBackpressureSettings.java b/server/src/main/java/org/opensearch/search/backpressure/settings/SearchBackpressureSettings.java new file mode 100644 index 0000000000000..df2c04a730fbc --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/settings/SearchBackpressureSettings.java @@ -0,0 +1,212 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.settings; + +import org.opensearch.ExceptionsHelper; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; + +/** + * Settings related to search backpressure and cancellation of in-flight requests. + * + * @opensearch.internal + */ +public class SearchBackpressureSettings { + private static class Defaults { + private static final long INTERVAL_MILLIS = 1000; + private static final String MODE = "monitor_only"; + + private static final double CANCELLATION_RATIO = 0.1; + private static final double CANCELLATION_RATE = 0.003; + private static final double CANCELLATION_BURST = 10.0; + } + + /** + * Defines the interval (in millis) at which the SearchBackpressureService monitors and cancels tasks. + */ + private final TimeValue interval; + public static final Setting SETTING_INTERVAL_MILLIS = Setting.longSetting( + "search_backpressure.interval_millis", + Defaults.INTERVAL_MILLIS, + 1, + Setting.Property.NodeScope + ); + + /** + * Defines the search backpressure mode. It can be either "disabled", "monitor_only" or "enforced". + */ + private volatile SearchBackpressureMode mode; + public static final Setting SETTING_MODE = Setting.simpleString( + "search_backpressure.mode", + Defaults.MODE, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + /** + * Defines the percentage of tasks to cancel relative to the number of successful task completions. + * In other words, it is the number of tokens added to the bucket on each successful task completion. + */ + private volatile double cancellationRatio; + public static final Setting SETTING_CANCELLATION_RATIO = Setting.doubleSetting( + "search_backpressure.cancellation_ratio", + Defaults.CANCELLATION_RATIO, + 0.0, + 1.0, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + /** + * Defines the number of tasks to cancel per unit time (in millis). + * In other words, it is the number of tokens added to the bucket each millisecond. + */ + private volatile double cancellationRate; + public static final Setting SETTING_CANCELLATION_RATE = Setting.doubleSetting( + "search_backpressure.cancellation_rate", + Defaults.CANCELLATION_RATE, + 0.0, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + /** + * Defines the maximum number of tasks that can be cancelled before being rate-limited. + */ + private volatile double cancellationBurst; + public static final Setting SETTING_CANCELLATION_BURST = Setting.doubleSetting( + "search_backpressure.cancellation_burst", + Defaults.CANCELLATION_BURST, + 1.0, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + /** + * Callback listeners. + */ + public interface Listener { + void onCancellationRatioChanged(); + + void onCancellationRateChanged(); + + void onCancellationBurstChanged(); + } + + private final List listeners = new ArrayList<>(); + private final Settings settings; + private final ClusterSettings clusterSettings; + private final NodeDuressSettings nodeDuressSettings; + private final SearchShardTaskSettings searchShardTaskSettings; + + public SearchBackpressureSettings(Settings settings, ClusterSettings clusterSettings) { + this.settings = settings; + this.clusterSettings = clusterSettings; + this.nodeDuressSettings = new NodeDuressSettings(settings, clusterSettings); + this.searchShardTaskSettings = new SearchShardTaskSettings(settings, clusterSettings); + + interval = new TimeValue(SETTING_INTERVAL_MILLIS.get(settings)); + + mode = SearchBackpressureMode.fromName(SETTING_MODE.get(settings)); + clusterSettings.addSettingsUpdateConsumer(SETTING_MODE, s -> this.setMode(SearchBackpressureMode.fromName(s))); + + cancellationRatio = SETTING_CANCELLATION_RATIO.get(settings); + clusterSettings.addSettingsUpdateConsumer(SETTING_CANCELLATION_RATIO, this::setCancellationRatio); + + cancellationRate = SETTING_CANCELLATION_RATE.get(settings); + clusterSettings.addSettingsUpdateConsumer(SETTING_CANCELLATION_RATE, this::setCancellationRate); + + cancellationBurst = SETTING_CANCELLATION_BURST.get(settings); + clusterSettings.addSettingsUpdateConsumer(SETTING_CANCELLATION_BURST, this::setCancellationBurst); + } + + public void addListener(Listener listener) { + listeners.add(listener); + } + + public Settings getSettings() { + return settings; + } + + public ClusterSettings getClusterSettings() { + return clusterSettings; + } + + public NodeDuressSettings getNodeDuressSettings() { + return nodeDuressSettings; + } + + public SearchShardTaskSettings getSearchShardTaskSettings() { + return searchShardTaskSettings; + } + + public TimeValue getInterval() { + return interval; + } + + public SearchBackpressureMode getMode() { + return mode; + } + + public void setMode(SearchBackpressureMode mode) { + this.mode = mode; + } + + public double getCancellationRatio() { + return cancellationRatio; + } + + private void setCancellationRatio(double cancellationRatio) { + this.cancellationRatio = cancellationRatio; + notifyListeners(Listener::onCancellationRatioChanged); + } + + public double getCancellationRate() { + return cancellationRate; + } + + public double getCancellationRateNanos() { + return getCancellationRate() / TimeUnit.MILLISECONDS.toNanos(1); // rate per nanoseconds + } + + private void setCancellationRate(double cancellationRate) { + this.cancellationRate = cancellationRate; + notifyListeners(Listener::onCancellationRateChanged); + } + + public double getCancellationBurst() { + return cancellationBurst; + } + + private void setCancellationBurst(double cancellationBurst) { + this.cancellationBurst = cancellationBurst; + notifyListeners(Listener::onCancellationBurstChanged); + } + + private void notifyListeners(Consumer consumer) { + List exceptions = new ArrayList<>(); + + for (Listener listener : listeners) { + try { + consumer.accept(listener); + } catch (Exception e) { + exceptions.add(e); + } + } + + ExceptionsHelper.maybeThrowRuntimeAndSuppress(exceptions); + } +} diff --git a/server/src/main/java/org/opensearch/search/backpressure/settings/SearchShardTaskSettings.java b/server/src/main/java/org/opensearch/search/backpressure/settings/SearchShardTaskSettings.java new file mode 100644 index 0000000000000..7e40f1c0eab53 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/settings/SearchShardTaskSettings.java @@ -0,0 +1,58 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.settings; + +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.monitor.jvm.JvmStats; + +/** + * Defines the settings related to the cancellation of SearchShardTasks. + * + * @opensearch.internal + */ +public class SearchShardTaskSettings { + private static final long HEAP_SIZE_BYTES = JvmStats.jvmStats().getMem().getHeapMax().getBytes(); + + private static class Defaults { + private static final double TOTAL_HEAP_PERCENT_THRESHOLD = 0.05; + } + + /** + * Defines the heap usage threshold (in percentage) for the sum of heap usages across all search shard tasks + * before in-flight cancellation is applied. + */ + private volatile double totalHeapPercentThreshold; + public static final Setting SETTING_TOTAL_HEAP_PERCENT_THRESHOLD = Setting.doubleSetting( + "search_backpressure.search_shard_task.total_heap_percent_threshold", + Defaults.TOTAL_HEAP_PERCENT_THRESHOLD, + 0.0, + 1.0, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + public SearchShardTaskSettings(Settings settings, ClusterSettings clusterSettings) { + totalHeapPercentThreshold = SETTING_TOTAL_HEAP_PERCENT_THRESHOLD.get(settings); + clusterSettings.addSettingsUpdateConsumer(SETTING_TOTAL_HEAP_PERCENT_THRESHOLD, this::setTotalHeapPercentThreshold); + } + + public double getTotalHeapPercentThreshold() { + return totalHeapPercentThreshold; + } + + public long getTotalHeapBytesThreshold() { + return (long) (HEAP_SIZE_BYTES * getTotalHeapPercentThreshold()); + } + + private void setTotalHeapPercentThreshold(double totalHeapPercentThreshold) { + this.totalHeapPercentThreshold = totalHeapPercentThreshold; + } +} diff --git a/server/src/main/java/org/opensearch/search/backpressure/settings/package-info.java b/server/src/main/java/org/opensearch/search/backpressure/settings/package-info.java new file mode 100644 index 0000000000000..a853a139b096b --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/settings/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * This package contains settings for search backpressure. + */ +package org.opensearch.search.backpressure.settings; diff --git a/server/src/main/java/org/opensearch/search/backpressure/stats/SearchBackpressureStats.java b/server/src/main/java/org/opensearch/search/backpressure/stats/SearchBackpressureStats.java new file mode 100644 index 0000000000000..3aec0dfc579c5 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/stats/SearchBackpressureStats.java @@ -0,0 +1,63 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.stats; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.common.xcontent.ToXContentFragment; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.search.backpressure.settings.SearchBackpressureMode; + +import java.io.IOException; +import java.util.Objects; + +/** + * Stats related to search backpressure. + */ +public class SearchBackpressureStats implements ToXContentFragment, Writeable { + private final SearchShardTaskStats searchShardTaskStats; + private final SearchBackpressureMode mode; + + public SearchBackpressureStats(SearchShardTaskStats searchShardTaskStats, SearchBackpressureMode mode) { + this.searchShardTaskStats = searchShardTaskStats; + this.mode = mode; + } + + public SearchBackpressureStats(StreamInput in) throws IOException { + this(new SearchShardTaskStats(in), SearchBackpressureMode.fromName(in.readString())); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.startObject("search_backpressure") + .field("search_shard_task", searchShardTaskStats) + .field("mode", mode.getName()) + .endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + searchShardTaskStats.writeTo(out); + out.writeString(mode.getName()); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SearchBackpressureStats that = (SearchBackpressureStats) o; + return searchShardTaskStats.equals(that.searchShardTaskStats) && mode == that.mode; + } + + @Override + public int hashCode() { + return Objects.hash(searchShardTaskStats, mode); + } +} diff --git a/server/src/main/java/org/opensearch/search/backpressure/stats/SearchShardTaskStats.java b/server/src/main/java/org/opensearch/search/backpressure/stats/SearchShardTaskStats.java new file mode 100644 index 0000000000000..4d532cfb12f80 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/stats/SearchShardTaskStats.java @@ -0,0 +1,98 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.stats; + +import org.opensearch.common.collect.MapBuilder; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.common.xcontent.ToXContentObject; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.search.backpressure.trackers.CpuUsageTracker; +import org.opensearch.search.backpressure.trackers.ElapsedTimeTracker; +import org.opensearch.search.backpressure.trackers.HeapUsageTracker; +import org.opensearch.search.backpressure.trackers.TaskResourceUsageTracker; +import org.opensearch.search.backpressure.trackers.TaskResourceUsageTrackerType; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +/** + * Stats related to cancelled search shard tasks. + */ +public class SearchShardTaskStats implements ToXContentObject, Writeable { + private final long cancellationCount; + private final long limitReachedCount; + private final Map resourceUsageTrackerStats; + + public SearchShardTaskStats( + long cancellationCount, + long limitReachedCount, + Map resourceUsageTrackerStats + ) { + this.cancellationCount = cancellationCount; + this.limitReachedCount = limitReachedCount; + this.resourceUsageTrackerStats = resourceUsageTrackerStats; + } + + public SearchShardTaskStats(StreamInput in) throws IOException { + this.cancellationCount = in.readVLong(); + this.limitReachedCount = in.readVLong(); + + MapBuilder builder = new MapBuilder<>(); + builder.put(TaskResourceUsageTrackerType.CPU_USAGE_TRACKER, in.readOptionalWriteable(CpuUsageTracker.Stats::new)); + builder.put(TaskResourceUsageTrackerType.HEAP_USAGE_TRACKER, in.readOptionalWriteable(HeapUsageTracker.Stats::new)); + builder.put(TaskResourceUsageTrackerType.ELAPSED_TIME_TRACKER, in.readOptionalWriteable(ElapsedTimeTracker.Stats::new)); + this.resourceUsageTrackerStats = builder.immutableMap(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.startObject("resource_tracker_stats"); + for (Map.Entry entry : resourceUsageTrackerStats.entrySet()) { + builder.field(entry.getKey().getName(), entry.getValue()); + } + builder.endObject(); + + builder.startObject("cancellation_stats") + .field("cancellation_count", cancellationCount) + .field("cancellation_limit_reached_count", limitReachedCount) + .endObject(); + + return builder.endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVLong(cancellationCount); + out.writeVLong(limitReachedCount); + + out.writeOptionalWriteable(resourceUsageTrackerStats.get(TaskResourceUsageTrackerType.CPU_USAGE_TRACKER)); + out.writeOptionalWriteable(resourceUsageTrackerStats.get(TaskResourceUsageTrackerType.HEAP_USAGE_TRACKER)); + out.writeOptionalWriteable(resourceUsageTrackerStats.get(TaskResourceUsageTrackerType.ELAPSED_TIME_TRACKER)); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SearchShardTaskStats that = (SearchShardTaskStats) o; + return cancellationCount == that.cancellationCount + && limitReachedCount == that.limitReachedCount + && resourceUsageTrackerStats.equals(that.resourceUsageTrackerStats); + } + + @Override + public int hashCode() { + return Objects.hash(cancellationCount, limitReachedCount, resourceUsageTrackerStats); + } +} diff --git a/server/src/main/java/org/opensearch/search/backpressure/stats/package-info.java b/server/src/main/java/org/opensearch/search/backpressure/stats/package-info.java new file mode 100644 index 0000000000000..514b274c2cf1a --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/stats/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * This package contains models required for the search backpressure stats API response. + */ +package org.opensearch.search.backpressure.stats; diff --git a/server/src/main/java/org/opensearch/search/backpressure/trackers/CpuUsageTracker.java b/server/src/main/java/org/opensearch/search/backpressure/trackers/CpuUsageTracker.java new file mode 100644 index 0000000000000..21bb3af32ae08 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/trackers/CpuUsageTracker.java @@ -0,0 +1,143 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.trackers; + +import org.opensearch.common.settings.Setting; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.search.backpressure.settings.SearchBackpressureSettings; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskCancellation; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.TimeUnit; + +import static org.opensearch.search.backpressure.trackers.TaskResourceUsageTrackerType.CPU_USAGE_TRACKER; + +/** + * CpuUsageTracker evaluates if the task has consumed too many CPU cycles than allowed. + * + * @opensearch.internal + */ +public class CpuUsageTracker extends TaskResourceUsageTracker { + private static class Defaults { + private static final long CPU_TIME_MILLIS_THRESHOLD = 15000; + } + + /** + * Defines the CPU usage threshold (in millis) for an individual task before it is considered for cancellation. + */ + private volatile long cpuTimeMillisThreshold; + public static final Setting SETTING_CPU_TIME_MILLIS_THRESHOLD = Setting.longSetting( + "search_backpressure.search_shard_task.cpu_time_millis_threshold", + Defaults.CPU_TIME_MILLIS_THRESHOLD, + 0, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + public CpuUsageTracker(SearchBackpressureSettings settings) { + this.cpuTimeMillisThreshold = SETTING_CPU_TIME_MILLIS_THRESHOLD.get(settings.getSettings()); + settings.getClusterSettings().addSettingsUpdateConsumer(SETTING_CPU_TIME_MILLIS_THRESHOLD, this::setCpuTimeMillisThreshold); + } + + @Override + public String name() { + return CPU_USAGE_TRACKER.getName(); + } + + @Override + public Optional checkAndMaybeGetCancellationReason(Task task) { + long usage = task.getTotalResourceStats().getCpuTimeInNanos(); + long threshold = getCpuTimeNanosThreshold(); + + if (usage < threshold) { + return Optional.empty(); + } + + return Optional.of( + new TaskCancellation.Reason( + "cpu usage exceeded [" + + new TimeValue(usage, TimeUnit.NANOSECONDS) + + " >= " + + new TimeValue(threshold, TimeUnit.NANOSECONDS) + + "]", + 1 // TODO: fine-tune the cancellation score/weight + ) + ); + } + + public long getCpuTimeNanosThreshold() { + return TimeUnit.MILLISECONDS.toNanos(cpuTimeMillisThreshold); + } + + public void setCpuTimeMillisThreshold(long cpuTimeMillisThreshold) { + this.cpuTimeMillisThreshold = cpuTimeMillisThreshold; + } + + @Override + public TaskResourceUsageTracker.Stats stats(List activeTasks) { + long currentMax = activeTasks.stream().mapToLong(t -> t.getTotalResourceStats().getCpuTimeInNanos()).max().orElse(0); + long currentAvg = (long) activeTasks.stream().mapToLong(t -> t.getTotalResourceStats().getCpuTimeInNanos()).average().orElse(0); + return new Stats(getCancellations(), currentMax, currentAvg); + } + + /** + * Stats related to CpuUsageTracker. + */ + public static class Stats implements TaskResourceUsageTracker.Stats { + private final long cancellationCount; + private final long currentMax; + private final long currentAvg; + + public Stats(long cancellationCount, long currentMax, long currentAvg) { + this.cancellationCount = cancellationCount; + this.currentMax = currentMax; + this.currentAvg = currentAvg; + } + + public Stats(StreamInput in) throws IOException { + this(in.readVLong(), in.readVLong(), in.readVLong()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.startObject() + .field("cancellation_count", cancellationCount) + .humanReadableField("current_max_millis", "current_max", new TimeValue(currentMax, TimeUnit.NANOSECONDS)) + .humanReadableField("current_avg_millis", "current_avg", new TimeValue(currentAvg, TimeUnit.NANOSECONDS)) + .endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVLong(cancellationCount); + out.writeVLong(currentMax); + out.writeVLong(currentAvg); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Stats stats = (Stats) o; + return cancellationCount == stats.cancellationCount && currentMax == stats.currentMax && currentAvg == stats.currentAvg; + } + + @Override + public int hashCode() { + return Objects.hash(cancellationCount, currentMax, currentAvg); + } + } +} diff --git a/server/src/main/java/org/opensearch/search/backpressure/trackers/ElapsedTimeTracker.java b/server/src/main/java/org/opensearch/search/backpressure/trackers/ElapsedTimeTracker.java new file mode 100644 index 0000000000000..10e53e2bce5ae --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/trackers/ElapsedTimeTracker.java @@ -0,0 +1,148 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.trackers; + +import org.opensearch.common.settings.Setting; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.search.backpressure.settings.SearchBackpressureSettings; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskCancellation; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.function.LongSupplier; + +import static org.opensearch.search.backpressure.trackers.TaskResourceUsageTrackerType.ELAPSED_TIME_TRACKER; + +/** + * ElapsedTimeTracker evaluates if the task has been running for more time than allowed. + * + * @opensearch.internal + */ +public class ElapsedTimeTracker extends TaskResourceUsageTracker { + private static class Defaults { + private static final long ELAPSED_TIME_MILLIS_THRESHOLD = 30000; + } + + /** + * Defines the elapsed time threshold (in millis) for an individual task before it is considered for cancellation. + */ + private volatile long elapsedTimeMillisThreshold; + public static final Setting SETTING_ELAPSED_TIME_MILLIS_THRESHOLD = Setting.longSetting( + "search_backpressure.search_shard_task.elapsed_time_millis_threshold", + Defaults.ELAPSED_TIME_MILLIS_THRESHOLD, + 0, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + private final LongSupplier timeNanosSupplier; + + public ElapsedTimeTracker(SearchBackpressureSettings settings, LongSupplier timeNanosSupplier) { + this.timeNanosSupplier = timeNanosSupplier; + this.elapsedTimeMillisThreshold = SETTING_ELAPSED_TIME_MILLIS_THRESHOLD.get(settings.getSettings()); + settings.getClusterSettings().addSettingsUpdateConsumer(SETTING_ELAPSED_TIME_MILLIS_THRESHOLD, this::setElapsedTimeMillisThreshold); + } + + @Override + public String name() { + return ELAPSED_TIME_TRACKER.getName(); + } + + @Override + public Optional checkAndMaybeGetCancellationReason(Task task) { + long usage = timeNanosSupplier.getAsLong() - task.getStartTimeNanos(); + long threshold = getElapsedTimeNanosThreshold(); + + if (usage < threshold) { + return Optional.empty(); + } + + return Optional.of( + new TaskCancellation.Reason( + "elapsed time exceeded [" + + new TimeValue(usage, TimeUnit.NANOSECONDS) + + " >= " + + new TimeValue(threshold, TimeUnit.NANOSECONDS) + + "]", + 1 // TODO: fine-tune the cancellation score/weight + ) + ); + } + + public long getElapsedTimeNanosThreshold() { + return TimeUnit.MILLISECONDS.toNanos(elapsedTimeMillisThreshold); + } + + public void setElapsedTimeMillisThreshold(long elapsedTimeMillisThreshold) { + this.elapsedTimeMillisThreshold = elapsedTimeMillisThreshold; + } + + @Override + public TaskResourceUsageTracker.Stats stats(List activeTasks) { + long now = timeNanosSupplier.getAsLong(); + long currentMax = activeTasks.stream().mapToLong(t -> now - t.getStartTimeNanos()).max().orElse(0); + long currentAvg = (long) activeTasks.stream().mapToLong(t -> now - t.getStartTimeNanos()).average().orElse(0); + return new Stats(getCancellations(), currentMax, currentAvg); + } + + /** + * Stats related to ElapsedTimeTracker. + */ + public static class Stats implements TaskResourceUsageTracker.Stats { + private final long cancellationCount; + private final long currentMax; + private final long currentAvg; + + public Stats(long cancellationCount, long currentMax, long currentAvg) { + this.cancellationCount = cancellationCount; + this.currentMax = currentMax; + this.currentAvg = currentAvg; + } + + public Stats(StreamInput in) throws IOException { + this(in.readVLong(), in.readVLong(), in.readVLong()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.startObject() + .field("cancellation_count", cancellationCount) + .humanReadableField("current_max_millis", "current_max", new TimeValue(currentMax, TimeUnit.NANOSECONDS)) + .humanReadableField("current_avg_millis", "current_avg", new TimeValue(currentAvg, TimeUnit.NANOSECONDS)) + .endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVLong(cancellationCount); + out.writeVLong(currentMax); + out.writeVLong(currentAvg); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Stats stats = (Stats) o; + return cancellationCount == stats.cancellationCount && currentMax == stats.currentMax && currentAvg == stats.currentAvg; + } + + @Override + public int hashCode() { + return Objects.hash(cancellationCount, currentMax, currentAvg); + } + } +} diff --git a/server/src/main/java/org/opensearch/search/backpressure/trackers/HeapUsageTracker.java b/server/src/main/java/org/opensearch/search/backpressure/trackers/HeapUsageTracker.java new file mode 100644 index 0000000000000..d1a264609e522 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/trackers/HeapUsageTracker.java @@ -0,0 +1,216 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.trackers; + +import org.opensearch.common.settings.Setting; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.unit.ByteSizeValue; +import org.opensearch.common.util.MovingAverage; +import org.opensearch.monitor.jvm.JvmStats; +import org.opensearch.search.backpressure.settings.SearchBackpressureSettings; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskCancellation; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; + +import static org.opensearch.search.backpressure.trackers.TaskResourceUsageTrackerType.HEAP_USAGE_TRACKER; + +/** + * HeapUsageTracker evaluates if the task has consumed too much heap than allowed. + * It also compares the task's heap usage against a historical moving average of previously completed tasks. + * + * @opensearch.internal + */ +public class HeapUsageTracker extends TaskResourceUsageTracker { + private static final long HEAP_SIZE_BYTES = JvmStats.jvmStats().getMem().getHeapMax().getBytes(); + + private static class Defaults { + private static final double HEAP_PERCENT_THRESHOLD = 0.005; + private static final double HEAP_VARIANCE_THRESHOLD = 2.0; + private static final int HEAP_MOVING_AVERAGE_WINDOW_SIZE = 100; + } + + /** + * Defines the heap usage threshold (in percentage) for an individual task before it is considered for cancellation. + */ + private volatile double heapPercentThreshold; + public static final Setting SETTING_HEAP_PERCENT_THRESHOLD = Setting.doubleSetting( + "search_backpressure.search_shard_task.heap_percent_threshold", + Defaults.HEAP_PERCENT_THRESHOLD, + 0.0, + 1.0, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + /** + * Defines the heap usage variance for an individual task before it is considered for cancellation. + * A task is considered for cancellation when taskHeapUsage is greater than or equal to heapUsageMovingAverage * variance. + */ + private volatile double heapVarianceThreshold; + public static final Setting SETTING_HEAP_VARIANCE_THRESHOLD = Setting.doubleSetting( + "search_backpressure.search_shard_task.heap_variance", + Defaults.HEAP_VARIANCE_THRESHOLD, + 0.0, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + /** + * Defines the window size to calculate the moving average of heap usage of completed tasks. + */ + private volatile int heapMovingAverageWindowSize; + public static final Setting SETTING_HEAP_MOVING_AVERAGE_WINDOW_SIZE = Setting.intSetting( + "search_backpressure.search_shard_task.heap_moving_average_window_size", + Defaults.HEAP_MOVING_AVERAGE_WINDOW_SIZE, + 0, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + private final AtomicReference movingAverageReference; + + public HeapUsageTracker(SearchBackpressureSettings settings) { + heapPercentThreshold = SETTING_HEAP_PERCENT_THRESHOLD.get(settings.getSettings()); + settings.getClusterSettings().addSettingsUpdateConsumer(SETTING_HEAP_PERCENT_THRESHOLD, this::setHeapPercentThreshold); + + heapVarianceThreshold = SETTING_HEAP_VARIANCE_THRESHOLD.get(settings.getSettings()); + settings.getClusterSettings().addSettingsUpdateConsumer(SETTING_HEAP_VARIANCE_THRESHOLD, this::setHeapVarianceThreshold); + + heapMovingAverageWindowSize = SETTING_HEAP_MOVING_AVERAGE_WINDOW_SIZE.get(settings.getSettings()); + settings.getClusterSettings() + .addSettingsUpdateConsumer(SETTING_HEAP_MOVING_AVERAGE_WINDOW_SIZE, this::setHeapMovingAverageWindowSize); + + this.movingAverageReference = new AtomicReference<>(new MovingAverage(heapMovingAverageWindowSize)); + } + + @Override + public String name() { + return HEAP_USAGE_TRACKER.getName(); + } + + @Override + public void update(Task task) { + movingAverageReference.get().record(task.getTotalResourceStats().getMemoryInBytes()); + } + + @Override + public Optional checkAndMaybeGetCancellationReason(Task task) { + MovingAverage movingAverage = movingAverageReference.get(); + + // There haven't been enough measurements. + if (movingAverage.isReady() == false) { + return Optional.empty(); + } + + double currentUsage = task.getTotalResourceStats().getMemoryInBytes(); + double averageUsage = movingAverage.getAverage(); + double allowedUsage = averageUsage * getHeapVarianceThreshold(); + + if (currentUsage < getHeapBytesThreshold() || currentUsage < allowedUsage) { + return Optional.empty(); + } + + return Optional.of( + new TaskCancellation.Reason( + "heap usage exceeded [" + new ByteSizeValue((long) currentUsage) + " >= " + new ByteSizeValue((long) allowedUsage) + "]", + (int) (currentUsage / averageUsage) // TODO: fine-tune the cancellation score/weight + ) + ); + } + + public long getHeapBytesThreshold() { + return (long) (HEAP_SIZE_BYTES * heapPercentThreshold); + } + + public void setHeapPercentThreshold(double heapPercentThreshold) { + this.heapPercentThreshold = heapPercentThreshold; + } + + public double getHeapVarianceThreshold() { + return heapVarianceThreshold; + } + + public void setHeapVarianceThreshold(double heapVarianceThreshold) { + this.heapVarianceThreshold = heapVarianceThreshold; + } + + public void setHeapMovingAverageWindowSize(int heapMovingAverageWindowSize) { + this.heapMovingAverageWindowSize = heapMovingAverageWindowSize; + this.movingAverageReference.set(new MovingAverage(heapMovingAverageWindowSize)); + } + + @Override + public TaskResourceUsageTracker.Stats stats(List activeTasks) { + long currentMax = activeTasks.stream().mapToLong(t -> t.getTotalResourceStats().getMemoryInBytes()).max().orElse(0); + long currentAvg = (long) activeTasks.stream().mapToLong(t -> t.getTotalResourceStats().getMemoryInBytes()).average().orElse(0); + return new Stats(getCancellations(), currentMax, currentAvg, (long) movingAverageReference.get().getAverage()); + } + + /** + * Stats related to HeapUsageTracker. + */ + public static class Stats implements TaskResourceUsageTracker.Stats { + private final long cancellationCount; + private final long currentMax; + private final long currentAvg; + private final long rollingAvg; + + public Stats(long cancellationCount, long currentMax, long currentAvg, long rollingAvg) { + this.cancellationCount = cancellationCount; + this.currentMax = currentMax; + this.currentAvg = currentAvg; + this.rollingAvg = rollingAvg; + } + + public Stats(StreamInput in) throws IOException { + this(in.readVLong(), in.readVLong(), in.readVLong(), in.readVLong()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.startObject() + .field("cancellation_count", cancellationCount) + .humanReadableField("current_max_bytes", "current_max", new ByteSizeValue(currentMax)) + .humanReadableField("current_avg_bytes", "current_avg", new ByteSizeValue(currentAvg)) + .humanReadableField("rolling_avg_bytes", "rolling_avg", new ByteSizeValue(rollingAvg)) + .endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVLong(cancellationCount); + out.writeVLong(currentMax); + out.writeVLong(currentAvg); + out.writeVLong(rollingAvg); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Stats stats = (Stats) o; + return cancellationCount == stats.cancellationCount + && currentMax == stats.currentMax + && currentAvg == stats.currentAvg + && rollingAvg == stats.rollingAvg; + } + + @Override + public int hashCode() { + return Objects.hash(cancellationCount, currentMax, currentAvg, rollingAvg); + } + } +} diff --git a/server/src/main/java/org/opensearch/search/backpressure/trackers/NodeDuressTracker.java b/server/src/main/java/org/opensearch/search/backpressure/trackers/NodeDuressTracker.java new file mode 100644 index 0000000000000..8e35c724a8fef --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/trackers/NodeDuressTracker.java @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.trackers; + +import org.opensearch.common.util.Streak; + +import java.util.function.BooleanSupplier; + +/** + * NodeDuressTracker is used to check if the node is in duress. + * + * @opensearch.internal + */ +public class NodeDuressTracker { + /** + * Tracks the number of consecutive breaches. + */ + private final Streak breaches = new Streak(); + + /** + * Predicate that returns true when the node is in duress. + */ + private final BooleanSupplier isNodeInDuress; + + public NodeDuressTracker(BooleanSupplier isNodeInDuress) { + this.isNodeInDuress = isNodeInDuress; + } + + /** + * Evaluates the predicate and returns the number of consecutive breaches. + */ + public int check() { + return breaches.record(isNodeInDuress.getAsBoolean()); + } +} diff --git a/server/src/main/java/org/opensearch/search/backpressure/trackers/TaskResourceUsageTracker.java b/server/src/main/java/org/opensearch/search/backpressure/trackers/TaskResourceUsageTracker.java new file mode 100644 index 0000000000000..cbbb751b996be --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/trackers/TaskResourceUsageTracker.java @@ -0,0 +1,63 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.trackers; + +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.common.xcontent.ToXContentObject; +import org.opensearch.tasks.TaskCancellation; +import org.opensearch.tasks.Task; + +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicLong; + +/** + * TaskResourceUsageTracker is used to track completions and cancellations of search related tasks. + * + * @opensearch.internal + */ +public abstract class TaskResourceUsageTracker { + /** + * Counts the number of cancellations made due to this tracker. + */ + private final AtomicLong cancellations = new AtomicLong(); + + public long incrementCancellations() { + return cancellations.incrementAndGet(); + } + + public long getCancellations() { + return cancellations.get(); + } + + /** + * Returns a unique name for this tracker. + */ + public abstract String name(); + + /** + * Notifies the tracker to update its state when a task execution completes. + */ + public void update(Task task) {} + + /** + * Returns the cancellation reason for the given task, if it's eligible for cancellation. + */ + public abstract Optional checkAndMaybeGetCancellationReason(Task task); + + /** + * Returns the tracker's state as seen in the stats API. + */ + public abstract Stats stats(List activeTasks); + + /** + * Represents the tracker's state as seen in the stats API. + */ + public interface Stats extends ToXContentObject, Writeable {} +} diff --git a/server/src/main/java/org/opensearch/search/backpressure/trackers/TaskResourceUsageTrackerType.java b/server/src/main/java/org/opensearch/search/backpressure/trackers/TaskResourceUsageTrackerType.java new file mode 100644 index 0000000000000..2211d28ad30c0 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/trackers/TaskResourceUsageTrackerType.java @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.trackers; + +/** + * Defines the type of TaskResourceUsageTracker. + */ +public enum TaskResourceUsageTrackerType { + CPU_USAGE_TRACKER("cpu_usage_tracker"), + HEAP_USAGE_TRACKER("heap_usage_tracker"), + ELAPSED_TIME_TRACKER("elapsed_time_tracker"); + + private final String name; + + TaskResourceUsageTrackerType(String name) { + this.name = name; + } + + public String getName() { + return name; + } + + public static TaskResourceUsageTrackerType fromName(String name) { + switch (name) { + case "cpu_usage_tracker": + return CPU_USAGE_TRACKER; + case "heap_usage_tracker": + return HEAP_USAGE_TRACKER; + case "elapsed_time_tracker": + return ELAPSED_TIME_TRACKER; + } + + throw new IllegalArgumentException("Invalid TaskResourceUsageTrackerType: " + name); + } +} diff --git a/server/src/main/java/org/opensearch/search/backpressure/trackers/package-info.java b/server/src/main/java/org/opensearch/search/backpressure/trackers/package-info.java new file mode 100644 index 0000000000000..da0532421391e --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/trackers/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * This package contains trackers to check if resource usage limits are breached on a node or task level. + */ +package org.opensearch.search.backpressure.trackers; diff --git a/server/src/main/java/org/opensearch/tasks/CancellableTask.java b/server/src/main/java/org/opensearch/tasks/CancellableTask.java index 439be2b630e84..336f5b1f4c244 100644 --- a/server/src/main/java/org/opensearch/tasks/CancellableTask.java +++ b/server/src/main/java/org/opensearch/tasks/CancellableTask.java @@ -71,7 +71,7 @@ public CancellableTask( /** * This method is called by the task manager when this task is cancelled. */ - final void cancel(String reason) { + public void cancel(String reason) { assert reason != null; if (cancelled.compareAndSet(false, true)) { this.reason = reason; diff --git a/server/src/main/java/org/opensearch/tasks/TaskCancellation.java b/server/src/main/java/org/opensearch/tasks/TaskCancellation.java new file mode 100644 index 0000000000000..d09312f38e3eb --- /dev/null +++ b/server/src/main/java/org/opensearch/tasks/TaskCancellation.java @@ -0,0 +1,111 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import org.opensearch.ExceptionsHelper; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +/** + * TaskCancellation represents a task eligible for cancellation. + * It doesn't guarantee that the task will actually get cancelled or not; that decision is left to the caller. + * + * It contains a list of cancellation reasons along with callbacks that are invoked when cancel() is called. + * + * @opensearch.internal + */ +public class TaskCancellation implements Comparable { + private final CancellableTask task; + private final List reasons; + private final List onCancelCallbacks; + + public TaskCancellation(CancellableTask task, List reasons, List onCancelCallbacks) { + this.task = task; + this.reasons = reasons; + this.onCancelCallbacks = onCancelCallbacks; + } + + public CancellableTask getTask() { + return task; + } + + public List getReasons() { + return reasons; + } + + public String getReasonString() { + return reasons.stream().map(Reason::getMessage).collect(Collectors.joining(", ")); + } + + /** + * Cancels the task and invokes all onCancelCallbacks. + */ + public void cancel() { + if (isEligibleForCancellation() == false) { + return; + } + + task.cancel(getReasonString()); + + List exceptions = new ArrayList<>(); + for (Runnable callback : onCancelCallbacks) { + try { + callback.run(); + } catch (Exception e) { + exceptions.add(e); + } + } + ExceptionsHelper.maybeThrowRuntimeAndSuppress(exceptions); + } + + /** + * Returns the sum of all cancellation scores. + * + * A zero score indicates no reason to cancel the task. + * A task with a higher score suggests greater possibility of recovering the node when it is cancelled. + */ + public int totalCancellationScore() { + return reasons.stream().mapToInt(Reason::getCancellationScore).sum(); + } + + /** + * A task is eligible for cancellation if it has one or more cancellation reasons, and is not already cancelled. + */ + public boolean isEligibleForCancellation() { + return (task.isCancelled() == false) && (reasons.size() > 0); + } + + @Override + public int compareTo(TaskCancellation other) { + return Integer.compare(totalCancellationScore(), other.totalCancellationScore()); + } + + /** + * Represents the cancellation reason for a task. + */ + public static class Reason { + private final String message; + private final int cancellationScore; + + public Reason(String message, int cancellationScore) { + this.message = message; + this.cancellationScore = cancellationScore; + } + + public String getMessage() { + return message; + } + + public int getCancellationScore() { + return cancellationScore; + } + } +} diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index c3cad117390e4..b4806b531429e 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.ExceptionsHelper; import org.opensearch.common.SuppressForbidden; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.ClusterSettings; @@ -50,6 +51,7 @@ public class TaskResourceTrackingService implements RunnableTaskExecutionListene private static final ThreadMXBean threadMXBean = (ThreadMXBean) ManagementFactory.getThreadMXBean(); private final ConcurrentMapLong resourceAwareTasks = ConcurrentCollections.newConcurrentMapLongWithAggressiveConcurrency(); + private final List taskCompletionListeners = new ArrayList<>(); private final ThreadPool threadPool; private volatile boolean taskResourceTrackingEnabled; @@ -116,6 +118,16 @@ public void stopTracking(Task task) { } finally { resourceAwareTasks.remove(task.getId()); } + + List exceptions = new ArrayList<>(); + for (TaskCompletionListener listener : taskCompletionListeners) { + try { + listener.onTaskCompleted(task); + } catch (Exception e) { + exceptions.add(e); + } + } + ExceptionsHelper.maybeThrowRuntimeAndSuppress(exceptions); } /** @@ -245,4 +257,14 @@ private ThreadContext.StoredContext addTaskIdToThreadContext(Task task) { return storedContext; } + /** + * Listener that gets invoked when a task execution completes. + */ + public interface TaskCompletionListener { + void onTaskCompleted(Task task); + } + + public void addTaskCompletionListener(TaskCompletionListener listener) { + this.taskCompletionListeners.add(listener); + } } diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/stats/NodeStatsTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/stats/NodeStatsTests.java index 380f8ce581e53..eb2c9602333b8 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/stats/NodeStatsTests.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/stats/NodeStatsTests.java @@ -709,6 +709,7 @@ public static NodeStats createNodeStats() { adaptiveSelectionStats, scriptCacheStats, null, + null, null ); } diff --git a/server/src/test/java/org/opensearch/cluster/DiskUsageTests.java b/server/src/test/java/org/opensearch/cluster/DiskUsageTests.java index 8259971d1b695..0bbe1dda95945 100644 --- a/server/src/test/java/org/opensearch/cluster/DiskUsageTests.java +++ b/server/src/test/java/org/opensearch/cluster/DiskUsageTests.java @@ -185,6 +185,7 @@ public void testFillDiskUsage() { null, null, null, + null, null ), new NodeStats( @@ -205,6 +206,7 @@ public void testFillDiskUsage() { null, null, null, + null, null ), new NodeStats( @@ -225,6 +227,7 @@ public void testFillDiskUsage() { null, null, null, + null, null ) ); @@ -276,6 +279,7 @@ public void testFillDiskUsageSomeInvalidValues() { null, null, null, + null, null ), new NodeStats( @@ -296,6 +300,7 @@ public void testFillDiskUsageSomeInvalidValues() { null, null, null, + null, null ), new NodeStats( @@ -316,6 +321,7 @@ public void testFillDiskUsageSomeInvalidValues() { null, null, null, + null, null ) ); diff --git a/server/src/test/java/org/opensearch/common/util/MovingAverageTests.java b/server/src/test/java/org/opensearch/common/util/MovingAverageTests.java new file mode 100644 index 0000000000000..415058992e081 --- /dev/null +++ b/server/src/test/java/org/opensearch/common/util/MovingAverageTests.java @@ -0,0 +1,49 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.util; + +import org.opensearch.test.OpenSearchTestCase; + +public class MovingAverageTests extends OpenSearchTestCase { + + public void testMovingAverage() { + MovingAverage ma = new MovingAverage(5); + + // No observations + assertEquals(0.0, ma.getAverage(), 0.0); + assertEquals(0, ma.getCount()); + + // Not enough observations + ma.record(1); + ma.record(2); + ma.record(3); + assertEquals(2.0, ma.getAverage(), 0.0); + assertEquals(3, ma.getCount()); + assertFalse(ma.isReady()); + + // Enough observations + ma.record(4); + ma.record(5); + ma.record(6); + assertEquals(4, ma.getAverage(), 0.0); + assertEquals(6, ma.getCount()); + assertTrue(ma.isReady()); + } + + public void testMovingAverageWithZeroSize() { + try { + new MovingAverage(0); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("window size must be greater than zero")); + return; + } + + fail("exception should have been thrown"); + } +} diff --git a/server/src/test/java/org/opensearch/common/util/StreakTests.java b/server/src/test/java/org/opensearch/common/util/StreakTests.java new file mode 100644 index 0000000000000..682a28d3a3a8b --- /dev/null +++ b/server/src/test/java/org/opensearch/common/util/StreakTests.java @@ -0,0 +1,33 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.util; + +import org.opensearch.test.OpenSearchTestCase; + +public class StreakTests extends OpenSearchTestCase { + + public void testStreak() { + Streak streak = new Streak(); + + // Streak starts with zero. + assertEquals(0, streak.length()); + + // Streak increases on successive successful events. + streak.record(true); + assertEquals(1, streak.length()); + streak.record(true); + assertEquals(2, streak.length()); + streak.record(true); + assertEquals(3, streak.length()); + + // Streak resets to zero after an unsuccessful event. + streak.record(false); + assertEquals(0, streak.length()); + } +} diff --git a/server/src/test/java/org/opensearch/common/util/TokenBucketTests.java b/server/src/test/java/org/opensearch/common/util/TokenBucketTests.java new file mode 100644 index 0000000000000..a52e97cdd835c --- /dev/null +++ b/server/src/test/java/org/opensearch/common/util/TokenBucketTests.java @@ -0,0 +1,77 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.util; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.LongSupplier; + +public class TokenBucketTests extends OpenSearchTestCase { + + public void testTokenBucket() { + AtomicLong mockTimeNanos = new AtomicLong(); + LongSupplier mockTimeNanosSupplier = mockTimeNanos::get; + + // Token bucket that refills at 2 tokens/second and allows short bursts up to 3 operations. + TokenBucket tokenBucket = new TokenBucket(mockTimeNanosSupplier, 2.0 / TimeUnit.SECONDS.toNanos(1), 3); + + // Three operations succeed, fourth fails. + assertTrue(tokenBucket.request()); + assertTrue(tokenBucket.request()); + assertTrue(tokenBucket.request()); + assertFalse(tokenBucket.request()); + + // Clock moves ahead by one second. Two operations succeed, third fails. + mockTimeNanos.addAndGet(TimeUnit.SECONDS.toNanos(1)); + assertTrue(tokenBucket.request()); + assertTrue(tokenBucket.request()); + assertFalse(tokenBucket.request()); + + // Clock moves ahead by half a second. One operation succeeds, second fails. + mockTimeNanos.addAndGet(TimeUnit.MILLISECONDS.toNanos(500)); + assertTrue(tokenBucket.request()); + assertFalse(tokenBucket.request()); + + // Clock moves ahead by many seconds, but the token bucket should be capped at the 'burst' capacity. + mockTimeNanos.addAndGet(TimeUnit.SECONDS.toNanos(10)); + assertTrue(tokenBucket.request()); + assertTrue(tokenBucket.request()); + assertTrue(tokenBucket.request()); + assertFalse(tokenBucket.request()); + + // Ability to request fractional tokens. + mockTimeNanos.addAndGet(TimeUnit.MILLISECONDS.toNanos(250)); + assertFalse(tokenBucket.request(1.0)); + assertTrue(tokenBucket.request(0.5)); + } + + public void testTokenBucketWithInvalidRate() { + try { + new TokenBucket(System::nanoTime, -1, 2); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("rate must be greater than zero")); + return; + } + + fail("exception should have been thrown"); + } + + public void testTokenBucketWithInvalidBurst() { + try { + new TokenBucket(System::nanoTime, 1, 0); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("burst must be greater than zero")); + return; + } + + fail("exception should have been thrown"); + } +} diff --git a/server/src/test/java/org/opensearch/search/backpressure/SearchBackpressureServiceTests.java b/server/src/test/java/org/opensearch/search/backpressure/SearchBackpressureServiceTests.java new file mode 100644 index 0000000000000..07a962c6824ca --- /dev/null +++ b/server/src/test/java/org/opensearch/search/backpressure/SearchBackpressureServiceTests.java @@ -0,0 +1,270 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure; + +import org.opensearch.action.search.SearchShardTask; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.search.backpressure.settings.SearchBackpressureMode; +import org.opensearch.search.backpressure.settings.SearchBackpressureSettings; +import org.opensearch.search.backpressure.settings.SearchShardTaskSettings; +import org.opensearch.search.backpressure.trackers.NodeDuressTracker; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.search.backpressure.stats.SearchBackpressureStats; +import org.opensearch.search.backpressure.stats.SearchShardTaskStats; +import org.opensearch.search.backpressure.trackers.TaskResourceUsageTracker; +import org.opensearch.search.backpressure.trackers.TaskResourceUsageTrackerType; +import org.opensearch.tasks.CancellableTask; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskCancellation; +import org.opensearch.tasks.TaskResourceTrackingService; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.LongSupplier; + +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.search.backpressure.SearchBackpressureTestHelpers.createMockTaskWithResourceStats; + +public class SearchBackpressureServiceTests extends OpenSearchTestCase { + + public void testIsNodeInDuress() { + TaskResourceTrackingService mockTaskResourceTrackingService = mock(TaskResourceTrackingService.class); + ThreadPool mockThreadPool = mock(ThreadPool.class); + + AtomicReference cpuUsage = new AtomicReference<>(); + AtomicReference heapUsage = new AtomicReference<>(); + NodeDuressTracker cpuUsageTracker = new NodeDuressTracker(() -> cpuUsage.get() >= 0.5); + NodeDuressTracker heapUsageTracker = new NodeDuressTracker(() -> heapUsage.get() >= 0.5); + + SearchBackpressureSettings settings = new SearchBackpressureSettings( + Settings.EMPTY, + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + + SearchBackpressureService service = new SearchBackpressureService( + settings, + mockTaskResourceTrackingService, + mockThreadPool, + System::nanoTime, + List.of(cpuUsageTracker, heapUsageTracker), + Collections.emptyList() + ); + + // Node not in duress. + cpuUsage.set(0.0); + heapUsage.set(0.0); + assertFalse(service.isNodeInDuress()); + + // Node in duress; but not for many consecutive data points. + cpuUsage.set(1.0); + heapUsage.set(1.0); + assertFalse(service.isNodeInDuress()); + + // Node in duress for consecutive data points. + assertFalse(service.isNodeInDuress()); + assertTrue(service.isNodeInDuress()); + + // Node not in duress anymore. + cpuUsage.set(0.0); + heapUsage.set(0.0); + assertFalse(service.isNodeInDuress()); + } + + public void testTrackerStateUpdateOnTaskCompletion() { + TaskResourceTrackingService mockTaskResourceTrackingService = mock(TaskResourceTrackingService.class); + ThreadPool mockThreadPool = mock(ThreadPool.class); + LongSupplier mockTimeNanosSupplier = () -> TimeUnit.SECONDS.toNanos(1234); + TaskResourceUsageTracker mockTaskResourceUsageTracker = mock(TaskResourceUsageTracker.class); + + SearchBackpressureSettings settings = new SearchBackpressureSettings( + Settings.EMPTY, + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + + SearchBackpressureService service = new SearchBackpressureService( + settings, + mockTaskResourceTrackingService, + mockThreadPool, + mockTimeNanosSupplier, + Collections.emptyList(), + List.of(mockTaskResourceUsageTracker) + ); + + // Record task completions to update the tracker state. Tasks other than SearchShardTask are ignored. + service.onTaskCompleted(createMockTaskWithResourceStats(CancellableTask.class, 100, 200)); + for (int i = 0; i < 100; i++) { + service.onTaskCompleted(createMockTaskWithResourceStats(SearchShardTask.class, 100, 200)); + } + assertEquals(100, service.getState().getCompletionCount()); + verify(mockTaskResourceUsageTracker, times(100)).update(any()); + } + + public void testInFlightCancellation() { + TaskResourceTrackingService mockTaskResourceTrackingService = mock(TaskResourceTrackingService.class); + ThreadPool mockThreadPool = mock(ThreadPool.class); + AtomicLong mockTime = new AtomicLong(0); + LongSupplier mockTimeNanosSupplier = mockTime::get; + NodeDuressTracker mockNodeDuressTracker = new NodeDuressTracker(() -> true); + + TaskResourceUsageTracker mockTaskResourceUsageTracker = new TaskResourceUsageTracker() { + @Override + public String name() { + return TaskResourceUsageTrackerType.CPU_USAGE_TRACKER.getName(); + } + + @Override + public void update(Task task) {} + + @Override + public Optional checkAndMaybeGetCancellationReason(Task task) { + if (task.getTotalResourceStats().getCpuTimeInNanos() < 300) { + return Optional.empty(); + } + + return Optional.of(new TaskCancellation.Reason("limits exceeded", 5)); + } + + @Override + public Stats stats(List activeTasks) { + return new MockStats(getCancellations()); + } + }; + + // Mocking 'settings' with predictable rate limiting thresholds. + SearchBackpressureSettings settings = spy( + new SearchBackpressureSettings( + Settings.builder() + .put(SearchBackpressureSettings.SETTING_MODE.getKey(), "enforced") + .put(SearchBackpressureSettings.SETTING_CANCELLATION_RATIO.getKey(), 0.1) + .put(SearchBackpressureSettings.SETTING_CANCELLATION_RATE.getKey(), 0.003) + .put(SearchBackpressureSettings.SETTING_CANCELLATION_BURST.getKey(), 10.0) + .build(), + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ) + ); + + SearchBackpressureService service = new SearchBackpressureService( + settings, + mockTaskResourceTrackingService, + mockThreadPool, + mockTimeNanosSupplier, + List.of(mockNodeDuressTracker), + List.of(mockTaskResourceUsageTracker) + ); + + // Run two iterations so that node is marked 'in duress' from the third iteration onwards. + service.doRun(); + service.doRun(); + + // Mocking 'settings' with predictable totalHeapBytesThreshold so that cancellation logic doesn't get skipped. + long taskHeapUsageBytes = 500; + SearchShardTaskSettings shardTaskSettings = mock(SearchShardTaskSettings.class); + when(shardTaskSettings.getTotalHeapBytesThreshold()).thenReturn(taskHeapUsageBytes); + when(settings.getSearchShardTaskSettings()).thenReturn(shardTaskSettings); + + // Create a mix of low and high resource usage tasks (60 low + 15 high resource usage tasks). + Map activeTasks = new HashMap<>(); + for (long i = 0; i < 75; i++) { + if (i % 5 == 0) { + activeTasks.put(i, createMockTaskWithResourceStats(SearchShardTask.class, 500, taskHeapUsageBytes)); + } else { + activeTasks.put(i, createMockTaskWithResourceStats(SearchShardTask.class, 100, taskHeapUsageBytes)); + } + } + doReturn(activeTasks).when(mockTaskResourceTrackingService).getResourceAwareTasks(); + + // There are 15 tasks eligible for cancellation but only 10 will be cancelled (burst limit). + service.doRun(); + assertEquals(10, service.getState().getCancellationCount()); + assertEquals(1, service.getState().getLimitReachedCount()); + + // If the clock or completed task count haven't made sufficient progress, we'll continue to be rate-limited. + service.doRun(); + assertEquals(10, service.getState().getCancellationCount()); + assertEquals(2, service.getState().getLimitReachedCount()); + + // Simulate task completion to replenish some tokens. + // This will add 2 tokens (task count delta * cancellationRatio) to 'rateLimitPerTaskCompletion'. + for (int i = 0; i < 20; i++) { + service.onTaskCompleted(createMockTaskWithResourceStats(SearchShardTask.class, 100, taskHeapUsageBytes)); + } + service.doRun(); + assertEquals(12, service.getState().getCancellationCount()); + assertEquals(3, service.getState().getLimitReachedCount()); + + // Fast-forward the clock by one second to replenish some tokens. + // This will add 3 tokens (time delta * rate) to 'rateLimitPerTime'. + mockTime.addAndGet(TimeUnit.SECONDS.toNanos(1)); + service.doRun(); + assertEquals(15, service.getState().getCancellationCount()); + assertEquals(3, service.getState().getLimitReachedCount()); // no more tasks to cancel; limit not reached + + // Verify search backpressure stats. + SearchBackpressureStats expectedStats = new SearchBackpressureStats( + new SearchShardTaskStats(15, 3, Map.of(TaskResourceUsageTrackerType.CPU_USAGE_TRACKER, new MockStats(15))), + SearchBackpressureMode.ENFORCED + ); + SearchBackpressureStats actualStats = service.nodeStats(); + assertEquals(expectedStats, actualStats); + } + + private static class MockStats implements TaskResourceUsageTracker.Stats { + private final long cancellationCount; + + public MockStats(long cancellationCount) { + this.cancellationCount = cancellationCount; + } + + public MockStats(StreamInput in) throws IOException { + this(in.readVLong()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.startObject().field("cancellation_count", cancellationCount).endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVLong(cancellationCount); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MockStats mockStats = (MockStats) o; + return cancellationCount == mockStats.cancellationCount; + } + + @Override + public int hashCode() { + return Objects.hash(cancellationCount); + } + } +} diff --git a/server/src/test/java/org/opensearch/search/backpressure/stats/SearchBackpressureStatsTests.java b/server/src/test/java/org/opensearch/search/backpressure/stats/SearchBackpressureStatsTests.java new file mode 100644 index 0000000000000..2665a6d5e05aa --- /dev/null +++ b/server/src/test/java/org/opensearch/search/backpressure/stats/SearchBackpressureStatsTests.java @@ -0,0 +1,32 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.stats; + +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.search.backpressure.settings.SearchBackpressureMode; +import org.opensearch.test.AbstractWireSerializingTestCase; + +public class SearchBackpressureStatsTests extends AbstractWireSerializingTestCase { + @Override + protected Writeable.Reader instanceReader() { + return SearchBackpressureStats::new; + } + + @Override + protected SearchBackpressureStats createTestInstance() { + return randomInstance(); + } + + public static SearchBackpressureStats randomInstance() { + return new SearchBackpressureStats( + SearchShardTaskStatsTests.randomInstance(), + randomFrom(SearchBackpressureMode.DISABLED, SearchBackpressureMode.MONITOR_ONLY, SearchBackpressureMode.ENFORCED) + ); + } +} diff --git a/server/src/test/java/org/opensearch/search/backpressure/stats/SearchShardTaskStatsTests.java b/server/src/test/java/org/opensearch/search/backpressure/stats/SearchShardTaskStatsTests.java new file mode 100644 index 0000000000000..d5bc9398492eb --- /dev/null +++ b/server/src/test/java/org/opensearch/search/backpressure/stats/SearchShardTaskStatsTests.java @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.stats; + +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.search.backpressure.trackers.CpuUsageTracker; +import org.opensearch.search.backpressure.trackers.ElapsedTimeTracker; +import org.opensearch.search.backpressure.trackers.HeapUsageTracker; +import org.opensearch.search.backpressure.trackers.TaskResourceUsageTracker; +import org.opensearch.search.backpressure.trackers.TaskResourceUsageTrackerType; +import org.opensearch.test.AbstractWireSerializingTestCase; + +import java.util.Map; + +public class SearchShardTaskStatsTests extends AbstractWireSerializingTestCase { + @Override + protected Writeable.Reader instanceReader() { + return SearchShardTaskStats::new; + } + + @Override + protected SearchShardTaskStats createTestInstance() { + return randomInstance(); + } + + public static SearchShardTaskStats randomInstance() { + Map resourceUsageTrackerStats = Map.of( + TaskResourceUsageTrackerType.CPU_USAGE_TRACKER, + new CpuUsageTracker.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong()), + TaskResourceUsageTrackerType.HEAP_USAGE_TRACKER, + new HeapUsageTracker.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong()), + TaskResourceUsageTrackerType.ELAPSED_TIME_TRACKER, + new ElapsedTimeTracker.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong()) + ); + + return new SearchShardTaskStats(randomNonNegativeLong(), randomNonNegativeLong(), resourceUsageTrackerStats); + } +} diff --git a/server/src/test/java/org/opensearch/search/backpressure/trackers/CpuUsageTrackerTests.java b/server/src/test/java/org/opensearch/search/backpressure/trackers/CpuUsageTrackerTests.java new file mode 100644 index 0000000000000..c790fb2e60eea --- /dev/null +++ b/server/src/test/java/org/opensearch/search/backpressure/trackers/CpuUsageTrackerTests.java @@ -0,0 +1,48 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.trackers; + +import org.opensearch.action.search.SearchShardTask; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.search.backpressure.settings.SearchBackpressureSettings; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskCancellation; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Optional; + +import static org.opensearch.search.backpressure.SearchBackpressureTestHelpers.createMockTaskWithResourceStats; + +public class CpuUsageTrackerTests extends OpenSearchTestCase { + private static final SearchBackpressureSettings mockSettings = new SearchBackpressureSettings( + Settings.builder() + .put(CpuUsageTracker.SETTING_CPU_TIME_MILLIS_THRESHOLD.getKey(), 15) // 15 ms + .build(), + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + + public void testEligibleForCancellation() { + Task task = createMockTaskWithResourceStats(SearchShardTask.class, 200000000, 200); + CpuUsageTracker tracker = new CpuUsageTracker(mockSettings); + + Optional reason = tracker.checkAndMaybeGetCancellationReason(task); + assertTrue(reason.isPresent()); + assertEquals(1, reason.get().getCancellationScore()); + assertEquals("cpu usage exceeded [200ms >= 15ms]", reason.get().getMessage()); + } + + public void testNotEligibleForCancellation() { + Task task = createMockTaskWithResourceStats(SearchShardTask.class, 5000000, 200); + CpuUsageTracker tracker = new CpuUsageTracker(mockSettings); + + Optional reason = tracker.checkAndMaybeGetCancellationReason(task); + assertFalse(reason.isPresent()); + } +} diff --git a/server/src/test/java/org/opensearch/search/backpressure/trackers/ElapsedTimeTrackerTests.java b/server/src/test/java/org/opensearch/search/backpressure/trackers/ElapsedTimeTrackerTests.java new file mode 100644 index 0000000000000..67ed6059a1914 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/backpressure/trackers/ElapsedTimeTrackerTests.java @@ -0,0 +1,49 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.trackers; + +import org.opensearch.action.search.SearchShardTask; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.search.backpressure.settings.SearchBackpressureSettings; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskCancellation; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Optional; + +import static org.opensearch.search.backpressure.SearchBackpressureTestHelpers.createMockTaskWithResourceStats; + +public class ElapsedTimeTrackerTests extends OpenSearchTestCase { + + private static final SearchBackpressureSettings mockSettings = new SearchBackpressureSettings( + Settings.builder() + .put(ElapsedTimeTracker.SETTING_ELAPSED_TIME_MILLIS_THRESHOLD.getKey(), 100) // 100 ms + .build(), + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + + public void testEligibleForCancellation() { + Task task = createMockTaskWithResourceStats(SearchShardTask.class, 1, 1, 0); + ElapsedTimeTracker tracker = new ElapsedTimeTracker(mockSettings, () -> 200000000); + + Optional reason = tracker.checkAndMaybeGetCancellationReason(task); + assertTrue(reason.isPresent()); + assertEquals(1, reason.get().getCancellationScore()); + assertEquals("elapsed time exceeded [200ms >= 100ms]", reason.get().getMessage()); + } + + public void testNotEligibleForCancellation() { + Task task = createMockTaskWithResourceStats(SearchShardTask.class, 1, 1, 150000000); + ElapsedTimeTracker tracker = new ElapsedTimeTracker(mockSettings, () -> 200000000); + + Optional reason = tracker.checkAndMaybeGetCancellationReason(task); + assertFalse(reason.isPresent()); + } +} diff --git a/server/src/test/java/org/opensearch/search/backpressure/trackers/HeapUsageTrackerTests.java b/server/src/test/java/org/opensearch/search/backpressure/trackers/HeapUsageTrackerTests.java new file mode 100644 index 0000000000000..b9967da22fbf1 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/backpressure/trackers/HeapUsageTrackerTests.java @@ -0,0 +1,83 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.trackers; + +import org.opensearch.action.search.SearchShardTask; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.search.backpressure.settings.SearchBackpressureSettings; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskCancellation; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Optional; + +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; +import static org.opensearch.search.backpressure.SearchBackpressureTestHelpers.createMockTaskWithResourceStats; + +public class HeapUsageTrackerTests extends OpenSearchTestCase { + private static final long HEAP_BYTES_THRESHOLD = 100; + private static final int HEAP_MOVING_AVERAGE_WINDOW_SIZE = 100; + + private static final SearchBackpressureSettings mockSettings = new SearchBackpressureSettings( + Settings.builder() + .put(HeapUsageTracker.SETTING_HEAP_VARIANCE_THRESHOLD.getKey(), 2.0) + .put(HeapUsageTracker.SETTING_HEAP_MOVING_AVERAGE_WINDOW_SIZE.getKey(), HEAP_MOVING_AVERAGE_WINDOW_SIZE) + .build(), + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + + public void testEligibleForCancellation() { + HeapUsageTracker tracker = spy(new HeapUsageTracker(mockSettings)); + when(tracker.getHeapBytesThreshold()).thenReturn(HEAP_BYTES_THRESHOLD); + Task task = createMockTaskWithResourceStats(SearchShardTask.class, 1, 50); + + // Record enough observations to make the moving average 'ready'. + for (int i = 0; i < HEAP_MOVING_AVERAGE_WINDOW_SIZE; i++) { + tracker.update(task); + } + + // Task that has heap usage >= heapBytesThreshold and (movingAverage * heapVariance). + task = createMockTaskWithResourceStats(SearchShardTask.class, 1, 200); + Optional reason = tracker.checkAndMaybeGetCancellationReason(task); + assertTrue(reason.isPresent()); + assertEquals(4, reason.get().getCancellationScore()); + assertEquals("heap usage exceeded [200b >= 100b]", reason.get().getMessage()); + } + + public void testNotEligibleForCancellation() { + Task task; + Optional reason; + HeapUsageTracker tracker = spy(new HeapUsageTracker(mockSettings)); + when(tracker.getHeapBytesThreshold()).thenReturn(HEAP_BYTES_THRESHOLD); + + // Task with heap usage < heapBytesThreshold. + task = createMockTaskWithResourceStats(SearchShardTask.class, 1, 99); + + // Not enough observations. + reason = tracker.checkAndMaybeGetCancellationReason(task); + assertFalse(reason.isPresent()); + + // Record enough observations to make the moving average 'ready'. + for (int i = 0; i < HEAP_MOVING_AVERAGE_WINDOW_SIZE; i++) { + tracker.update(task); + } + + // Task with heap usage < heapBytesThreshold should not be cancelled. + reason = tracker.checkAndMaybeGetCancellationReason(task); + assertFalse(reason.isPresent()); + + // Task with heap usage between heapBytesThreshold and (movingAverage * heapVariance) should not be cancelled. + double allowedHeapUsage = 99.0 * 2.0; + task = createMockTaskWithResourceStats(SearchShardTask.class, 1, randomLongBetween(99, (long) allowedHeapUsage - 1)); + reason = tracker.checkAndMaybeGetCancellationReason(task); + assertFalse(reason.isPresent()); + } +} diff --git a/server/src/test/java/org/opensearch/search/backpressure/trackers/NodeDuressTrackerTests.java b/server/src/test/java/org/opensearch/search/backpressure/trackers/NodeDuressTrackerTests.java new file mode 100644 index 0000000000000..472ba95566523 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/backpressure/trackers/NodeDuressTrackerTests.java @@ -0,0 +1,35 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.trackers; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.concurrent.atomic.AtomicReference; + +public class NodeDuressTrackerTests extends OpenSearchTestCase { + + public void testNodeDuressTracker() { + AtomicReference cpuUsage = new AtomicReference<>(0.0); + NodeDuressTracker tracker = new NodeDuressTracker(() -> cpuUsage.get() >= 0.5); + + // Node not in duress. + assertEquals(0, tracker.check()); + + // Node in duress; the streak must keep increasing. + cpuUsage.set(0.7); + assertEquals(1, tracker.check()); + assertEquals(2, tracker.check()); + assertEquals(3, tracker.check()); + + // Node not in duress anymore. + cpuUsage.set(0.3); + assertEquals(0, tracker.check()); + assertEquals(0, tracker.check()); + } +} diff --git a/server/src/test/java/org/opensearch/tasks/TaskCancellationTests.java b/server/src/test/java/org/opensearch/tasks/TaskCancellationTests.java new file mode 100644 index 0000000000000..e74f89c905499 --- /dev/null +++ b/server/src/test/java/org/opensearch/tasks/TaskCancellationTests.java @@ -0,0 +1,77 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import org.opensearch.action.search.SearchShardTask; +import org.opensearch.search.backpressure.trackers.TaskResourceUsageTracker; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +public class TaskCancellationTests extends OpenSearchTestCase { + + public void testTaskCancellation() { + SearchShardTask mockTask = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()); + + TaskResourceUsageTracker mockTracker1 = createMockTaskResourceUsageTracker("mock_tracker_1"); + TaskResourceUsageTracker mockTracker2 = createMockTaskResourceUsageTracker("mock_tracker_2"); + TaskResourceUsageTracker mockTracker3 = createMockTaskResourceUsageTracker("mock_tracker_3"); + + List reasons = new ArrayList<>(); + List callbacks = List.of(mockTracker1::incrementCancellations, mockTracker2::incrementCancellations); + TaskCancellation taskCancellation = new TaskCancellation(mockTask, reasons, callbacks); + + // Task does not have any reason to be cancelled. + assertEquals(0, taskCancellation.totalCancellationScore()); + assertFalse(taskCancellation.isEligibleForCancellation()); + taskCancellation.cancel(); + assertEquals(0, mockTracker1.getCancellations()); + assertEquals(0, mockTracker2.getCancellations()); + assertEquals(0, mockTracker3.getCancellations()); + + // Task has one or more reasons to be cancelled. + reasons.add(new TaskCancellation.Reason("limits exceeded 1", 10)); + reasons.add(new TaskCancellation.Reason("limits exceeded 2", 20)); + reasons.add(new TaskCancellation.Reason("limits exceeded 3", 5)); + assertEquals(35, taskCancellation.totalCancellationScore()); + assertTrue(taskCancellation.isEligibleForCancellation()); + + // Cancel the task and validate the cancellation reason and invocation of callbacks. + taskCancellation.cancel(); + assertTrue(mockTask.getReasonCancelled().contains("limits exceeded 1, limits exceeded 2, limits exceeded 3")); + assertEquals(1, mockTracker1.getCancellations()); + assertEquals(1, mockTracker2.getCancellations()); + assertEquals(0, mockTracker3.getCancellations()); + } + + private static TaskResourceUsageTracker createMockTaskResourceUsageTracker(String name) { + return new TaskResourceUsageTracker() { + @Override + public String name() { + return name; + } + + @Override + public void update(Task task) {} + + @Override + public Optional checkAndMaybeGetCancellationReason(Task task) { + return Optional.empty(); + } + + @Override + public Stats stats(List activeTasks) { + return null; + } + }; + } +} diff --git a/test/framework/src/main/java/org/opensearch/cluster/MockInternalClusterInfoService.java b/test/framework/src/main/java/org/opensearch/cluster/MockInternalClusterInfoService.java index 149d7ccdaaae0..1c9514a48c752 100644 --- a/test/framework/src/main/java/org/opensearch/cluster/MockInternalClusterInfoService.java +++ b/test/framework/src/main/java/org/opensearch/cluster/MockInternalClusterInfoService.java @@ -114,7 +114,8 @@ List adjustNodesStats(List nodesStats) { nodeStats.getAdaptiveSelectionStats(), nodeStats.getScriptCacheStats(), nodeStats.getIndexingPressureStats(), - nodeStats.getShardIndexingPressureStats() + nodeStats.getShardIndexingPressureStats(), + nodeStats.getSearchBackpressureStats() ); }).collect(Collectors.toList()); } diff --git a/test/framework/src/main/java/org/opensearch/search/backpressure/SearchBackpressureTestHelpers.java b/test/framework/src/main/java/org/opensearch/search/backpressure/SearchBackpressureTestHelpers.java new file mode 100644 index 0000000000000..ba3653d0b4a84 --- /dev/null +++ b/test/framework/src/main/java/org/opensearch/search/backpressure/SearchBackpressureTestHelpers.java @@ -0,0 +1,47 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure; + +import org.opensearch.tasks.CancellableTask; +import org.opensearch.tasks.TaskResourceUsage; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class SearchBackpressureTestHelpers extends OpenSearchTestCase { + + public static T createMockTaskWithResourceStats(Class type, long cpuUsage, long heapUsage) { + return createMockTaskWithResourceStats(type, cpuUsage, heapUsage, 0); + } + + public static T createMockTaskWithResourceStats( + Class type, + long cpuUsage, + long heapUsage, + long startTimeNanos + ) { + T task = mock(type); + when(task.getTotalResourceStats()).thenReturn(new TaskResourceUsage(cpuUsage, heapUsage)); + when(task.getStartTimeNanos()).thenReturn(startTimeNanos); + + AtomicBoolean isCancelled = new AtomicBoolean(false); + doAnswer(invocation -> { + isCancelled.set(true); + return null; + }).when(task).cancel(anyString()); + doAnswer(invocation -> isCancelled.get()).when(task).isCancelled(); + + return task; + } +} diff --git a/test/framework/src/main/java/org/opensearch/test/InternalTestCluster.java b/test/framework/src/main/java/org/opensearch/test/InternalTestCluster.java index 122dadeb152bb..f32a0a46ff69c 100644 --- a/test/framework/src/main/java/org/opensearch/test/InternalTestCluster.java +++ b/test/framework/src/main/java/org/opensearch/test/InternalTestCluster.java @@ -2625,6 +2625,7 @@ public void ensureEstimatedStats() { false, false, false, + false, false ); assertThat(