From cb40de87cc6e3702aff334d15fc18d0b1880a05f Mon Sep 17 00:00:00 2001 From: Kaushal Kumar Date: Wed, 11 Sep 2024 12:12:04 -0700 Subject: [PATCH] Add cancellation framework changes in wlm (#15651) * cancellation related Signed-off-by: Kiran Prakash * Update CHANGELOG.md Signed-off-by: Kiran Prakash * add better cancellation reason Signed-off-by: Kiran Prakash * Update DefaultTaskCancellationTests.java Signed-off-by: Kiran Prakash * refactor Signed-off-by: Kiran Prakash * refactor Signed-off-by: Kiran Prakash * Update DefaultTaskCancellation.java Signed-off-by: Kiran Prakash * Update DefaultTaskCancellation.java Signed-off-by: Kiran Prakash * Update DefaultTaskCancellation.java Signed-off-by: Kiran Prakash * Update DefaultTaskSelectionStrategy.java Signed-off-by: Kiran Prakash * refactor Signed-off-by: Kiran Prakash * refactor node level threshold Signed-off-by: Kiran Prakash * use query group task Signed-off-by: Kaushal Kumar * code clean up and refactorings Signed-off-by: Kaushal Kumar * add unit tests and fix existing ones Signed-off-by: Kaushal Kumar * uncomment the test case Signed-off-by: Kaushal Kumar * update CHANGELOG Signed-off-by: Kaushal Kumar * fix imports Signed-off-by: Kaushal Kumar * refactor and add UTs for new constructs Signed-off-by: Kaushal Kumar * fix javadocs Signed-off-by: Kaushal Kumar * remove code clutter Signed-off-by: Kaushal Kumar * change annotation version and task selection strategy Signed-off-by: Kaushal Kumar * rename a util class Signed-off-by: Kaushal Kumar * remove wrappers from resource type Signed-off-by: Kaushal Kumar * apply spotless Signed-off-by: Kaushal Kumar * address comments Signed-off-by: Kaushal Kumar * add rename changes Signed-off-by: Kaushal Kumar * address comments Signed-off-by: Kaushal Kumar * refactor changes and logical bug fix Signed-off-by: Kaushal Kumar * address comments Signed-off-by: Kaushal Kumar --------- Signed-off-by: Kiran Prakash Signed-off-by: Kaushal Kumar Signed-off-by: Ankit Jain Co-authored-by: Kiran Prakash Co-authored-by: Ankit Jain --- CHANGELOG.md | 1 + .../cluster/metadata/QueryGroup.java | 3 +- .../wlm/QueryGroupLevelResourceUsageView.java | 12 +- .../org/opensearch/wlm/QueryGroupTask.java | 24 +- .../java/org/opensearch/wlm/ResourceType.java | 41 +- .../wlm/WorkloadManagementSettings.java | 2 + .../MaximumResourceTaskSelectionStrategy.java | 71 +++ .../QueryGroupTaskCancellationService.java | 205 +++++++ .../cancellation/TaskSelectionStrategy.java | 28 + .../wlm/cancellation/package-info.java | 12 + .../wlm/tracker/CpuUsageCalculator.java | 38 ++ .../wlm/tracker/MemoryUsageCalculator.java | 35 ++ ...QueryGroupResourceUsageTrackerService.java | 21 +- .../wlm/tracker/ResourceUsageCalculator.java | 34 ++ ...QueryGroupLevelResourceUsageViewTests.java | 49 +- .../org/opensearch/wlm/ResourceTypeTests.java | 19 - ...mumResourceTaskSelectionStrategyTests.java | 126 ++++ ...ueryGroupTaskCancellationServiceTests.java | 541 ++++++++++++++++++ .../tracker/ResourceUsageCalculatorTests.java | 70 +++ ...ceUsageCalculatorTrackerServiceTests.java} | 63 +- 20 files changed, 1297 insertions(+), 98 deletions(-) create mode 100644 server/src/main/java/org/opensearch/wlm/cancellation/MaximumResourceTaskSelectionStrategy.java create mode 100644 server/src/main/java/org/opensearch/wlm/cancellation/QueryGroupTaskCancellationService.java create mode 100644 server/src/main/java/org/opensearch/wlm/cancellation/TaskSelectionStrategy.java create mode 100644 server/src/main/java/org/opensearch/wlm/cancellation/package-info.java create mode 100644 server/src/main/java/org/opensearch/wlm/tracker/CpuUsageCalculator.java create mode 100644 server/src/main/java/org/opensearch/wlm/tracker/MemoryUsageCalculator.java create mode 100644 server/src/main/java/org/opensearch/wlm/tracker/ResourceUsageCalculator.java create mode 100644 server/src/test/java/org/opensearch/wlm/cancellation/MaximumResourceTaskSelectionStrategyTests.java create mode 100644 server/src/test/java/org/opensearch/wlm/cancellation/QueryGroupTaskCancellationServiceTests.java create mode 100644 server/src/test/java/org/opensearch/wlm/tracker/ResourceUsageCalculatorTests.java rename server/src/test/java/org/opensearch/wlm/tracker/{QueryGroupResourceUsageTrackerServiceTests.java => ResourceUsageCalculatorTrackerServiceTests.java} (69%) diff --git a/CHANGELOG.md b/CHANGELOG.md index c726b3e29cc40..21ce8ba4daf96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x] ### Added - Add path prefix support to hashed prefix snapshots ([#15664](https://github.com/opensearch-project/OpenSearch/pull/15664)) +- [Workload Management] QueryGroup resource cancellation framework changes ([#15651](https://github.com/opensearch-project/OpenSearch/pull/15651)) ### Dependencies - Bump `com.azure:azure-identity` from 1.13.0 to 1.13.2 ([#15578](https://github.com/opensearch-project/OpenSearch/pull/15578)) diff --git a/server/src/main/java/org/opensearch/cluster/metadata/QueryGroup.java b/server/src/main/java/org/opensearch/cluster/metadata/QueryGroup.java index dcd96dceb4bf1..0eeafdc8f5eed 100644 --- a/server/src/main/java/org/opensearch/cluster/metadata/QueryGroup.java +++ b/server/src/main/java/org/opensearch/cluster/metadata/QueryGroup.java @@ -12,6 +12,7 @@ import org.opensearch.cluster.Diff; import org.opensearch.common.UUIDs; import org.opensearch.common.annotation.ExperimentalApi; +import org.opensearch.common.annotation.PublicApi; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; @@ -41,7 +42,7 @@ * "updated_at": 4513232415 * } */ -@ExperimentalApi +@PublicApi(since = "2.18.0") public class QueryGroup extends AbstractDiffable implements ToXContentObject { public static final String _ID_STRING = "_id"; diff --git a/server/src/main/java/org/opensearch/wlm/QueryGroupLevelResourceUsageView.java b/server/src/main/java/org/opensearch/wlm/QueryGroupLevelResourceUsageView.java index 7577c8573ec10..de213eaab64a8 100644 --- a/server/src/main/java/org/opensearch/wlm/QueryGroupLevelResourceUsageView.java +++ b/server/src/main/java/org/opensearch/wlm/QueryGroupLevelResourceUsageView.java @@ -8,8 +8,6 @@ package org.opensearch.wlm; -import org.opensearch.tasks.Task; - import java.util.List; import java.util.Map; @@ -20,11 +18,11 @@ */ public class QueryGroupLevelResourceUsageView { // resourceUsage holds the resource usage data for a QueryGroup at a point in time - private final Map resourceUsage; + private final Map resourceUsage; // activeTasks holds the list of active tasks for a QueryGroup at a point in time - private final List activeTasks; + private final List activeTasks; - public QueryGroupLevelResourceUsageView(Map resourceUsage, List activeTasks) { + public QueryGroupLevelResourceUsageView(Map resourceUsage, List activeTasks) { this.resourceUsage = resourceUsage; this.activeTasks = activeTasks; } @@ -34,7 +32,7 @@ public QueryGroupLevelResourceUsageView(Map resourceUsage, L * * @return The map of resource usage data */ - public Map getResourceUsageData() { + public Map getResourceUsageData() { return resourceUsage; } @@ -43,7 +41,7 @@ public Map getResourceUsageData() { * * @return The list of active tasks */ - public List getActiveTasks() { + public List getActiveTasks() { return activeTasks; } } diff --git a/server/src/main/java/org/opensearch/wlm/QueryGroupTask.java b/server/src/main/java/org/opensearch/wlm/QueryGroupTask.java index 4eb413be61b72..a1cb766579d43 100644 --- a/server/src/main/java/org/opensearch/wlm/QueryGroupTask.java +++ b/server/src/main/java/org/opensearch/wlm/QueryGroupTask.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.common.annotation.PublicApi; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.tasks.TaskId; @@ -17,6 +18,7 @@ import java.util.Map; import java.util.Optional; +import java.util.function.LongSupplier; import java.util.function.Supplier; import static org.opensearch.search.SearchService.NO_TIMEOUT; @@ -24,15 +26,17 @@ /** * Base class to define QueryGroup tasks */ +@PublicApi(since = "2.18.0") public class QueryGroupTask extends CancellableTask { private static final Logger logger = LogManager.getLogger(QueryGroupTask.class); public static final String QUERY_GROUP_ID_HEADER = "queryGroupId"; public static final Supplier DEFAULT_QUERY_GROUP_ID_SUPPLIER = () -> "DEFAULT_QUERY_GROUP"; + private final LongSupplier nanoTimeSupplier; private String queryGroupId; public QueryGroupTask(long id, String type, String action, String description, TaskId parentTaskId, Map headers) { - this(id, type, action, description, parentTaskId, headers, NO_TIMEOUT); + this(id, type, action, description, parentTaskId, headers, NO_TIMEOUT, System::nanoTime); } public QueryGroupTask( @@ -43,8 +47,22 @@ public QueryGroupTask( TaskId parentTaskId, Map headers, TimeValue cancelAfterTimeInterval + ) { + this(id, type, action, description, parentTaskId, headers, cancelAfterTimeInterval, System::nanoTime); + } + + public QueryGroupTask( + long id, + String type, + String action, + String description, + TaskId parentTaskId, + Map headers, + TimeValue cancelAfterTimeInterval, + LongSupplier nanoTimeSupplier ) { super(id, type, action, description, parentTaskId, headers, cancelAfterTimeInterval); + this.nanoTimeSupplier = nanoTimeSupplier; } /** @@ -69,6 +87,10 @@ public final void setQueryGroupId(final ThreadContext threadContext) { .orElse(DEFAULT_QUERY_GROUP_ID_SUPPLIER.get()); } + public long getElapsedTime() { + return nanoTimeSupplier.getAsLong() - getStartTimeNanos(); + } + @Override public boolean shouldCancelChildrenOnCancellation() { return false; diff --git a/server/src/main/java/org/opensearch/wlm/ResourceType.java b/server/src/main/java/org/opensearch/wlm/ResourceType.java index b2485988c6ef6..01d1d1fb63880 100644 --- a/server/src/main/java/org/opensearch/wlm/ResourceType.java +++ b/server/src/main/java/org/opensearch/wlm/ResourceType.java @@ -10,8 +10,9 @@ import org.opensearch.common.annotation.PublicApi; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.tasks.resourcetracker.ResourceStats; -import org.opensearch.tasks.Task; +import org.opensearch.wlm.tracker.CpuUsageCalculator; +import org.opensearch.wlm.tracker.MemoryUsageCalculator; +import org.opensearch.wlm.tracker.ResourceUsageCalculator; import java.io.IOException; import java.util.List; @@ -25,19 +26,25 @@ */ @PublicApi(since = "2.17.0") public enum ResourceType { - CPU("cpu", task -> task.getTotalResourceUtilization(ResourceStats.CPU), true), - MEMORY("memory", task -> task.getTotalResourceUtilization(ResourceStats.MEMORY), true); + CPU("cpu", true, CpuUsageCalculator.INSTANCE, WorkloadManagementSettings::getNodeLevelCpuCancellationThreshold), + MEMORY("memory", true, MemoryUsageCalculator.INSTANCE, WorkloadManagementSettings::getNodeLevelMemoryCancellationThreshold); private final String name; - private final Function getResourceUsage; private final boolean statsEnabled; - + private final ResourceUsageCalculator resourceUsageCalculator; + private final Function nodeLevelThresholdSupplier; private static List sortedValues = List.of(CPU, MEMORY); - ResourceType(String name, Function getResourceUsage, boolean statsEnabled) { + ResourceType( + String name, + boolean statsEnabled, + ResourceUsageCalculator resourceUsageCalculator, + Function nodeLevelThresholdSupplier + ) { this.name = name; - this.getResourceUsage = getResourceUsage; this.statsEnabled = statsEnabled; + this.resourceUsageCalculator = resourceUsageCalculator; + this.nodeLevelThresholdSupplier = nodeLevelThresholdSupplier; } /** @@ -62,20 +69,18 @@ public String getName() { return name; } - /** - * Gets the resource usage for a given resource type and task. - * - * @param task the task for which to calculate resource usage - * @return the resource usage - */ - public long getResourceUsage(Task task) { - return getResourceUsage.apply(task); - } - public boolean hasStatsEnabled() { return statsEnabled; } + public ResourceUsageCalculator getResourceUsageCalculator() { + return resourceUsageCalculator; + } + + public double getNodeLevelThreshold(WorkloadManagementSettings settings) { + return nodeLevelThresholdSupplier.apply(settings); + } + public static List getSortedValues() { return sortedValues; } diff --git a/server/src/main/java/org/opensearch/wlm/WorkloadManagementSettings.java b/server/src/main/java/org/opensearch/wlm/WorkloadManagementSettings.java index b104925df77b3..b3577c1b3219d 100644 --- a/server/src/main/java/org/opensearch/wlm/WorkloadManagementSettings.java +++ b/server/src/main/java/org/opensearch/wlm/WorkloadManagementSettings.java @@ -8,6 +8,7 @@ package org.opensearch.wlm; +import org.opensearch.common.annotation.PublicApi; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; @@ -15,6 +16,7 @@ /** * Main class to declare Workload Management related settings */ +@PublicApi(since = "2.18.0") public class WorkloadManagementSettings { private static final Double DEFAULT_NODE_LEVEL_MEMORY_REJECTION_THRESHOLD = 0.8; private static final Double DEFAULT_NODE_LEVEL_MEMORY_CANCELLATION_THRESHOLD = 0.9; diff --git a/server/src/main/java/org/opensearch/wlm/cancellation/MaximumResourceTaskSelectionStrategy.java b/server/src/main/java/org/opensearch/wlm/cancellation/MaximumResourceTaskSelectionStrategy.java new file mode 100644 index 0000000000000..ffb326c07e7ac --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/cancellation/MaximumResourceTaskSelectionStrategy.java @@ -0,0 +1,71 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.wlm.QueryGroupTask; +import org.opensearch.wlm.ResourceType; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; + +import static org.opensearch.wlm.cancellation.QueryGroupTaskCancellationService.MIN_VALUE; + +/** + * Represents the highest resource consuming task first selection strategy. + */ +public class MaximumResourceTaskSelectionStrategy implements TaskSelectionStrategy { + + public MaximumResourceTaskSelectionStrategy() {} + + /** + * Returns a comparator that defines the sorting condition for tasks. + * This is the default implementation since the most resource consuming tasks are the likely to regress the performance. + * from resiliency point of view it makes sense to cancel them first + * + * @return The comparator + */ + private Comparator sortingCondition(ResourceType resourceType) { + return Comparator.comparingDouble(task -> resourceType.getResourceUsageCalculator().calculateTaskResourceUsage(task)); + } + + /** + * Selects tasks for cancellation based on the provided limit and resource type. + * The tasks are sorted based on the sorting condition and then selected until the accumulated resource usage reaches the limit. + * + * @param tasks The list of tasks from which to select + * @param limit The limit on the accumulated resource usage + * @param resourceType The type of resource to consider + * @return The list of selected tasks + * @throws IllegalArgumentException If the limit is less than zero + */ + public List selectTasksForCancellation(List tasks, double limit, ResourceType resourceType) { + if (limit < 0) { + throw new IllegalArgumentException("limit has to be greater than zero"); + } + if (limit < MIN_VALUE) { + return Collections.emptyList(); + } + + List sortedTasks = tasks.stream().sorted(sortingCondition(resourceType).reversed()).collect(Collectors.toList()); + + List selectedTasks = new ArrayList<>(); + double accumulated = 0; + for (QueryGroupTask task : sortedTasks) { + selectedTasks.add(task); + accumulated += resourceType.getResourceUsageCalculator().calculateTaskResourceUsage(task); + if ((accumulated - limit) > MIN_VALUE) { + break; + } + } + return selectedTasks; + } +} diff --git a/server/src/main/java/org/opensearch/wlm/cancellation/QueryGroupTaskCancellationService.java b/server/src/main/java/org/opensearch/wlm/cancellation/QueryGroupTaskCancellationService.java new file mode 100644 index 0000000000000..a2c97c8d8635b --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/cancellation/QueryGroupTaskCancellationService.java @@ -0,0 +1,205 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.cluster.metadata.QueryGroup; +import org.opensearch.tasks.CancellableTask; +import org.opensearch.tasks.TaskCancellation; +import org.opensearch.wlm.MutableQueryGroupFragment.ResiliencyMode; +import org.opensearch.wlm.QueryGroupLevelResourceUsageView; +import org.opensearch.wlm.QueryGroupTask; +import org.opensearch.wlm.ResourceType; +import org.opensearch.wlm.WorkloadManagementSettings; +import org.opensearch.wlm.tracker.QueryGroupResourceUsageTrackerService; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.function.BooleanSupplier; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import static org.opensearch.wlm.tracker.QueryGroupResourceUsageTrackerService.TRACKED_RESOURCES; + +/** + * Manages the cancellation of tasks enforced by QueryGroup thresholds on resource usage criteria. + * This class utilizes a strategy pattern through {@link MaximumResourceTaskSelectionStrategy} to identify tasks that exceed + * predefined resource usage limits and are therefore eligible for cancellation. + * + *

The cancellation process is initiated by evaluating the resource usage of each QueryGroup against its + * resource limits. Tasks that contribute to exceeding these limits are selected for cancellation based on the + * implemented task selection strategy.

+ * + *

Instances of this class are configured with a map linking QueryGroup IDs to their corresponding resource usage + * views, a set of active QueryGroups, and a task selection strategy. These components collectively facilitate the + * identification and cancellation of tasks that threaten to breach QueryGroup resource limits.

+ * + * @see MaximumResourceTaskSelectionStrategy + * @see QueryGroup + * @see ResourceType + */ +public class QueryGroupTaskCancellationService { + public static final double MIN_VALUE = 1e-9; + + private final WorkloadManagementSettings workloadManagementSettings; + private final TaskSelectionStrategy taskSelectionStrategy; + private final QueryGroupResourceUsageTrackerService resourceUsageTrackerService; + // a map of QueryGroupId to its corresponding QueryGroupLevelResourceUsageView object + Map queryGroupLevelResourceUsageViews; + private final Collection activeQueryGroups; + private final Collection deletedQueryGroups; + + public QueryGroupTaskCancellationService( + WorkloadManagementSettings workloadManagementSettings, + TaskSelectionStrategy taskSelectionStrategy, + QueryGroupResourceUsageTrackerService resourceUsageTrackerService, + Collection activeQueryGroups, + Collection deletedQueryGroups + ) { + this.workloadManagementSettings = workloadManagementSettings; + this.taskSelectionStrategy = taskSelectionStrategy; + this.resourceUsageTrackerService = resourceUsageTrackerService; + this.activeQueryGroups = activeQueryGroups; + this.deletedQueryGroups = deletedQueryGroups; + } + + /** + * Cancel tasks based on the implemented strategy. + */ + public final void cancelTasks(BooleanSupplier isNodeInDuress) { + queryGroupLevelResourceUsageViews = resourceUsageTrackerService.constructQueryGroupLevelUsageViews(); + // cancel tasks from QueryGroups that are in Enforced mode that are breaching their resource limits + cancelTasks(ResiliencyMode.ENFORCED); + // if the node is in duress, cancel tasks accordingly. + handleNodeDuress(isNodeInDuress); + } + + private void handleNodeDuress(BooleanSupplier isNodeInDuress) { + if (!isNodeInDuress.getAsBoolean()) { + return; + } + // List of tasks to be executed in order if the node is in duress + List> duressActions = List.of(v -> cancelTasksFromDeletedQueryGroups(), v -> cancelTasks(ResiliencyMode.SOFT)); + + for (Consumer duressAction : duressActions) { + if (!isNodeInDuress.getAsBoolean()) { + break; + } + duressAction.accept(null); + } + } + + private void cancelTasksFromDeletedQueryGroups() { + cancelTasks(getAllCancellableTasks(this.deletedQueryGroups)); + } + + /** + * Get all cancellable tasks from the QueryGroups. + * + * @return List of tasks that can be cancelled + */ + List getAllCancellableTasks(ResiliencyMode resiliencyMode) { + return getAllCancellableTasks( + activeQueryGroups.stream().filter(queryGroup -> queryGroup.getResiliencyMode() == resiliencyMode).collect(Collectors.toList()) + ); + } + + /** + * Get all cancellable tasks from the given QueryGroups. + * + * @return List of tasks that can be cancelled + */ + List getAllCancellableTasks(Collection queryGroups) { + List taskCancellations = new ArrayList<>(); + for (QueryGroup queryGroup : queryGroups) { + final List reasons = new ArrayList<>(); + List selectedTasks = new ArrayList<>(); + for (ResourceType resourceType : TRACKED_RESOURCES) { + // We need to consider the already selected tasks since those tasks also consumed the resources + double excessUsage = getExcessUsage(queryGroup, resourceType) - resourceType.getResourceUsageCalculator() + .calculateResourceUsage(selectedTasks); + if (excessUsage > MIN_VALUE) { + reasons.add(new TaskCancellation.Reason(generateReasonString(queryGroup, resourceType), 1)); + // TODO: We will need to add the cancellation callback for these resources for the queryGroup to reflect stats + + // Only add tasks not already added to avoid double cancellations + selectedTasks.addAll( + taskSelectionStrategy.selectTasksForCancellation(getTasksFor(queryGroup), excessUsage, resourceType) + .stream() + .filter(x -> selectedTasks.stream().noneMatch(y -> x.getId() != y.getId())) + .collect(Collectors.toList()) + ); + } + } + + if (!reasons.isEmpty()) { + taskCancellations.addAll( + selectedTasks.stream().map(task -> createTaskCancellation(task, reasons)).collect(Collectors.toList()) + ); + } + } + return taskCancellations; + } + + private String generateReasonString(QueryGroup queryGroup, ResourceType resourceType) { + final double currentUsage = getCurrentUsage(queryGroup, resourceType); + return "QueryGroup ID : " + + queryGroup.get_id() + + " breached the resource limit: (" + + currentUsage + + " > " + + queryGroup.getResourceLimits().get(resourceType) + + ") for resource type : " + + resourceType.getName(); + } + + private List getTasksFor(QueryGroup queryGroup) { + return queryGroupLevelResourceUsageViews.get(queryGroup.get_id()).getActiveTasks(); + } + + private void cancelTasks(ResiliencyMode resiliencyMode) { + cancelTasks(getAllCancellableTasks(resiliencyMode)); + } + + private void cancelTasks(List cancellableTasks) { + cancellableTasks.forEach(TaskCancellation::cancel); + } + + private TaskCancellation createTaskCancellation(CancellableTask task, List reasons) { + return new TaskCancellation(task, reasons, List.of(this::callbackOnCancel)); + } + + private double getExcessUsage(QueryGroup queryGroup, ResourceType resourceType) { + if (queryGroup.getResourceLimits().get(resourceType) == null + || !queryGroupLevelResourceUsageViews.containsKey(queryGroup.get_id())) { + return 0; + } + return getCurrentUsage(queryGroup, resourceType) - getNormalisedThreshold(queryGroup, resourceType); + } + + private double getCurrentUsage(QueryGroup queryGroup, ResourceType resourceType) { + final QueryGroupLevelResourceUsageView queryGroupResourceUsageView = queryGroupLevelResourceUsageViews.get(queryGroup.get_id()); + return queryGroupResourceUsageView.getResourceUsageData().get(resourceType); + } + + /** + * normalises configured value with respect to node level cancellation thresholds + * @param queryGroup instance + * @return normalised value with respect to node level cancellation thresholds + */ + private double getNormalisedThreshold(QueryGroup queryGroup, ResourceType resourceType) { + double nodeLevelCancellationThreshold = resourceType.getNodeLevelThreshold(workloadManagementSettings); + return queryGroup.getResourceLimits().get(resourceType) * nodeLevelCancellationThreshold; + } + + private void callbackOnCancel() { + // TODO Implement callback logic here mostly used for Stats + } +} diff --git a/server/src/main/java/org/opensearch/wlm/cancellation/TaskSelectionStrategy.java b/server/src/main/java/org/opensearch/wlm/cancellation/TaskSelectionStrategy.java new file mode 100644 index 0000000000000..63fbf9b791a33 --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/cancellation/TaskSelectionStrategy.java @@ -0,0 +1,28 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.wlm.QueryGroupTask; +import org.opensearch.wlm.ResourceType; + +import java.util.List; + +/** + * This interface exposes a method which implementations can use + */ +public interface TaskSelectionStrategy { + /** + * Determines how the tasks are selected from the list of given tasks based on resource type + * @param tasks to select from + * @param limit min cumulative resource usage sum of selected tasks + * @param resourceType + * @return list of tasks + */ + List selectTasksForCancellation(List tasks, double limit, ResourceType resourceType); +} diff --git a/server/src/main/java/org/opensearch/wlm/cancellation/package-info.java b/server/src/main/java/org/opensearch/wlm/cancellation/package-info.java new file mode 100644 index 0000000000000..1ce7b571e9a9c --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/cancellation/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * Workload management resource based cancellation artifacts + */ +package org.opensearch.wlm.cancellation; diff --git a/server/src/main/java/org/opensearch/wlm/tracker/CpuUsageCalculator.java b/server/src/main/java/org/opensearch/wlm/tracker/CpuUsageCalculator.java new file mode 100644 index 0000000000000..05c84cd767b1f --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/tracker/CpuUsageCalculator.java @@ -0,0 +1,38 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.tracker; + +import org.opensearch.core.tasks.resourcetracker.ResourceStats; +import org.opensearch.wlm.QueryGroupTask; + +import java.util.List; + +/** + * class to help make cpu usage calculations for the query group + */ +public class CpuUsageCalculator extends ResourceUsageCalculator { + // This value should be initialised at the start time of the process and be used throughout the codebase + public static final int PROCESSOR_COUNT = Runtime.getRuntime().availableProcessors(); + public static final CpuUsageCalculator INSTANCE = new CpuUsageCalculator(); + + private CpuUsageCalculator() {} + + @Override + public double calculateResourceUsage(List tasks) { + double usage = tasks.stream().mapToDouble(this::calculateTaskResourceUsage).sum(); + + usage /= PROCESSOR_COUNT; + return usage; + } + + @Override + public double calculateTaskResourceUsage(QueryGroupTask task) { + return (1.0f * task.getTotalResourceUtilization(ResourceStats.CPU)) / task.getElapsedTime(); + } +} diff --git a/server/src/main/java/org/opensearch/wlm/tracker/MemoryUsageCalculator.java b/server/src/main/java/org/opensearch/wlm/tracker/MemoryUsageCalculator.java new file mode 100644 index 0000000000000..fb66ff47f58d0 --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/tracker/MemoryUsageCalculator.java @@ -0,0 +1,35 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.tracker; + +import org.opensearch.core.tasks.resourcetracker.ResourceStats; +import org.opensearch.monitor.jvm.JvmStats; +import org.opensearch.wlm.QueryGroupTask; + +import java.util.List; + +/** + * class to help make memory usage calculations for the query group + */ +public class MemoryUsageCalculator extends ResourceUsageCalculator { + public static final long HEAP_SIZE_BYTES = JvmStats.jvmStats().getMem().getHeapMax().getBytes(); + public static final MemoryUsageCalculator INSTANCE = new MemoryUsageCalculator(); + + private MemoryUsageCalculator() {} + + @Override + public double calculateResourceUsage(List tasks) { + return tasks.stream().mapToDouble(this::calculateTaskResourceUsage).sum(); + } + + @Override + public double calculateTaskResourceUsage(QueryGroupTask task) { + return (1.0f * task.getTotalResourceUtilization(ResourceStats.MEMORY)) / HEAP_SIZE_BYTES; + } +} diff --git a/server/src/main/java/org/opensearch/wlm/tracker/QueryGroupResourceUsageTrackerService.java b/server/src/main/java/org/opensearch/wlm/tracker/QueryGroupResourceUsageTrackerService.java index 15852b5bbe6a8..b23d9ff342139 100644 --- a/server/src/main/java/org/opensearch/wlm/tracker/QueryGroupResourceUsageTrackerService.java +++ b/server/src/main/java/org/opensearch/wlm/tracker/QueryGroupResourceUsageTrackerService.java @@ -8,7 +8,6 @@ package org.opensearch.wlm.tracker; -import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.wlm.QueryGroupLevelResourceUsageView; import org.opensearch.wlm.QueryGroupTask; @@ -25,7 +24,6 @@ * This class tracks resource usage per QueryGroup */ public class QueryGroupResourceUsageTrackerService { - public static final EnumSet TRACKED_RESOURCES = EnumSet.allOf(ResourceType.class); private final TaskResourceTrackingService taskResourceTrackingService; @@ -44,19 +42,16 @@ public QueryGroupResourceUsageTrackerService(TaskResourceTrackingService taskRes * @return Map of QueryGroup views */ public Map constructQueryGroupLevelUsageViews() { - final Map> tasksByQueryGroup = getTasksGroupedByQueryGroup(); + final Map> tasksByQueryGroup = getTasksGroupedByQueryGroup(); final Map queryGroupViews = new HashMap<>(); // Iterate over each QueryGroup entry - for (Map.Entry> queryGroupEntry : tasksByQueryGroup.entrySet()) { - // Compute the QueryGroup usage - final EnumMap queryGroupUsage = new EnumMap<>(ResourceType.class); + for (Map.Entry> queryGroupEntry : tasksByQueryGroup.entrySet()) { + // Compute the QueryGroup resource usage + final Map queryGroupUsage = new EnumMap<>(ResourceType.class); for (ResourceType resourceType : TRACKED_RESOURCES) { - long queryGroupResourceUsage = 0; - for (Task task : queryGroupEntry.getValue()) { - queryGroupResourceUsage += resourceType.getResourceUsage(task); - } - queryGroupUsage.put(resourceType, queryGroupResourceUsage); + double usage = resourceType.getResourceUsageCalculator().calculateResourceUsage(queryGroupEntry.getValue()); + queryGroupUsage.put(resourceType, usage); } // Add to the QueryGroup View @@ -73,12 +68,12 @@ public Map constructQueryGroupLevelUsa * * @return Map of tasks grouped by QueryGroup */ - private Map> getTasksGroupedByQueryGroup() { + private Map> getTasksGroupedByQueryGroup() { return taskResourceTrackingService.getResourceAwareTasks() .values() .stream() .filter(QueryGroupTask.class::isInstance) .map(QueryGroupTask.class::cast) - .collect(Collectors.groupingBy(QueryGroupTask::getQueryGroupId, Collectors.mapping(task -> (Task) task, Collectors.toList()))); + .collect(Collectors.groupingBy(QueryGroupTask::getQueryGroupId, Collectors.mapping(task -> task, Collectors.toList()))); } } diff --git a/server/src/main/java/org/opensearch/wlm/tracker/ResourceUsageCalculator.java b/server/src/main/java/org/opensearch/wlm/tracker/ResourceUsageCalculator.java new file mode 100644 index 0000000000000..bc8317cbfbf92 --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/tracker/ResourceUsageCalculator.java @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.tracker; + +import org.opensearch.common.annotation.PublicApi; +import org.opensearch.wlm.QueryGroupTask; + +import java.util.List; + +/** + * This class is used to track query group level resource usage + */ +@PublicApi(since = "2.18.0") +public abstract class ResourceUsageCalculator { + /** + * calculates the current resource usage for the query group + * + * @param tasks list of tasks in the query group + */ + public abstract double calculateResourceUsage(List tasks); + + /** + * calculates the task level resource usage + * @param task QueryGroupTask + * @return task level resource usage + */ + public abstract double calculateTaskResourceUsage(QueryGroupTask task); +} diff --git a/server/src/test/java/org/opensearch/wlm/QueryGroupLevelResourceUsageViewTests.java b/server/src/test/java/org/opensearch/wlm/QueryGroupLevelResourceUsageViewTests.java index 532bf3de95bd6..0c7eb721806d5 100644 --- a/server/src/test/java/org/opensearch/wlm/QueryGroupLevelResourceUsageViewTests.java +++ b/server/src/test/java/org/opensearch/wlm/QueryGroupLevelResourceUsageViewTests.java @@ -8,23 +8,34 @@ package org.opensearch.wlm; -import org.opensearch.action.search.SearchAction; -import org.opensearch.core.tasks.TaskId; -import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.wlm.tracker.ResourceUsageCalculatorTrackerServiceTests; -import java.util.Collections; import java.util.List; import java.util.Map; +import static org.opensearch.wlm.cancellation.QueryGroupTaskCancellationService.MIN_VALUE; +import static org.opensearch.wlm.tracker.CpuUsageCalculator.PROCESSOR_COUNT; +import static org.opensearch.wlm.tracker.MemoryUsageCalculator.HEAP_SIZE_BYTES; +import static org.opensearch.wlm.tracker.ResourceUsageCalculatorTests.createMockTaskWithResourceStats; +import static org.mockito.Mockito.mock; + public class QueryGroupLevelResourceUsageViewTests extends OpenSearchTestCase { - Map resourceUsage; - List activeTasks; + Map resourceUsage; + List activeTasks; + ResourceUsageCalculatorTrackerServiceTests.TestClock clock; + WorkloadManagementSettings settings; public void setUp() throws Exception { super.setUp(); - resourceUsage = Map.of(ResourceType.fromName("memory"), 34L, ResourceType.fromName("cpu"), 12L); - activeTasks = List.of(getRandomTask(4321)); + settings = mock(WorkloadManagementSettings.class); + clock = new ResourceUsageCalculatorTrackerServiceTests.TestClock(); + activeTasks = List.of(createMockTaskWithResourceStats(QueryGroupTask.class, 100, 200, 0, 1)); + clock.fastForwardBy(300); + double memoryUsage = 200.0 / HEAP_SIZE_BYTES; + double cpuUsage = 100.0 / (PROCESSOR_COUNT * 300.0); + + resourceUsage = Map.of(ResourceType.MEMORY, memoryUsage, ResourceType.CPU, cpuUsage); } public void testGetResourceUsageData() { @@ -32,7 +43,7 @@ public void testGetResourceUsageData() { resourceUsage, activeTasks ); - Map resourceUsageData = queryGroupLevelResourceUsageView.getResourceUsageData(); + Map resourceUsageData = queryGroupLevelResourceUsageView.getResourceUsageData(); assertTrue(assertResourceUsageData(resourceUsageData)); } @@ -41,23 +52,13 @@ public void testGetActiveTasks() { resourceUsage, activeTasks ); - List activeTasks = queryGroupLevelResourceUsageView.getActiveTasks(); + List activeTasks = queryGroupLevelResourceUsageView.getActiveTasks(); assertEquals(1, activeTasks.size()); - assertEquals(4321, activeTasks.get(0).getId()); + assertEquals(1, activeTasks.get(0).getId()); } - private boolean assertResourceUsageData(Map resourceUsageData) { - return resourceUsageData.get(ResourceType.fromName("memory")) == 34L && resourceUsageData.get(ResourceType.fromName("cpu")) == 12L; - } - - private Task getRandomTask(long id) { - return new Task( - id, - "transport", - SearchAction.NAME, - "test description", - new TaskId(randomLong() + ":" + randomLong()), - Collections.emptyMap() - ); + private boolean assertResourceUsageData(Map resourceUsageData) { + return (resourceUsageData.get(ResourceType.MEMORY) - 200.0 / HEAP_SIZE_BYTES) <= MIN_VALUE + && (resourceUsageData.get(ResourceType.CPU) - 100.0 / (300)) < MIN_VALUE; } } diff --git a/server/src/test/java/org/opensearch/wlm/ResourceTypeTests.java b/server/src/test/java/org/opensearch/wlm/ResourceTypeTests.java index 737cbb37b554c..16bd8b7e66266 100644 --- a/server/src/test/java/org/opensearch/wlm/ResourceTypeTests.java +++ b/server/src/test/java/org/opensearch/wlm/ResourceTypeTests.java @@ -8,14 +8,8 @@ package org.opensearch.wlm; -import org.opensearch.action.search.SearchShardTask; -import org.opensearch.core.tasks.resourcetracker.ResourceStats; -import org.opensearch.tasks.CancellableTask; import org.opensearch.test.OpenSearchTestCase; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - public class ResourceTypeTests extends OpenSearchTestCase { public void testFromName() { @@ -35,17 +29,4 @@ public void testGetName() { assertEquals("cpu", ResourceType.CPU.getName()); assertEquals("memory", ResourceType.MEMORY.getName()); } - - public void testGetResourceUsage() { - SearchShardTask mockTask = createMockTask(SearchShardTask.class, 100, 200); - assertEquals(100, ResourceType.CPU.getResourceUsage(mockTask)); - assertEquals(200, ResourceType.MEMORY.getResourceUsage(mockTask)); - } - - private T createMockTask(Class type, long cpuUsage, long heapUsage) { - T task = mock(type); - when(task.getTotalResourceUtilization(ResourceStats.CPU)).thenReturn(cpuUsage); - when(task.getTotalResourceUtilization(ResourceStats.MEMORY)).thenReturn(heapUsage); - return task; - } } diff --git a/server/src/test/java/org/opensearch/wlm/cancellation/MaximumResourceTaskSelectionStrategyTests.java b/server/src/test/java/org/opensearch/wlm/cancellation/MaximumResourceTaskSelectionStrategyTests.java new file mode 100644 index 0000000000000..dc79822c59c49 --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/cancellation/MaximumResourceTaskSelectionStrategyTests.java @@ -0,0 +1,126 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.action.search.SearchAction; +import org.opensearch.action.search.SearchTask; +import org.opensearch.core.tasks.TaskId; +import org.opensearch.core.tasks.resourcetracker.ResourceStats; +import org.opensearch.core.tasks.resourcetracker.ResourceStatsType; +import org.opensearch.core.tasks.resourcetracker.ResourceUsageMetric; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.wlm.QueryGroupTask; +import org.opensearch.wlm.ResourceType; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.IntStream; + +import static org.opensearch.wlm.cancellation.QueryGroupTaskCancellationService.MIN_VALUE; +import static org.opensearch.wlm.tracker.MemoryUsageCalculator.HEAP_SIZE_BYTES; + +public class MaximumResourceTaskSelectionStrategyTests extends OpenSearchTestCase { + + public void testSelectTasksToCancelSelectsTasksMeetingThreshold_ifReduceByIsGreaterThanZero() { + MaximumResourceTaskSelectionStrategy testHighestResourceConsumingTaskFirstSelectionStrategy = + new MaximumResourceTaskSelectionStrategy(); + double reduceBy = 50000.0 / HEAP_SIZE_BYTES; + ResourceType resourceType = ResourceType.MEMORY; + List tasks = getListOfTasks(100); + List selectedTasks = testHighestResourceConsumingTaskFirstSelectionStrategy.selectTasksForCancellation( + tasks, + reduceBy, + resourceType + ); + assertFalse(selectedTasks.isEmpty()); + boolean sortedInDescendingResourceUsage = IntStream.range(0, selectedTasks.size() - 1) + .noneMatch( + index -> ResourceType.MEMORY.getResourceUsageCalculator() + .calculateTaskResourceUsage(selectedTasks.get(index)) < ResourceType.MEMORY.getResourceUsageCalculator() + .calculateTaskResourceUsage(selectedTasks.get(index + 1)) + ); + assertTrue(sortedInDescendingResourceUsage); + assertTrue(tasksUsageMeetsThreshold(selectedTasks, reduceBy)); + } + + public void testSelectTasksToCancelSelectsTasksMeetingThreshold_ifReduceByIsLesserThanZero() { + MaximumResourceTaskSelectionStrategy testHighestResourceConsumingTaskFirstSelectionStrategy = + new MaximumResourceTaskSelectionStrategy(); + double reduceBy = -50.0 / HEAP_SIZE_BYTES; + ResourceType resourceType = ResourceType.MEMORY; + List tasks = getListOfTasks(3); + try { + testHighestResourceConsumingTaskFirstSelectionStrategy.selectTasksForCancellation(tasks, reduceBy, resourceType); + } catch (Exception e) { + assertTrue(e instanceof IllegalArgumentException); + assertEquals("limit has to be greater than zero", e.getMessage()); + } + } + + public void testSelectTasksToCancelSelectsTasksMeetingThreshold_ifReduceByIsEqualToZero() { + MaximumResourceTaskSelectionStrategy testHighestResourceConsumingTaskFirstSelectionStrategy = + new MaximumResourceTaskSelectionStrategy(); + double reduceBy = 0.0; + ResourceType resourceType = ResourceType.MEMORY; + List tasks = getListOfTasks(50); + List selectedTasks = testHighestResourceConsumingTaskFirstSelectionStrategy.selectTasksForCancellation( + tasks, + reduceBy, + resourceType + ); + assertTrue(selectedTasks.isEmpty()); + } + + private boolean tasksUsageMeetsThreshold(List selectedTasks, double threshold) { + double memory = 0; + for (QueryGroupTask task : selectedTasks) { + memory += ResourceType.MEMORY.getResourceUsageCalculator().calculateTaskResourceUsage(task); + if ((memory - threshold) > MIN_VALUE) { + return true; + } + } + return false; + } + + private List getListOfTasks(int numberOfTasks) { + List tasks = new ArrayList<>(); + + while (tasks.size() < numberOfTasks) { + long id = randomLong(); + final QueryGroupTask task = getRandomSearchTask(id); + long initial_memory = randomLongBetween(1, 100); + + ResourceUsageMetric[] initialTaskResourceMetrics = new ResourceUsageMetric[] { + new ResourceUsageMetric(ResourceStats.MEMORY, initial_memory) }; + task.startThreadResourceTracking(id, ResourceStatsType.WORKER_STATS, initialTaskResourceMetrics); + + long memory = initial_memory + randomLongBetween(1, 10000); + + ResourceUsageMetric[] taskResourceMetrics = new ResourceUsageMetric[] { + new ResourceUsageMetric(ResourceStats.MEMORY, memory), }; + task.updateThreadResourceStats(id, ResourceStatsType.WORKER_STATS, taskResourceMetrics); + task.stopThreadResourceTracking(id, ResourceStatsType.WORKER_STATS); + tasks.add(task); + } + + return tasks; + } + + private QueryGroupTask getRandomSearchTask(long id) { + return new SearchTask( + id, + "transport", + SearchAction.NAME, + () -> "test description", + new TaskId(randomLong() + ":" + randomLong()), + Collections.emptyMap() + ); + } +} diff --git a/server/src/test/java/org/opensearch/wlm/cancellation/QueryGroupTaskCancellationServiceTests.java b/server/src/test/java/org/opensearch/wlm/cancellation/QueryGroupTaskCancellationServiceTests.java new file mode 100644 index 0000000000000..f7a49235efc69 --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/cancellation/QueryGroupTaskCancellationServiceTests.java @@ -0,0 +1,541 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.action.search.SearchAction; +import org.opensearch.cluster.metadata.QueryGroup; +import org.opensearch.core.tasks.TaskId; +import org.opensearch.tasks.TaskCancellation; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.wlm.MutableQueryGroupFragment; +import org.opensearch.wlm.MutableQueryGroupFragment.ResiliencyMode; +import org.opensearch.wlm.QueryGroupLevelResourceUsageView; +import org.opensearch.wlm.QueryGroupTask; +import org.opensearch.wlm.ResourceType; +import org.opensearch.wlm.WorkloadManagementSettings; +import org.opensearch.wlm.tracker.QueryGroupResourceUsageTrackerService; +import org.opensearch.wlm.tracker.ResourceUsageCalculatorTrackerServiceTests.TestClock; +import org.junit.Before; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class QueryGroupTaskCancellationServiceTests extends OpenSearchTestCase { + private static final String queryGroupId1 = "queryGroup1"; + private static final String queryGroupId2 = "queryGroup2"; + + private TestClock clock; + + private Map queryGroupLevelViews; + private Set activeQueryGroups; + private Set deletedQueryGroups; + private QueryGroupTaskCancellationService taskCancellation; + private WorkloadManagementSettings workloadManagementSettings; + private QueryGroupResourceUsageTrackerService resourceUsageTrackerService; + + @Before + public void setup() { + workloadManagementSettings = mock(WorkloadManagementSettings.class); + queryGroupLevelViews = new HashMap<>(); + activeQueryGroups = new HashSet<>(); + deletedQueryGroups = new HashSet<>(); + + clock = new TestClock(); + when(workloadManagementSettings.getNodeLevelCpuCancellationThreshold()).thenReturn(0.9); + when(workloadManagementSettings.getNodeLevelMemoryCancellationThreshold()).thenReturn(0.9); + resourceUsageTrackerService = mock(QueryGroupResourceUsageTrackerService.class); + taskCancellation = new QueryGroupTaskCancellationService( + workloadManagementSettings, + new MaximumResourceTaskSelectionStrategy(), + resourceUsageTrackerService, + activeQueryGroups, + deletedQueryGroups + ); + } + + public void testGetCancellableTasksFrom_setupAppropriateCancellationReasonAndScore() { + ResourceType resourceType = ResourceType.CPU; + double cpuUsage = 0.11; + double memoryUsage = 0.0; + Double threshold = 0.1; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + new MutableQueryGroupFragment(ResiliencyMode.ENFORCED, Map.of(resourceType, threshold)), + 1L + ); + clock.fastForwardBy(1000); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(); + when(mockView.getResourceUsageData()).thenReturn(Map.of(resourceType, cpuUsage, ResourceType.MEMORY, memoryUsage)); + queryGroupLevelViews.put(queryGroupId1, mockView); + taskCancellation.queryGroupLevelResourceUsageViews = queryGroupLevelViews; + + List cancellableTasksFrom = taskCancellation.getAllCancellableTasks(List.of(queryGroup1)); + assertEquals(2, cancellableTasksFrom.size()); + assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId()); + assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId()); + assertEquals(1, cancellableTasksFrom.get(0).getReasons().get(0).getCancellationScore()); + } + + public void testGetCancellableTasksFrom_returnsTasksWhenBreachingThreshold() { + ResourceType resourceType = ResourceType.CPU; + double cpuUsage = 0.11; + double memoryUsage = 0.0; + Double threshold = 0.1; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + new MutableQueryGroupFragment(ResiliencyMode.ENFORCED, Map.of(resourceType, threshold)), + 1L + ); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(); + when(mockView.getResourceUsageData()).thenReturn(Map.of(resourceType, cpuUsage, ResourceType.MEMORY, memoryUsage)); + queryGroupLevelViews.put(queryGroupId1, mockView); + taskCancellation.queryGroupLevelResourceUsageViews = queryGroupLevelViews; + + List cancellableTasksFrom = taskCancellation.getAllCancellableTasks(List.of(queryGroup1)); + assertEquals(2, cancellableTasksFrom.size()); + assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId()); + assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId()); + } + + public void testGetCancellableTasksFrom_returnsTasksWhenBreachingThresholdForMemory() { + ResourceType resourceType = ResourceType.MEMORY; + double cpuUsage = 0.0; + double memoryUsage = 0.11; + Double threshold = 0.1; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + new MutableQueryGroupFragment(ResiliencyMode.ENFORCED, Map.of(resourceType, threshold)), + 1L + ); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(); + when(mockView.getResourceUsageData()).thenReturn(Map.of(ResourceType.CPU, cpuUsage, resourceType, memoryUsage)); + + queryGroupLevelViews.put(queryGroupId1, mockView); + activeQueryGroups.add(queryGroup1); + taskCancellation.queryGroupLevelResourceUsageViews = queryGroupLevelViews; + + List cancellableTasksFrom = taskCancellation.getAllCancellableTasks(ResiliencyMode.ENFORCED); + assertEquals(2, cancellableTasksFrom.size()); + assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId()); + assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId()); + } + + public void testGetCancellableTasksFrom_returnsNoTasksWhenNotBreachingThreshold() { + ResourceType resourceType = ResourceType.CPU; + double cpuUsage = 0.81; + double memoryUsage = 0.0; + Double threshold = 0.9; + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + new MutableQueryGroupFragment(ResiliencyMode.ENFORCED, Map.of(resourceType, threshold)), + 1L + ); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(); + when(mockView.getResourceUsageData()).thenReturn(Map.of(ResourceType.CPU, cpuUsage, ResourceType.MEMORY, memoryUsage)); + queryGroupLevelViews.put(queryGroupId1, mockView); + activeQueryGroups.add(queryGroup1); + taskCancellation.queryGroupLevelResourceUsageViews = queryGroupLevelViews; + + List cancellableTasksFrom = taskCancellation.getAllCancellableTasks(List.of(queryGroup1)); + assertTrue(cancellableTasksFrom.isEmpty()); + } + + public void testGetCancellableTasksFrom_filtersQueryGroupCorrectly() { + ResourceType resourceType = ResourceType.CPU; + double usage = 0.02; + Double threshold = 0.01; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + new MutableQueryGroupFragment(ResiliencyMode.ENFORCED, Map.of(resourceType, threshold)), + 1L + ); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(); + queryGroupLevelViews.put(queryGroupId1, mockView); + activeQueryGroups.add(queryGroup1); + taskCancellation.queryGroupLevelResourceUsageViews = queryGroupLevelViews; + + QueryGroupTaskCancellationService taskCancellation = new QueryGroupTaskCancellationService( + workloadManagementSettings, + new MaximumResourceTaskSelectionStrategy(), + resourceUsageTrackerService, + activeQueryGroups, + deletedQueryGroups + ); + + List cancellableTasksFrom = taskCancellation.getAllCancellableTasks(ResiliencyMode.SOFT); + assertEquals(0, cancellableTasksFrom.size()); + } + + public void testCancelTasks_cancelsGivenTasks() { + ResourceType resourceType = ResourceType.CPU; + double cpuUsage = 0.011; + double memoryUsage = 0.011; + + Double threshold = 0.01; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + new MutableQueryGroupFragment(ResiliencyMode.ENFORCED, Map.of(resourceType, threshold, ResourceType.MEMORY, threshold)), + 1L + ); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(); + when(mockView.getResourceUsageData()).thenReturn(Map.of(ResourceType.CPU, cpuUsage, ResourceType.MEMORY, memoryUsage)); + + queryGroupLevelViews.put(queryGroupId1, mockView); + activeQueryGroups.add(queryGroup1); + + QueryGroupTaskCancellationService taskCancellation = new QueryGroupTaskCancellationService( + workloadManagementSettings, + new MaximumResourceTaskSelectionStrategy(), + resourceUsageTrackerService, + activeQueryGroups, + deletedQueryGroups + ); + + taskCancellation.queryGroupLevelResourceUsageViews = queryGroupLevelViews; + + List cancellableTasksFrom = taskCancellation.getAllCancellableTasks(ResiliencyMode.ENFORCED); + assertEquals(2, cancellableTasksFrom.size()); + assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId()); + assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId()); + + when(resourceUsageTrackerService.constructQueryGroupLevelUsageViews()).thenReturn(queryGroupLevelViews); + taskCancellation.cancelTasks(() -> false); + assertTrue(cancellableTasksFrom.get(0).getTask().isCancelled()); + assertTrue(cancellableTasksFrom.get(1).getTask().isCancelled()); + } + + public void testCancelTasks_cancelsTasksFromDeletedQueryGroups() { + ResourceType resourceType = ResourceType.CPU; + double activeQueryGroupCpuUsage = 0.01; + double activeQueryGroupMemoryUsage = 0.0; + double deletedQueryGroupCpuUsage = 0.01; + double deletedQueryGroupMemoryUsage = 0.0; + Double threshold = 0.01; + + QueryGroup activeQueryGroup = new QueryGroup( + "testQueryGroup", + queryGroupId1, + new MutableQueryGroupFragment(ResiliencyMode.ENFORCED, Map.of(resourceType, threshold)), + 1L + ); + + QueryGroup deletedQueryGroup = new QueryGroup( + "testQueryGroup", + queryGroupId2, + new MutableQueryGroupFragment(ResiliencyMode.ENFORCED, Map.of(resourceType, threshold)), + 1L + ); + + QueryGroupLevelResourceUsageView mockView1 = createResourceUsageViewMock(); + QueryGroupLevelResourceUsageView mockView2 = createResourceUsageViewMock( + resourceType, + deletedQueryGroupCpuUsage, + List.of(1000, 1001) + ); + + when(mockView1.getResourceUsageData()).thenReturn( + Map.of(ResourceType.CPU, activeQueryGroupCpuUsage, ResourceType.MEMORY, activeQueryGroupMemoryUsage) + ); + when(mockView2.getResourceUsageData()).thenReturn( + Map.of(ResourceType.CPU, deletedQueryGroupCpuUsage, ResourceType.MEMORY, deletedQueryGroupMemoryUsage) + ); + queryGroupLevelViews.put(queryGroupId1, mockView1); + queryGroupLevelViews.put(queryGroupId2, mockView2); + + activeQueryGroups.add(activeQueryGroup); + deletedQueryGroups.add(deletedQueryGroup); + + QueryGroupTaskCancellationService taskCancellation = new QueryGroupTaskCancellationService( + workloadManagementSettings, + new MaximumResourceTaskSelectionStrategy(), + resourceUsageTrackerService, + activeQueryGroups, + deletedQueryGroups + ); + + taskCancellation.queryGroupLevelResourceUsageViews = queryGroupLevelViews; + + List cancellableTasksFrom = taskCancellation.getAllCancellableTasks(ResiliencyMode.ENFORCED); + assertEquals(2, cancellableTasksFrom.size()); + assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId()); + assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId()); + + List cancellableTasksFromDeletedQueryGroups = taskCancellation.getAllCancellableTasks(List.of(deletedQueryGroup)); + assertEquals(2, cancellableTasksFromDeletedQueryGroups.size()); + assertEquals(1000, cancellableTasksFromDeletedQueryGroups.get(0).getTask().getId()); + assertEquals(1001, cancellableTasksFromDeletedQueryGroups.get(1).getTask().getId()); + + when(resourceUsageTrackerService.constructQueryGroupLevelUsageViews()).thenReturn(queryGroupLevelViews); + taskCancellation.cancelTasks(() -> true); + + assertTrue(cancellableTasksFrom.get(0).getTask().isCancelled()); + assertTrue(cancellableTasksFrom.get(1).getTask().isCancelled()); + assertTrue(cancellableTasksFromDeletedQueryGroups.get(0).getTask().isCancelled()); + assertTrue(cancellableTasksFromDeletedQueryGroups.get(1).getTask().isCancelled()); + } + + public void testCancelTasks_does_not_cancelTasksFromDeletedQueryGroups_whenNodeNotInDuress() { + ResourceType resourceType = ResourceType.CPU; + double activeQueryGroupCpuUsage = 0.11; + double activeQueryGroupMemoryUsage = 0.0; + double deletedQueryGroupCpuUsage = 0.11; + double deletedQueryGroupMemoryUsage = 0.0; + + Double threshold = 0.01; + + QueryGroup activeQueryGroup = new QueryGroup( + "testQueryGroup", + queryGroupId1, + new MutableQueryGroupFragment(ResiliencyMode.ENFORCED, Map.of(resourceType, threshold)), + 1L + ); + + QueryGroup deletedQueryGroup = new QueryGroup( + "testQueryGroup", + queryGroupId2, + new MutableQueryGroupFragment(ResiliencyMode.ENFORCED, Map.of(resourceType, threshold)), + 1L + ); + + QueryGroupLevelResourceUsageView mockView1 = createResourceUsageViewMock(); + QueryGroupLevelResourceUsageView mockView2 = createResourceUsageViewMock( + resourceType, + deletedQueryGroupCpuUsage, + List.of(1000, 1001) + ); + + when(mockView1.getResourceUsageData()).thenReturn( + Map.of(ResourceType.CPU, activeQueryGroupCpuUsage, ResourceType.MEMORY, activeQueryGroupMemoryUsage) + ); + when(mockView2.getResourceUsageData()).thenReturn( + Map.of(ResourceType.CPU, deletedQueryGroupCpuUsage, ResourceType.MEMORY, deletedQueryGroupMemoryUsage) + ); + + queryGroupLevelViews.put(queryGroupId1, mockView1); + queryGroupLevelViews.put(queryGroupId2, mockView2); + activeQueryGroups.add(activeQueryGroup); + deletedQueryGroups.add(deletedQueryGroup); + + QueryGroupTaskCancellationService taskCancellation = new QueryGroupTaskCancellationService( + workloadManagementSettings, + new MaximumResourceTaskSelectionStrategy(), + resourceUsageTrackerService, + activeQueryGroups, + deletedQueryGroups + ); + taskCancellation.queryGroupLevelResourceUsageViews = queryGroupLevelViews; + + List cancellableTasksFrom = taskCancellation.getAllCancellableTasks(ResiliencyMode.ENFORCED); + assertEquals(2, cancellableTasksFrom.size()); + assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId()); + assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId()); + + List cancellableTasksFromDeletedQueryGroups = taskCancellation.getAllCancellableTasks(List.of(deletedQueryGroup)); + assertEquals(2, cancellableTasksFromDeletedQueryGroups.size()); + assertEquals(1000, cancellableTasksFromDeletedQueryGroups.get(0).getTask().getId()); + assertEquals(1001, cancellableTasksFromDeletedQueryGroups.get(1).getTask().getId()); + + when(resourceUsageTrackerService.constructQueryGroupLevelUsageViews()).thenReturn(queryGroupLevelViews); + taskCancellation.cancelTasks(() -> false); + + assertTrue(cancellableTasksFrom.get(0).getTask().isCancelled()); + assertTrue(cancellableTasksFrom.get(1).getTask().isCancelled()); + assertFalse(cancellableTasksFromDeletedQueryGroups.get(0).getTask().isCancelled()); + assertFalse(cancellableTasksFromDeletedQueryGroups.get(1).getTask().isCancelled()); + } + + public void testCancelTasks_cancelsGivenTasks_WhenNodeInDuress() { + ResourceType resourceType = ResourceType.CPU; + double cpuUsage1 = 0.11; + double memoryUsage1 = 0.0; + double cpuUsage2 = 0.11; + double memoryUsage2 = 0.0; + Double threshold = 0.01; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + new MutableQueryGroupFragment(ResiliencyMode.ENFORCED, Map.of(resourceType, threshold)), + 1L + ); + + QueryGroup queryGroup2 = new QueryGroup( + "testQueryGroup", + queryGroupId2, + new MutableQueryGroupFragment(ResiliencyMode.SOFT, Map.of(resourceType, threshold)), + 1L + ); + + QueryGroupLevelResourceUsageView mockView1 = createResourceUsageViewMock(); + when(mockView1.getResourceUsageData()).thenReturn(Map.of(ResourceType.CPU, cpuUsage1, ResourceType.MEMORY, memoryUsage1)); + queryGroupLevelViews.put(queryGroupId1, mockView1); + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(); + when(mockView.getActiveTasks()).thenReturn(List.of(getRandomSearchTask(5678), getRandomSearchTask(8765))); + when(mockView.getResourceUsageData()).thenReturn(Map.of(ResourceType.CPU, cpuUsage2, ResourceType.MEMORY, memoryUsage2)); + queryGroupLevelViews.put(queryGroupId2, mockView); + Collections.addAll(activeQueryGroups, queryGroup1, queryGroup2); + + QueryGroupTaskCancellationService taskCancellation = new QueryGroupTaskCancellationService( + workloadManagementSettings, + new MaximumResourceTaskSelectionStrategy(), + resourceUsageTrackerService, + activeQueryGroups, + deletedQueryGroups + ); + + taskCancellation.queryGroupLevelResourceUsageViews = queryGroupLevelViews; + + List cancellableTasksFrom = taskCancellation.getAllCancellableTasks(ResiliencyMode.ENFORCED); + assertEquals(2, cancellableTasksFrom.size()); + assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId()); + assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId()); + + List cancellableTasksFrom1 = taskCancellation.getAllCancellableTasks(ResiliencyMode.SOFT); + assertEquals(2, cancellableTasksFrom1.size()); + assertEquals(5678, cancellableTasksFrom1.get(0).getTask().getId()); + assertEquals(8765, cancellableTasksFrom1.get(1).getTask().getId()); + + when(resourceUsageTrackerService.constructQueryGroupLevelUsageViews()).thenReturn(queryGroupLevelViews); + taskCancellation.cancelTasks(() -> true); + assertTrue(cancellableTasksFrom.get(0).getTask().isCancelled()); + assertTrue(cancellableTasksFrom.get(1).getTask().isCancelled()); + assertTrue(cancellableTasksFrom1.get(0).getTask().isCancelled()); + assertTrue(cancellableTasksFrom1.get(1).getTask().isCancelled()); + } + + public void testGetAllCancellableTasks_ReturnsNoTasksWhenNotBreachingThresholds() { + ResourceType resourceType = ResourceType.CPU; + double queryGroupCpuUsage = 0.09; + double queryGroupMemoryUsage = 0.0; + Double threshold = 0.1; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + new MutableQueryGroupFragment(ResiliencyMode.ENFORCED, Map.of(resourceType, threshold)), + 1L + ); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(); + when(mockView.getResourceUsageData()).thenReturn( + Map.of(ResourceType.CPU, queryGroupCpuUsage, ResourceType.MEMORY, queryGroupMemoryUsage) + ); + queryGroupLevelViews.put(queryGroupId1, mockView); + activeQueryGroups.add(queryGroup1); + taskCancellation.queryGroupLevelResourceUsageViews = queryGroupLevelViews; + + List allCancellableTasks = taskCancellation.getAllCancellableTasks(ResiliencyMode.ENFORCED); + assertTrue(allCancellableTasks.isEmpty()); + } + + public void testGetAllCancellableTasks_ReturnsTasksWhenBreachingThresholds() { + ResourceType resourceType = ResourceType.CPU; + double cpuUsage = 0.11; + double memoryUsage = 0.0; + Double threshold = 0.01; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + new MutableQueryGroupFragment(ResiliencyMode.ENFORCED, Map.of(resourceType, threshold)), + 1L + ); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(); + when(mockView.getResourceUsageData()).thenReturn(Map.of(ResourceType.CPU, cpuUsage, ResourceType.MEMORY, memoryUsage)); + queryGroupLevelViews.put(queryGroupId1, mockView); + activeQueryGroups.add(queryGroup1); + taskCancellation.queryGroupLevelResourceUsageViews = queryGroupLevelViews; + + List allCancellableTasks = taskCancellation.getAllCancellableTasks(ResiliencyMode.ENFORCED); + assertEquals(2, allCancellableTasks.size()); + assertEquals(1234, allCancellableTasks.get(0).getTask().getId()); + assertEquals(4321, allCancellableTasks.get(1).getTask().getId()); + } + + public void testGetCancellableTasksFrom_doesNotReturnTasksWhenQueryGroupIdNotFound() { + ResourceType resourceType = ResourceType.CPU; + double usage = 0.11; + Double threshold = 0.01; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup1", + queryGroupId1, + new MutableQueryGroupFragment(ResiliencyMode.ENFORCED, Map.of(resourceType, threshold)), + 1L + ); + QueryGroup queryGroup2 = new QueryGroup( + "testQueryGroup2", + queryGroupId2, + new MutableQueryGroupFragment(ResiliencyMode.ENFORCED, Map.of(resourceType, threshold)), + 1L + ); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(); + queryGroupLevelViews.put(queryGroupId1, mockView); + activeQueryGroups.add(queryGroup1); + activeQueryGroups.add(queryGroup2); + taskCancellation.queryGroupLevelResourceUsageViews = queryGroupLevelViews; + + List cancellableTasksFrom = taskCancellation.getAllCancellableTasks(List.of(queryGroup2)); + assertEquals(0, cancellableTasksFrom.size()); + } + + private QueryGroupLevelResourceUsageView createResourceUsageViewMock() { + QueryGroupLevelResourceUsageView mockView = mock(QueryGroupLevelResourceUsageView.class); + when(mockView.getActiveTasks()).thenReturn(List.of(getRandomSearchTask(1234), getRandomSearchTask(4321))); + return mockView; + } + + private QueryGroupLevelResourceUsageView createResourceUsageViewMock(ResourceType resourceType, double usage, Collection ids) { + QueryGroupLevelResourceUsageView mockView = mock(QueryGroupLevelResourceUsageView.class); + when(mockView.getResourceUsageData()).thenReturn(Collections.singletonMap(resourceType, usage)); + when(mockView.getActiveTasks()).thenReturn(ids.stream().map(this::getRandomSearchTask).collect(Collectors.toList())); + return mockView; + } + + private QueryGroupTask getRandomSearchTask(long id) { + return new QueryGroupTask( + id, + "transport", + SearchAction.NAME, + "test description", + new TaskId(randomLong() + ":" + randomLong()), + Collections.emptyMap(), + null, + clock::getTime + ); + } +} diff --git a/server/src/test/java/org/opensearch/wlm/tracker/ResourceUsageCalculatorTests.java b/server/src/test/java/org/opensearch/wlm/tracker/ResourceUsageCalculatorTests.java new file mode 100644 index 0000000000000..21d9717a1aaca --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/tracker/ResourceUsageCalculatorTests.java @@ -0,0 +1,70 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.tracker; + +import org.opensearch.core.tasks.resourcetracker.ResourceStats; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.wlm.QueryGroupTask; +import org.opensearch.wlm.ResourceType; +import org.opensearch.wlm.tracker.ResourceUsageCalculatorTrackerServiceTests.TestClock; + +import java.util.List; + +import static org.opensearch.wlm.cancellation.QueryGroupTaskCancellationService.MIN_VALUE; +import static org.opensearch.wlm.tracker.CpuUsageCalculator.PROCESSOR_COUNT; +import static org.opensearch.wlm.tracker.MemoryUsageCalculator.HEAP_SIZE_BYTES; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ResourceUsageCalculatorTests extends OpenSearchTestCase { + + public void testQueryGroupCpuUsage() { + TestClock clock = new TestClock(); + long fastForwardTime = PROCESSOR_COUNT * 200L; + clock.fastForwardBy(fastForwardTime); + + double expectedQueryGroupCpuUsage = 1.0 / PROCESSOR_COUNT; + + QueryGroupTask mockTask = createMockTaskWithResourceStats(QueryGroupTask.class, fastForwardTime, 200, 0, 123); + when(mockTask.getElapsedTime()).thenReturn(fastForwardTime); + double actualUsage = ResourceType.CPU.getResourceUsageCalculator().calculateResourceUsage(List.of(mockTask)); + assertEquals(expectedQueryGroupCpuUsage, actualUsage, MIN_VALUE); + + double taskResourceUsage = ResourceType.CPU.getResourceUsageCalculator().calculateTaskResourceUsage(mockTask); + assertEquals(1.0, taskResourceUsage, MIN_VALUE); + } + + public void testQueryGroupMemoryUsage() { + QueryGroupTask mockTask = createMockTaskWithResourceStats(QueryGroupTask.class, 100, 200, 0, 123); + double actualMemoryUsage = ResourceType.MEMORY.getResourceUsageCalculator().calculateResourceUsage(List.of(mockTask)); + double expectedMemoryUsage = 200.0 / HEAP_SIZE_BYTES; + + assertEquals(expectedMemoryUsage, actualMemoryUsage, MIN_VALUE); + assertEquals( + 200.0 / HEAP_SIZE_BYTES, + ResourceType.MEMORY.getResourceUsageCalculator().calculateTaskResourceUsage(mockTask), + MIN_VALUE + ); + } + + public static T createMockTaskWithResourceStats( + Class type, + long cpuUsage, + long heapUsage, + long startTimeNanos, + long taskId + ) { + T task = mock(type); + when(task.getTotalResourceUtilization(ResourceStats.CPU)).thenReturn(cpuUsage); + when(task.getTotalResourceUtilization(ResourceStats.MEMORY)).thenReturn(heapUsage); + when(task.getStartTimeNanos()).thenReturn(startTimeNanos); + when(task.getId()).thenReturn(taskId); + return task; + } +} diff --git a/server/src/test/java/org/opensearch/wlm/tracker/QueryGroupResourceUsageTrackerServiceTests.java b/server/src/test/java/org/opensearch/wlm/tracker/ResourceUsageCalculatorTrackerServiceTests.java similarity index 69% rename from server/src/test/java/org/opensearch/wlm/tracker/QueryGroupResourceUsageTrackerServiceTests.java rename to server/src/test/java/org/opensearch/wlm/tracker/ResourceUsageCalculatorTrackerServiceTests.java index ca2891cb532f2..fe72bd6e710c8 100644 --- a/server/src/test/java/org/opensearch/wlm/tracker/QueryGroupResourceUsageTrackerServiceTests.java +++ b/server/src/test/java/org/opensearch/wlm/tracker/ResourceUsageCalculatorTrackerServiceTests.java @@ -9,10 +9,8 @@ package org.opensearch.wlm.tracker; import org.opensearch.action.search.SearchShardTask; -import org.opensearch.action.search.SearchTask; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.tasks.resourcetracker.ResourceStats; -import org.opensearch.tasks.CancellableTask; import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.test.OpenSearchTestCase; @@ -21,6 +19,7 @@ import org.opensearch.wlm.QueryGroupLevelResourceUsageView; import org.opensearch.wlm.QueryGroupTask; import org.opensearch.wlm.ResourceType; +import org.opensearch.wlm.WorkloadManagementSettings; import org.junit.After; import org.junit.Before; @@ -31,18 +30,38 @@ import java.util.concurrent.atomic.AtomicBoolean; import static org.opensearch.wlm.QueryGroupTask.QUERY_GROUP_ID_HEADER; +import static org.opensearch.wlm.cancellation.QueryGroupTaskCancellationService.MIN_VALUE; +import static org.opensearch.wlm.tracker.CpuUsageCalculator.PROCESSOR_COUNT; +import static org.opensearch.wlm.tracker.MemoryUsageCalculator.HEAP_SIZE_BYTES; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class QueryGroupResourceUsageTrackerServiceTests extends OpenSearchTestCase { +public class ResourceUsageCalculatorTrackerServiceTests extends OpenSearchTestCase { TestThreadPool threadPool; TaskResourceTrackingService mockTaskResourceTrackingService; QueryGroupResourceUsageTrackerService queryGroupResourceUsageTrackerService; + WorkloadManagementSettings settings; + + public static class TestClock { + long time; + + public void fastForwardBy(long nanos) { + time += nanos; + } + + public long getTime() { + return time; + } + } + + TestClock clock; @Before public void setup() { + clock = new TestClock(); + settings = mock(WorkloadManagementSettings.class); threadPool = new TestThreadPool(getTestName()); mockTaskResourceTrackingService = mock(TaskResourceTrackingService.class); queryGroupResourceUsageTrackerService = new QueryGroupResourceUsageTrackerService(mockTaskResourceTrackingService); @@ -55,16 +74,24 @@ public void cleanup() { public void testConstructQueryGroupLevelViews_CreatesQueryGroupLevelUsageView_WhenTasksArePresent() { List queryGroupIds = List.of("queryGroup1", "queryGroup2", "queryGroup3"); + clock.fastForwardBy(2000); Map activeSearchShardTasks = createActiveSearchShardTasks(queryGroupIds); when(mockTaskResourceTrackingService.getResourceAwareTasks()).thenReturn(activeSearchShardTasks); + Map stringQueryGroupLevelResourceUsageViewMap = queryGroupResourceUsageTrackerService .constructQueryGroupLevelUsageViews(); for (String queryGroupId : queryGroupIds) { assertEquals( - 400, - (long) stringQueryGroupLevelResourceUsageViewMap.get(queryGroupId).getResourceUsageData().get(ResourceType.MEMORY) + (400 * 1.0f) / HEAP_SIZE_BYTES, + stringQueryGroupLevelResourceUsageViewMap.get(queryGroupId).getResourceUsageData().get(ResourceType.MEMORY), + MIN_VALUE + ); + assertEquals( + (200 * 1.0f) / (PROCESSOR_COUNT * 2000), + stringQueryGroupLevelResourceUsageViewMap.get(queryGroupId).getResourceUsageData().get(ResourceType.CPU), + MIN_VALUE ); assertEquals(2, stringQueryGroupLevelResourceUsageViewMap.get(queryGroupId).getActiveTasks().size()); } @@ -78,14 +105,23 @@ public void testConstructQueryGroupLevelViews_CreatesQueryGroupLevelUsageView_Wh public void testConstructQueryGroupLevelUsageViews_WithTasksHavingDifferentResourceUsage() { Map activeSearchShardTasks = new HashMap<>(); + clock.fastForwardBy(2000); activeSearchShardTasks.put(1L, createMockTask(SearchShardTask.class, 100, 200, "queryGroup1")); activeSearchShardTasks.put(2L, createMockTask(SearchShardTask.class, 200, 400, "queryGroup1")); when(mockTaskResourceTrackingService.getResourceAwareTasks()).thenReturn(activeSearchShardTasks); - Map queryGroupViews = queryGroupResourceUsageTrackerService .constructQueryGroupLevelUsageViews(); - assertEquals(600, (long) queryGroupViews.get("queryGroup1").getResourceUsageData().get(ResourceType.MEMORY)); + assertEquals( + (double) 600 / HEAP_SIZE_BYTES, + queryGroupViews.get("queryGroup1").getResourceUsageData().get(ResourceType.MEMORY), + MIN_VALUE + ); + assertEquals( + ((double) 300) / (PROCESSOR_COUNT * 2000), + queryGroupViews.get("queryGroup1").getResourceUsageData().get(ResourceType.CPU), + MIN_VALUE + ); assertEquals(2, queryGroupViews.get("queryGroup1").getActiveTasks().size()); } @@ -100,19 +136,16 @@ private Map createActiveSearchShardTasks(List queryGroupIds) return activeSearchShardTasks; } - private T createMockTask(Class type, long cpuUsage, long heapUsage, String queryGroupId) { + private T createMockTask(Class type, long cpuUsage, long heapUsage, String queryGroupId) { T task = mock(type); - if (task instanceof SearchTask || task instanceof SearchShardTask) { - // Stash the current thread context to ensure that any existing context is preserved and restored after setting the query group - // ID. - try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().stashContext()) { - threadPool.getThreadContext().putHeader(QUERY_GROUP_ID_HEADER, queryGroupId); - ((QueryGroupTask) task).setQueryGroupId(threadPool.getThreadContext()); - } + try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().stashContext()) { + threadPool.getThreadContext().putHeader(QUERY_GROUP_ID_HEADER, queryGroupId); + task.setQueryGroupId(threadPool.getThreadContext()); } when(task.getTotalResourceUtilization(ResourceStats.CPU)).thenReturn(cpuUsage); when(task.getTotalResourceUtilization(ResourceStats.MEMORY)).thenReturn(heapUsage); when(task.getStartTimeNanos()).thenReturn((long) 0); + when(task.getElapsedTime()).thenReturn(clock.getTime()); AtomicBoolean isCancelled = new AtomicBoolean(false); doAnswer(invocation -> {