From f3a4813159869f910b5f6255c1e207f517f18811 Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Mon, 6 Nov 2023 12:08:18 +0100 Subject: [PATCH 01/21] Cleanup duplication and dead code in o.e.action.search (#101789) Removing some obvious code duplication and dead code found during today's test failure fixing. --- .../search/AbstractSearchAsyncAction.java | 27 ++---------- .../search/CanMatchPreFilterSearchPhase.java | 41 +++++-------------- .../action/search/ClearScrollResponse.java | 2 +- .../action/search/DfsQueryPhase.java | 3 +- .../action/search/MultiSearchRequest.java | 1 - .../search/MultiSearchRequestBuilder.java | 7 ---- .../search/OpenPointInTimeResponse.java | 14 ------- .../action/search/ParsedScrollId.java | 9 +--- .../search/QueryPhaseResultConsumer.java | 2 +- .../search/RestOpenPointInTimeAction.java | 3 +- .../action/search/SearchContextId.java | 2 +- .../action/search/SearchPhase.java | 25 +++++++++++ .../action/search/SearchPhaseController.java | 2 +- .../action/search/SearchRequestBuilder.java | 28 ------------- .../action/search/SearchResponse.java | 12 ++---- .../search/SearchScrollAsyncAction.java | 3 +- .../action/search/SearchTransportService.java | 20 +++++---- .../action/search/SearchType.java | 2 +- .../search/TransportMultiSearchAction.java | 17 +------- .../TransportOpenPointInTimeAction.java | 2 +- .../action/search/TransportSearchAction.java | 26 +++--------- .../action/search/TransportSearchHelper.java | 2 +- .../action/search/ParsedScrollIdTests.java | 2 +- .../search/SearchScrollAsyncActionTests.java | 2 +- 24 files changed, 74 insertions(+), 180 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java index 2f3266f9e0099..b56cb0ca5926c 100644 --- a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java @@ -228,27 +228,7 @@ public final void run() { skipShard(iterator); } if (shardsIts.size() > 0) { - assert request.allowPartialSearchResults() != null : "SearchRequest missing setting for allowPartialSearchResults"; - if (request.allowPartialSearchResults() == false) { - final StringBuilder missingShards = new StringBuilder(); - // Fail-fast verification of all shards being available - for (int index = 0; index < shardsIts.size(); index++) { - final SearchShardIterator shardRoutings = shardsIts.get(index); - if (shardRoutings.size() == 0) { - if (missingShards.length() > 0) { - missingShards.append(", "); - } - missingShards.append(shardRoutings.shardId()); - } - } - if (missingShards.length() > 0) { - // Status red - shard is missing all copies and would produce partial results for an index search - final String msg = "Search rejected due to missing shards [" - + missingShards - + "]. Consider using `allow_partial_search_results` setting to bypass this error."; - throw new SearchPhaseExecutionException(getName(), msg, null, ShardSearchFailure.EMPTY_ARRAY); - } - } + doCheckNoMissingShards(getName(), request, shardsIts); Version version = request.minCompatibleShardNode(); if (version != null && Version.CURRENT.minimumCompatibilityVersion().equals(version) == false) { if (checkMinimumVersion(shardsIts) == false) { @@ -434,7 +414,6 @@ public final void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPha logger.debug(() -> format("%s shards failed for phase: [%s]", numShardFailures, currentPhase.getName()), cause); } onPhaseFailure(currentPhase, "Partial shards failure", null); - return; } else { int discrepancy = getNumShards() - successfulOps.get(); assert discrepancy > 0 : "discrepancy: " + discrepancy; @@ -449,8 +428,8 @@ public final void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPha ); } onPhaseFailure(currentPhase, "Partial shards failure (" + discrepancy + " shards unavailable)", null); - return; } + return; } if (logger.isTraceEnabled()) { final String resultsFrom = results.getSuccessfulResults() @@ -840,7 +819,7 @@ void executeNext(Runnable runnable, Thread originalThread) { private static final class PendingExecutions { private final int permits; private int permitsTaken = 0; - private ArrayDeque queue = new ArrayDeque<>(); + private final ArrayDeque queue = new ArrayDeque<>(); PendingExecutions(int permits) { assert permits > 0 : "not enough permits: " + permits; diff --git a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java index cef6bf92cc5e6..6e553f254ee8b 100644 --- a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java @@ -31,7 +31,6 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.Transport; -import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; @@ -127,7 +126,7 @@ private static boolean assertSearchCoordinationThread() { } @Override - public void run() throws IOException { + public void run() { assert assertSearchCoordinationThread(); checkNoMissingShards(); Version version = request.minCompatibleShardNode(); @@ -159,9 +158,7 @@ private void runCoordinatorRewritePhase() { ); final ShardSearchRequest request = canMatchNodeRequest.createShardSearchRequest(buildShardLevelRequest(searchShardIterator)); if (searchShardIterator.prefiltered()) { - CanMatchShardResponse result = new CanMatchShardResponse(searchShardIterator.skip() == false, null); - result.setShardIndex(request.shardRequestIndex()); - results.consumeResult(result, () -> {}); + consumeResult(searchShardIterator.skip() == false, request); continue; } boolean canMatch = true; @@ -178,9 +175,7 @@ private void runCoordinatorRewritePhase() { if (canMatch) { matchedShardLevelRequests.add(searchShardIterator); } else { - CanMatchShardResponse result = new CanMatchShardResponse(canMatch, null); - result.setShardIndex(request.shardRequestIndex()); - results.consumeResult(result, () -> {}); + consumeResult(false, request); } } if (matchedShardLevelRequests.isEmpty()) { @@ -190,29 +185,15 @@ private void runCoordinatorRewritePhase() { } } + private void consumeResult(boolean canMatch, ShardSearchRequest request) { + CanMatchShardResponse result = new CanMatchShardResponse(canMatch, null); + result.setShardIndex(request.shardRequestIndex()); + results.consumeResult(result, () -> {}); + } + private void checkNoMissingShards() { assert assertSearchCoordinationThread(); - assert request.allowPartialSearchResults() != null : "SearchRequest missing setting for allowPartialSearchResults"; - if (request.allowPartialSearchResults() == false) { - final StringBuilder missingShards = new StringBuilder(); - // Fail-fast verification of all shards being available - for (int index = 0; index < shardsIts.size(); index++) { - final SearchShardIterator shardRoutings = shardsIts.get(index); - if (shardRoutings.size() == 0) { - if (missingShards.length() > 0) { - missingShards.append(", "); - } - missingShards.append(shardRoutings.shardId()); - } - } - if (missingShards.length() > 0) { - // Status red - shard is missing all copies and would produce partial results for an index search - final String msg = "Search rejected due to missing shards [" - + missingShards - + "]. Consider using `allow_partial_search_results` setting to bypass this error."; - throw new SearchPhaseExecutionException(getName(), msg, null, ShardSearchFailure.EMPTY_ARRAY); - } - } + doCheckNoMissingShards(getName(), request, shardsIts); } private Map> groupByNode(GroupShardsIterator shards) { @@ -425,7 +406,7 @@ public void onFailure(Exception e) { } @Override - protected void doRun() throws IOException { + protected void doRun() { CanMatchPreFilterSearchPhase.this.run(); } }); diff --git a/server/src/main/java/org/elasticsearch/action/search/ClearScrollResponse.java b/server/src/main/java/org/elasticsearch/action/search/ClearScrollResponse.java index 0a7b53ea8b9c4..8b1116951df82 100644 --- a/server/src/main/java/org/elasticsearch/action/search/ClearScrollResponse.java +++ b/server/src/main/java/org/elasticsearch/action/search/ClearScrollResponse.java @@ -85,7 +85,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws /** * Parse the clear scroll response body into a new {@link ClearScrollResponse} object */ - public static ClosePointInTimeResponse fromXContent(XContentParser parser) throws IOException { + public static ClosePointInTimeResponse fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } diff --git a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java index dca269f06a3d3..e010e840d3f2d 100644 --- a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java @@ -24,7 +24,6 @@ import org.elasticsearch.search.vectors.KnnScoreDocQueryBuilder; import org.elasticsearch.transport.Transport; -import java.io.IOException; import java.util.ArrayList; import java.util.Comparator; import java.util.List; @@ -71,7 +70,7 @@ final class DfsQueryPhase extends SearchPhase { } @Override - public void run() throws IOException { + public void run() { // TODO we can potentially also consume the actual per shard results from the initial phase here in the aggregateDfs // to free up memory early final CountedCollector counter = new CountedCollector<>( diff --git a/server/src/main/java/org/elasticsearch/action/search/MultiSearchRequest.java b/server/src/main/java/org/elasticsearch/action/search/MultiSearchRequest.java index e7d6eca23498f..cadcd6ca57334 100644 --- a/server/src/main/java/org/elasticsearch/action/search/MultiSearchRequest.java +++ b/server/src/main/java/org/elasticsearch/action/search/MultiSearchRequest.java @@ -51,7 +51,6 @@ */ public class MultiSearchRequest extends ActionRequest implements CompositeIndicesRequest { private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(RestSearchAction.class); - public static final String TYPES_DEPRECATION_MESSAGE = "[types removal]" + " Specifying types in search requests is deprecated."; public static final String FIRST_LINE_EMPTY_DEPRECATION_MESSAGE = "support for empty first line before any action metadata in msearch API is deprecated " + "and will be removed in the next major version"; diff --git a/server/src/main/java/org/elasticsearch/action/search/MultiSearchRequestBuilder.java b/server/src/main/java/org/elasticsearch/action/search/MultiSearchRequestBuilder.java index 6f1e8d429edab..57c536f3d371e 100644 --- a/server/src/main/java/org/elasticsearch/action/search/MultiSearchRequestBuilder.java +++ b/server/src/main/java/org/elasticsearch/action/search/MultiSearchRequestBuilder.java @@ -63,11 +63,4 @@ public MultiSearchRequestBuilder setIndicesOptions(IndicesOptions indicesOptions return this; } - /** - * Sets how many search requests specified in this multi search requests are allowed to be ran concurrently. - */ - public MultiSearchRequestBuilder setMaxConcurrentSearchRequests(int maxConcurrentSearchRequests) { - request().maxConcurrentSearchRequests(maxConcurrentSearchRequests); - return this; - } } diff --git a/server/src/main/java/org/elasticsearch/action/search/OpenPointInTimeResponse.java b/server/src/main/java/org/elasticsearch/action/search/OpenPointInTimeResponse.java index c6463bcb00f67..92a2a1503aefc 100644 --- a/server/src/main/java/org/elasticsearch/action/search/OpenPointInTimeResponse.java +++ b/server/src/main/java/org/elasticsearch/action/search/OpenPointInTimeResponse.java @@ -11,27 +11,16 @@ import org.elasticsearch.action.ActionResponse; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.xcontent.ConstructingObjectParser; -import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; import java.util.Objects; -import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; - public final class OpenPointInTimeResponse extends ActionResponse implements ToXContentObject { private static final ParseField ID = new ParseField("id"); - private static final ConstructingObjectParser PARSER; - - static { - PARSER = new ConstructingObjectParser<>("open_point_in_time", true, a -> new OpenPointInTimeResponse((String) a[0])); - PARSER.declareField(constructorArg(), (parser, context) -> parser.text(), ID, ObjectParser.ValueType.STRING); - } private final String pointInTimeId; public OpenPointInTimeResponse(String pointInTimeId) { @@ -60,7 +49,4 @@ public String getPointInTimeId() { return pointInTimeId; } - public static OpenPointInTimeResponse fromXContent(XContentParser parser) throws IOException { - return PARSER.parse(parser, null); - } } diff --git a/server/src/main/java/org/elasticsearch/action/search/ParsedScrollId.java b/server/src/main/java/org/elasticsearch/action/search/ParsedScrollId.java index ca68b1865495d..a9f3502bfa631 100644 --- a/server/src/main/java/org/elasticsearch/action/search/ParsedScrollId.java +++ b/server/src/main/java/org/elasticsearch/action/search/ParsedScrollId.java @@ -16,22 +16,15 @@ public class ParsedScrollId { public static final String QUERY_AND_FETCH_TYPE = "queryAndFetch"; - private final String source; - private final String type; private final SearchContextIdForNode[] context; - ParsedScrollId(String source, String type, SearchContextIdForNode[] context) { - this.source = source; + ParsedScrollId(String type, SearchContextIdForNode[] context) { this.type = type; this.context = context; } - public String getSource() { - return source; - } - public String getType() { return type; } diff --git a/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java b/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java index f78d5f4005755..ee956b5179902 100644 --- a/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java +++ b/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java @@ -520,7 +520,7 @@ private record MergeResult( private static class MergeTask { private final List emptyResults; private QuerySearchResult[] buffer; - private long aggsBufferSize; + private final long aggsBufferSize; private Runnable next; private MergeTask(QuerySearchResult[] buffer, long aggsBufferSize, List emptyResults, Runnable next) { diff --git a/server/src/main/java/org/elasticsearch/action/search/RestOpenPointInTimeAction.java b/server/src/main/java/org/elasticsearch/action/search/RestOpenPointInTimeAction.java index 5de59cc6ce878..815deac07dfcd 100644 --- a/server/src/main/java/org/elasticsearch/action/search/RestOpenPointInTimeAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/RestOpenPointInTimeAction.java @@ -18,7 +18,6 @@ import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestToXContentListener; -import java.io.IOException; import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.POST; @@ -37,7 +36,7 @@ public List routes() { } @Override - public RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) throws IOException { + public RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) { final String[] indices = Strings.splitStringByCommaToArray(request.param("index")); final OpenPointInTimeRequest openRequest = new OpenPointInTimeRequest(indices); openRequest.indicesOptions(IndicesOptions.fromRequest(request, OpenPointInTimeRequest.DEFAULT_INDICES_OPTIONS)); diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchContextId.java b/server/src/main/java/org/elasticsearch/action/search/SearchContextId.java index 2b7105cffe2bb..f10650a6401d6 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchContextId.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchContextId.java @@ -41,7 +41,7 @@ public final class SearchContextId { private final Map shards; private final Map aliasFilter; - private transient Set contextIds; + private final transient Set contextIds; SearchContextId(Map shards, Map aliasFilter) { this.shards = shards; diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java index 88da2fdfa3a9e..9d3eadcc42bf9 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java @@ -7,6 +7,7 @@ */ package org.elasticsearch.action.search; +import org.elasticsearch.cluster.routing.GroupShardsIterator; import org.elasticsearch.core.CheckedRunnable; import java.io.IOException; @@ -37,4 +38,28 @@ public void start() { throw new UncheckedIOException(e); } } + + static void doCheckNoMissingShards(String phaseName, SearchRequest request, GroupShardsIterator shardsIts) { + assert request.allowPartialSearchResults() != null : "SearchRequest missing setting for allowPartialSearchResults"; + if (request.allowPartialSearchResults() == false) { + final StringBuilder missingShards = new StringBuilder(); + // Fail-fast verification of all shards being available + for (int index = 0; index < shardsIts.size(); index++) { + final SearchShardIterator shardRoutings = shardsIts.get(index); + if (shardRoutings.size() == 0) { + if (missingShards.isEmpty() == false) { + missingShards.append(", "); + } + missingShards.append(shardRoutings.shardId()); + } + } + if (missingShards.isEmpty() == false) { + // Status red - shard is missing all copies and would produce partial results for an index search + final String msg = "Search rejected due to missing shards [" + + missingShards + + "]. Consider using `allow_partial_search_results` setting to bypass this error."; + throw new SearchPhaseExecutionException(phaseName, msg, null, ShardSearchFailure.EMPTY_ARRAY); + } + } + } } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java index fb554232503f2..5af5c4c2ec602 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java @@ -667,7 +667,7 @@ private static void validateMergeSortValueFormats(Collection statsGroups) { - sourceBuilder().stats(statsGroups); - return this; - } - /** * Indicates whether the response should contain the stored _source for every hit */ diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchResponse.java b/server/src/main/java/org/elasticsearch/action/search/SearchResponse.java index b6a9179b1e956..56b58cd8ced6c 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchResponse.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchResponse.java @@ -144,10 +144,6 @@ public RestStatus status() { return RestStatus.status(successfulShards, totalShards, shardFailures); } - public SearchResponseSections getInternalResponse() { - return internalResponse; - } - /** * The search hits. */ @@ -387,7 +383,7 @@ public static SearchResponse innerFromXContent(XContentParser parser) throws IOE } } else if (token == Token.START_ARRAY) { if (RestActions.FAILURES_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - while ((token = parser.nextToken()) != Token.END_ARRAY) { + while (parser.nextToken() != Token.END_ARRAY) { failures.add(ShardSearchFailure.fromXContent(parser)); } } else { @@ -479,7 +475,7 @@ public static final class Clusters implements ToXContentFragment, Writeable { private final Map clusterInfo; // not Writeable since it is only needed on the (primary) CCS coordinator - private transient Boolean ccsMinimizeRoundtrips; + private final transient Boolean ccsMinimizeRoundtrips; /** * For use with cross-cluster searches. @@ -985,7 +981,7 @@ public static class Builder { private List failures; private TimeValue took; private Boolean timedOut; - private Cluster original; + private final Cluster original; public Builder(Cluster copyFrom) { this.original = copyFrom; @@ -1167,7 +1163,7 @@ public static Cluster fromXContent(String clusterAlias, XContentParser parser) t } } else if (token == Token.START_ARRAY) { if (RestActions.FAILURES_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - while ((token = parser.nextToken()) != Token.END_ARRAY) { + while (parser.nextToken() != Token.END_ARRAY) { failures.add(ShardSearchFailure.fromXContent(parser)); } } else { diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchScrollAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchScrollAsyncAction.java index 35aae0764e251..df16c107a2619 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchScrollAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchScrollAsyncAction.java @@ -23,7 +23,6 @@ import org.elasticsearch.transport.RemoteClusterService; import org.elasticsearch.transport.Transport; -import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; @@ -230,7 +229,7 @@ protected SearchPhase sendResponsePhase( ) { return new SearchPhase("fetch") { @Override - public void run() throws IOException { + public void run() { sendResponse(queryPhase, fetchResults); } }; diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java index d02958567a873..800ad7afbb8db 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java @@ -51,6 +51,7 @@ import org.elasticsearch.transport.TransportActionProxy; import org.elasticsearch.transport.TransportException; import org.elasticsearch.transport.TransportRequest; +import org.elasticsearch.transport.TransportRequestHandler; import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.transport.TransportResponseHandler; @@ -366,7 +367,7 @@ public Map getPendingSearchRequests() { } static class ScrollFreeContextRequest extends TransportRequest { - private ShardSearchContextId contextId; + private final ShardSearchContextId contextId; ScrollFreeContextRequest(ShardSearchContextId contextId) { this.contextId = Objects.requireNonNull(contextId); @@ -390,7 +391,7 @@ public ShardSearchContextId id() { } static class SearchFreeContextRequest extends ScrollFreeContextRequest implements IndicesRequest { - private OriginalIndices originalIndices; + private final OriginalIndices originalIndices; SearchFreeContextRequest(OriginalIndices originalIndices, ShardSearchContextId id) { super(id); @@ -428,7 +429,7 @@ public IndicesOptions indicesOptions() { public static class SearchFreeContextResponse extends TransportResponse { - private boolean freed; + private final boolean freed; SearchFreeContextResponse(StreamInput in) throws IOException { freed = in.readBoolean(); @@ -541,13 +542,16 @@ public static void registerRequestHandler(TransportService transportService, Sea ); TransportActionProxy.registerProxyAction(transportService, QUERY_FETCH_SCROLL_ACTION_NAME, true, ScrollQueryFetchSearchResult::new); + TransportRequestHandler shardFetchHandler = (request, channel, task) -> searchService.executeFetchPhase( + request, + (SearchShardTask) task, + new ChannelActionListener<>(channel) + ); transportService.registerRequestHandler( FETCH_ID_SCROLL_ACTION_NAME, EsExecutors.DIRECT_EXECUTOR_SERVICE, ShardFetchRequest::new, - (request, channel, task) -> { - searchService.executeFetchPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel)); - } + shardFetchHandler ); TransportActionProxy.registerProxyAction(transportService, FETCH_ID_SCROLL_ACTION_NAME, true, FetchSearchResult::new); @@ -557,9 +561,7 @@ public static void registerRequestHandler(TransportService transportService, Sea true, true, ShardFetchSearchRequest::new, - (request, channel, task) -> { - searchService.executeFetchPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel)); - } + shardFetchHandler ); TransportActionProxy.registerProxyAction(transportService, FETCH_ID_ACTION_NAME, true, FetchSearchResult::new); diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchType.java b/server/src/main/java/org/elasticsearch/action/search/SearchType.java index 519f1ce98a7b6..8e6511db62136 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchType.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchType.java @@ -39,7 +39,7 @@ public enum SearchType { */ public static final SearchType[] CURRENTLY_SUPPORTED = { QUERY_THEN_FETCH, DFS_QUERY_THEN_FETCH }; - private byte id; + private final byte id; SearchType(byte id) { this.id = id; diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportMultiSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportMultiSearchAction.java index a4a35789db258..a2324010876bf 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportMultiSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportMultiSearchAction.java @@ -16,7 +16,6 @@ import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; -import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.util.concurrent.EsExecutors; @@ -47,13 +46,7 @@ public TransportMultiSearchAction( ActionFilters actionFilters, NodeClient client ) { - super( - MultiSearchAction.NAME, - transportService, - actionFilters, - (Writeable.Reader) MultiSearchRequest::new, - EsExecutors.DIRECT_EXECUTOR_SERVICE - ); + super(MultiSearchAction.NAME, transportService, actionFilters, MultiSearchRequest::new, EsExecutors.DIRECT_EXECUTOR_SERVICE); this.threadPool = threadPool; this.clusterService = clusterService; this.allocatedProcessors = EsExecutors.allocatedProcessors(settings); @@ -70,13 +63,7 @@ public TransportMultiSearchAction( LongSupplier relativeTimeProvider, NodeClient client ) { - super( - MultiSearchAction.NAME, - transportService, - actionFilters, - (Writeable.Reader) MultiSearchRequest::new, - EsExecutors.DIRECT_EXECUTOR_SERVICE - ); + super(MultiSearchAction.NAME, transportService, actionFilters, MultiSearchRequest::new, EsExecutors.DIRECT_EXECUTOR_SERVICE); this.threadPool = threadPool; this.clusterService = clusterService; this.allocatedProcessors = allocatedProcessors; diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java index aeb71a3b03d8f..ae3c735e079e9 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java @@ -268,7 +268,7 @@ public void writeTo(StreamOutput out) throws IOException { private class ShardOpenReaderRequestHandler implements TransportRequestHandler { @Override - public void messageReceived(ShardOpenReaderRequest request, TransportChannel channel, Task task) throws Exception { + public void messageReceived(ShardOpenReaderRequest request, TransportChannel channel, Task task) { searchService.openReaderContext( request.getShardId(), request.keepAlive, diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index a2739e2c2a85e..5030bd875a0f6 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -39,7 +39,6 @@ import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.logging.DeprecationCategory; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.settings.Setting; @@ -159,13 +158,7 @@ public TransportSearchAction( NamedWriteableRegistry namedWriteableRegistry, ExecutorSelector executorSelector ) { - super( - SearchAction.NAME, - transportService, - actionFilters, - (Writeable.Reader) SearchRequest::new, - EsExecutors.DIRECT_EXECUTOR_SERVICE - ); + super(SearchAction.NAME, transportService, actionFilters, SearchRequest::new, EsExecutors.DIRECT_EXECUTOR_SERVICE); this.threadPool = threadPool; this.circuitBreaker = circuitBreakerService.getBreaker(CircuitBreaker.REQUEST); this.searchPhaseController = searchPhaseController; @@ -514,7 +507,7 @@ static void ccsRemoteReduce( clusterAlias, remoteClientResponseExecutor ); - remoteClusterClient.search(ccsSearchRequest, new ActionListener() { + remoteClusterClient.search(ccsSearchRequest, new ActionListener<>() { @Override public void onResponse(SearchResponse searchResponse) { // TODO: in CCS fail fast ticket we may need to fail the query if the cluster is marked as FAILED @@ -749,14 +742,7 @@ private static ActionListener createCCSListener( SearchResponse.Clusters clusters, ActionListener originalListener ) { - return new CCSActionListener( - clusterAlias, - skipUnavailable, - countDown, - exceptions, - clusters, - originalListener - ) { + return new CCSActionListener<>(clusterAlias, skipUnavailable, countDown, exceptions, clusters, originalListener) { @Override void innerOnResponse(SearchResponse searchResponse) { // TODO: in CCS fail fast ticket we may need to fail the query if the cluster gets marked as FAILED @@ -1417,7 +1403,6 @@ abstract static class CCSActionListener implements Acti private final AtomicReference exceptions; protected final SearchResponse.Clusters clusters; private final ActionListener originalListener; - protected final long startTime; /** * Used by both minimize_roundtrips true and false @@ -1436,7 +1421,6 @@ abstract static class CCSActionListener implements Acti this.exceptions = exceptions; this.clusters = clusters; this.originalListener = originalListener; - this.startTime = System.currentTimeMillis(); } @Override @@ -1454,12 +1438,12 @@ public final void onFailure(Exception e) { SearchResponse.Cluster cluster = clusters.getCluster(clusterAlias); if (skipUnavailable) { if (cluster != null) { - ccsClusterInfoUpdate(f, clusters, clusterAlias, skipUnavailable); + ccsClusterInfoUpdate(f, clusters, clusterAlias, true); } // skippedClusters.incrementAndGet(); } else { if (cluster != null) { - ccsClusterInfoUpdate(f, clusters, clusterAlias, skipUnavailable); + ccsClusterInfoUpdate(f, clusters, clusterAlias, false); } Exception exception = e; if (RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY.equals(clusterAlias) == false) { diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchHelper.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchHelper.java index 632fbafa0536b..ffaecedb62bba 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchHelper.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchHelper.java @@ -93,7 +93,7 @@ static ParsedScrollId parseScrollId(String scrollId) { if (in.available() > 0) { throw new IllegalArgumentException("Not all bytes were read"); } - return new ParsedScrollId(scrollId, type, context); + return new ParsedScrollId(type, context); } catch (Exception e) { throw new IllegalArgumentException("Cannot parse scroll id", e); } diff --git a/server/src/test/java/org/elasticsearch/action/search/ParsedScrollIdTests.java b/server/src/test/java/org/elasticsearch/action/search/ParsedScrollIdTests.java index 6130435b4b181..a92cfdb1d02be 100644 --- a/server/src/test/java/org/elasticsearch/action/search/ParsedScrollIdTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/ParsedScrollIdTests.java @@ -26,7 +26,7 @@ public void testHasLocalIndices() { new ShardSearchContextId(randomAlphaOfLength(8), randomLong()) ); } - final ParsedScrollId parsedScrollId = new ParsedScrollId(randomAlphaOfLength(8), randomAlphaOfLength(8), searchContextIdForNodes); + final ParsedScrollId parsedScrollId = new ParsedScrollId(randomAlphaOfLength(8), searchContextIdForNodes); assertEquals(hasLocal, parsedScrollId.hasLocalIndices()); } diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchScrollAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchScrollAsyncActionTests.java index df33a5e18fce6..41e7a5c8ad1e1 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchScrollAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchScrollAsyncActionTests.java @@ -458,7 +458,7 @@ protected void onFirstPhaseResult(int shardId, SearchAsyncActionTests.TestSearch private static ParsedScrollId getParsedScrollId(SearchContextIdForNode... idsForNodes) { List searchContextIdForNodes = Arrays.asList(idsForNodes); Collections.shuffle(searchContextIdForNodes, random()); - return new ParsedScrollId("", "test", searchContextIdForNodes.toArray(new SearchContextIdForNode[0])); + return new ParsedScrollId("test", searchContextIdForNodes.toArray(new SearchContextIdForNode[0])); } private ActionListener dummyListener() { From aa2f6e7c4e2a4bab312d37721dbee381c1e43210 Mon Sep 17 00:00:00 2001 From: Valeriy Khakhutskyy <1292899+valeriy42@users.noreply.github.com> Date: Mon, 6 Nov 2023 12:17:05 +0100 Subject: [PATCH 02/21] [ML] Use perAllocation and perDeployment memory usage in the model assignment planner (#98874) Building upon #98139, this PR extends the model assignment planning algorithms and the linear solver to use the extended memory fields. It also adds unit tests to verify the new behavior. I needed to adjust the old unit tests since we use the estimateMemoryUsage routine, which would compute 2*memoryBytes + 240 MB as the memory requirement. Previously, in the unit tests, we were simply using memoryBytes field value. --- docs/changelog/98874.yaml | 5 + .../assignment/TrainedModelAssignment.java | 5 + .../TransportGetTrainedModelsStatsAction.java | 24 +- .../TrainedModelAssignmentClusterService.java | 7 +- .../TrainedModelAssignmentRebalancer.java | 36 +- .../planning/AbstractPreserveAllocations.java | 42 +- .../assignment/planning/AssignmentPlan.java | 139 +++- .../planning/AssignmentPlanner.java | 11 +- .../planning/LinearProgrammingPlanSolver.java | 29 +- .../planning/PreserveAllAllocations.java | 2 +- .../planning/PreserveOneAllocation.java | 2 +- .../RandomizedAssignmentRounding.java | 46 +- .../planning/ZoneAwareAssignmentPlanner.java | 16 +- ...TrainedModelAssignmentRebalancerTests.java | 81 +- .../planning/AssignmentPlanTests.java | 511 ++++++++++--- .../planning/AssignmentPlannerTests.java | 698 +++++++++++++++--- .../planning/PreserveAllAllocationsTests.java | 228 ++++-- .../planning/PreserveOneAllocationTests.java | 264 +++++-- .../ZoneAwareAssignmentPlannerTests.java | 126 +++- .../MlAssignmentPlannerUpgradeIT.java | 287 +++++++ 20 files changed, 2076 insertions(+), 483 deletions(-) create mode 100644 docs/changelog/98874.yaml create mode 100644 x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MlAssignmentPlannerUpgradeIT.java diff --git a/docs/changelog/98874.yaml b/docs/changelog/98874.yaml new file mode 100644 index 0000000000000..e3eb7b5acc63f --- /dev/null +++ b/docs/changelog/98874.yaml @@ -0,0 +1,5 @@ +pr: 98874 +summary: Estimate the memory required to deploy trained models more accurately +area: Machine Learning +type: enhancement +issues: [] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java index f69be31939b32..d27d325a5c596 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java @@ -9,6 +9,7 @@ import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.cluster.SimpleDiffable; import org.elasticsearch.common.Randomness; @@ -96,6 +97,10 @@ public final class TrainedModelAssignment implements SimpleDiffable 0L ? StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( model.getModelId(), totalDefinitionLength, - model.getPerDeploymentMemoryBytes(), - model.getPerAllocationMemoryBytes(), + useNewMemoryFields ? model.getPerDeploymentMemoryBytes() : 0, + useNewMemoryFields ? model.getPerAllocationMemoryBytes() : 0, numberOfAllocations ) : 0L; modelSizeStatsByModelId.put( model.getModelId(), - new TrainedModelSizeStats( - totalDefinitionLength, - totalDefinitionLength > 0L - ? StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( - model.getModelId(), - totalDefinitionLength, - model.getPerDeploymentMemoryBytes(), - model.getPerAllocationMemoryBytes(), - numberOfAllocations - ) - : 0L - ) + new TrainedModelSizeStats(totalDefinitionLength, estimatedMemoryUsageBytes) ); } else { modelSizeStatsByModelId.put(model.getModelId(), new TrainedModelSizeStats(model.getModelSize(), 0)); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java index 2caf338d2a3c7..fe4462d6556ee 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java @@ -47,6 +47,7 @@ import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil; +import org.elasticsearch.xpack.core.ml.utils.TransportVersionUtils; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.autoscaling.NodeAvailabilityZoneMapper; import org.elasticsearch.xpack.ml.inference.assignment.planning.AllocationReducer; @@ -76,6 +77,8 @@ public class TrainedModelAssignmentClusterService implements ClusterStateListene private static final TransportVersion RENAME_ALLOCATION_TO_ASSIGNMENT_TRANSPORT_VERSION = TransportVersions.V_8_3_0; public static final TransportVersion DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION = TransportVersions.V_8_4_0; + private static final TransportVersion NEW_ALLOCATION_MEMORY_VERSION = TransportVersions.V_8_500_064; + private final ClusterService clusterService; private final ThreadPool threadPool; private final NodeLoadDetector nodeLoadDetector; @@ -644,12 +647,14 @@ private TrainedModelAssignmentMetadata.Builder rebalanceAssignments( Map nodeLoads = detectNodeLoads(nodes, currentState); TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.fromState(currentState); + boolean useNewMemoryFields = TrainedModelAssignment.useNewMemoryFields(TransportVersionUtils.getMinTransportVersion(currentState)); TrainedModelAssignmentRebalancer rebalancer = new TrainedModelAssignmentRebalancer( currentMetadata, nodeLoads, nodeAvailabilityZoneMapper.buildMlNodesByAvailabilityZone(currentState), modelToAdd, - allocatedProcessorsScale + allocatedProcessorsScale, + useNewMemoryFields ); Set shuttingDownNodeIds = currentState.metadata().nodeShutdowns().getAllNodeIds(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java index e1241dc8a93c3..6e6b447fcea3d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java @@ -52,18 +52,22 @@ class TrainedModelAssignmentRebalancer { private final Optional deploymentToAdd; private final int allocatedProcessorsScale; + private final boolean useNewMemoryFields; + TrainedModelAssignmentRebalancer( TrainedModelAssignmentMetadata currentMetadata, Map nodeLoads, Map, Collection> mlNodesByZone, Optional deploymentToAdd, - int allocatedProcessorsScale + int allocatedProcessorsScale, + boolean useNewMemoryFields ) { this.currentMetadata = Objects.requireNonNull(currentMetadata); this.nodeLoads = Objects.requireNonNull(nodeLoads); this.mlNodesByZone = Objects.requireNonNull(mlNodesByZone); this.deploymentToAdd = Objects.requireNonNull(deploymentToAdd); this.allocatedProcessorsScale = allocatedProcessorsScale; + this.useNewMemoryFields = useNewMemoryFields; } TrainedModelAssignmentMetadata.Builder rebalance() { @@ -138,9 +142,11 @@ private static void copyAssignments( AssignmentPlan.Node originalNode = originalNodeById.get(assignment.getKey().id()); dest.assignModelToNode(m, originalNode, assignment.getValue()); if (m.currentAllocationsByNodeId().containsKey(originalNode.id())) { + // TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder // As the node has all its available memory we need to manually account memory of models with // current allocations. - dest.accountMemory(m, originalNode); + long requiredMemory = m.estimateMemoryUsageBytes(m.currentAllocationsByNodeId().get(originalNode.id())); + dest.accountMemory(m, originalNode, requiredMemory); } } } @@ -168,11 +174,14 @@ private AssignmentPlan computePlanForNormalPriorityModels( .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getTargetAllocations())); return new AssignmentPlan.Deployment( assignment.getDeploymentId(), - assignment.getTaskParams().estimateMemoryUsageBytes(), + assignment.getTaskParams().getModelBytes(), assignment.getTaskParams().getNumberOfAllocations(), assignment.getTaskParams().getThreadsPerAllocation(), currentAssignments, - assignment.getMaxAssignedAllocations() + assignment.getMaxAssignedAllocations(), + // in the mixed cluster state use old memory fields to avoid unstable assignment plans + useNewMemoryFields ? assignment.getTaskParams().getPerDeploymentMemoryBytes() : 0, + useNewMemoryFields ? assignment.getTaskParams().getPerAllocationMemoryBytes() : 0 ); }) .forEach(planDeployments::add); @@ -181,11 +190,14 @@ private AssignmentPlan computePlanForNormalPriorityModels( planDeployments.add( new AssignmentPlan.Deployment( taskParams.getDeploymentId(), - taskParams.estimateMemoryUsageBytes(), + taskParams.getModelBytes(), taskParams.getNumberOfAllocations(), taskParams.getThreadsPerAllocation(), Map.of(), - 0 + 0, + // in the mixed cluster state use old memory fields to avoid unstable assignment plans + useNewMemoryFields ? taskParams.getPerDeploymentMemoryBytes() : 0, + useNewMemoryFields ? taskParams.getPerAllocationMemoryBytes() : 0 ) ); } @@ -217,12 +229,14 @@ private AssignmentPlan computePlanForLowPriorityModels(Set assignableNod .map( assignment -> new AssignmentPlan.Deployment( assignment.getDeploymentId(), - assignment.getTaskParams().estimateMemoryUsageBytes(), + assignment.getTaskParams().getModelBytes(), assignment.getTaskParams().getNumberOfAllocations(), assignment.getTaskParams().getThreadsPerAllocation(), findFittingAssignments(assignment, assignableNodeIds, remainingNodeMemory), assignment.getMaxAssignedAllocations(), - Priority.LOW + Priority.LOW, + (useNewMemoryFields == false) ? assignment.getTaskParams().getPerDeploymentMemoryBytes() : 0, + (useNewMemoryFields == false) ? assignment.getTaskParams().getPerAllocationMemoryBytes() : 0 ) ) .forEach(planDeployments::add); @@ -231,12 +245,14 @@ private AssignmentPlan computePlanForLowPriorityModels(Set assignableNod planDeployments.add( new AssignmentPlan.Deployment( taskParams.getDeploymentId(), - taskParams.estimateMemoryUsageBytes(), + taskParams.getModelBytes(), taskParams.getNumberOfAllocations(), taskParams.getThreadsPerAllocation(), Map.of(), 0, - Priority.LOW + Priority.LOW, + (useNewMemoryFields == false) ? taskParams.getPerDeploymentMemoryBytes() : 0, + (useNewMemoryFields == false) ? taskParams.getPerAllocationMemoryBytes() : 0 ) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java index 4843cc43d1187..026b433a8c2d4 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java @@ -35,7 +35,8 @@ private Node modifyNodePreservingAllocations(Node n) { int coresUsed = 0; for (Deployment m : deployments) { if (m.currentAllocationsByNodeId().containsKey(n.id())) { - bytesUsed += m.memoryBytes(); + int allocations = m.currentAllocationsByNodeId().get(n.id()); + bytesUsed += m.estimateMemoryUsageBytes(allocations); coresUsed += calculateUsedCores(n, m); } } @@ -58,7 +59,9 @@ Deployment modifyModelPreservingPreviousAssignments(Deployment m) { m.allocations() - calculatePreservedAllocations(m), m.threadsPerAllocation(), calculateAllocationsPerNodeToPreserve(m), - m.maxAssignedAllocations() + m.maxAssignedAllocations(), + m.perDeploymentMemoryBytes(), + m.perAllocationMemoryBytes() ); } @@ -67,28 +70,37 @@ AssignmentPlan mergePreservedAllocations(AssignmentPlan assignmentPlan) { // they will not match the models/nodes members we have in this class. // Therefore, we build a lookup table based on the ids so we can merge the plan // with its preserved allocations. - final Map, Integer> assignmentsByModelNodeIdPair = new HashMap<>(); + final Map, Integer> plannedAssignmentsByModelNodeIdPair = new HashMap<>(); for (Deployment m : assignmentPlan.models()) { Map assignments = assignmentPlan.assignments(m).orElse(Map.of()); for (Map.Entry nodeAssignment : assignments.entrySet()) { - assignmentsByModelNodeIdPair.put(Tuple.tuple(m.id(), nodeAssignment.getKey().id()), nodeAssignment.getValue()); + plannedAssignmentsByModelNodeIdPair.put(Tuple.tuple(m.id(), nodeAssignment.getKey().id()), nodeAssignment.getValue()); } } AssignmentPlan.Builder mergedPlanBuilder = AssignmentPlan.builder(nodes, deployments); - for (Deployment m : deployments) { - for (Node n : nodes) { - int allocations = assignmentsByModelNodeIdPair.getOrDefault(Tuple.tuple(m.id(), n.id()), 0); - if (m.currentAllocationsByNodeId().containsKey(n.id())) { - if (mergedPlanBuilder.getRemainingMemory(n) >= m.memoryBytes()) { - allocations += addPreservedAllocations(n, m); - // As the node has all its available memory we need to manually account memory of models with - // current allocations. - mergedPlanBuilder.accountMemory(m, n); + for (Node n : nodes) { + // TODO (#101612) Should the first loop happen in the builder constructor? + for (Deployment deploymentAllocationsToPreserve : deployments) { + + // if the model m is already allocated on the node n and I want to preserve this allocation + int preservedAllocations = addPreservedAllocations(n, deploymentAllocationsToPreserve); + if (preservedAllocations > 0) { + long requiredMemory = deploymentAllocationsToPreserve.estimateMemoryUsageBytes(preservedAllocations); + if (mergedPlanBuilder.canAssign(deploymentAllocationsToPreserve, n, preservedAllocations, requiredMemory)) { + mergedPlanBuilder.assignModelToNode(deploymentAllocationsToPreserve, n, preservedAllocations, requiredMemory); } } - if (allocations > 0) { - mergedPlanBuilder.assignModelToNode(m, n, allocations); + } + for (Deployment deploymentNewAllocations : deployments) { + int newAllocations = plannedAssignmentsByModelNodeIdPair.getOrDefault( + Tuple.tuple(deploymentNewAllocations.id(), n.id()), + 0 + ); + + long requiredMemory = mergedPlanBuilder.getDeploymentMemoryRequirement(deploymentNewAllocations, n, newAllocations); + if (newAllocations > 0 && mergedPlanBuilder.canAssign(deploymentNewAllocations, n, newAllocations, requiredMemory)) { + mergedPlanBuilder.assignModelToNode(deploymentNewAllocations, n, newAllocations); } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java index 72a83d7579463..1dce7f0bb46ba 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.Maps; import org.elasticsearch.core.Tuple; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.assignment.Priority; import java.util.ArrayList; @@ -36,18 +37,32 @@ public record Deployment( int threadsPerAllocation, Map currentAllocationsByNodeId, int maxAssignedAllocations, - Priority priority + Priority priority, + long perDeploymentMemoryBytes, + long perAllocationMemoryBytes ) { public Deployment( String id, - long memoryBytes, + long modelBytes, int allocations, int threadsPerAllocation, Map currentAllocationsByNodeId, - int maxAssignedAllocations + int maxAssignedAllocations, + long perDeploymentMemoryBytes, + long perAllocationMemoryBytes ) { - this(id, memoryBytes, allocations, threadsPerAllocation, currentAllocationsByNodeId, maxAssignedAllocations, Priority.NORMAL); + this( + id, + modelBytes, + allocations, + threadsPerAllocation, + currentAllocationsByNodeId, + maxAssignedAllocations, + Priority.NORMAL, + perDeploymentMemoryBytes, + perAllocationMemoryBytes + ); } int getCurrentAssignedAllocations() { @@ -58,6 +73,60 @@ boolean hasEverBeenAllocated() { return maxAssignedAllocations > 0; } + public long estimateMemoryUsageBytes(int allocations) { + return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( + id, + memoryBytes, + perDeploymentMemoryBytes, + perAllocationMemoryBytes, + allocations + ); + } + + long estimateAdditionalMemoryUsageBytes(int allocationsOld, int allocationsNew) { + return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( + id, + memoryBytes, + perDeploymentMemoryBytes, + perAllocationMemoryBytes, + allocationsNew + ) - StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( + id, + memoryBytes, + perDeploymentMemoryBytes, + perAllocationMemoryBytes, + allocationsOld + ); + + } + + long minimumMemoryRequiredBytes() { + return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( + id, + memoryBytes, + perDeploymentMemoryBytes, + perAllocationMemoryBytes, + 1 + ); + } + + int findOptimalAllocations(int maxAllocations, long availableMemoryBytes) { + if (perDeploymentMemoryBytes > 0 && perAllocationMemoryBytes > 0) { + return (int) Math.max( + Math.min(maxAllocations, Math.floorDiv(availableMemoryBytes - estimateMemoryUsageBytes(0), perAllocationMemoryBytes)), + 0 + ); + } + return maxAllocations; + } + + int findExcessAllocations(int maxAllocations, long availableMemoryBytes) { + if (perDeploymentMemoryBytes > 0 && perAllocationMemoryBytes > 0) { + return (int) Math.min(maxAllocations, Math.floorDiv(availableMemoryBytes, perAllocationMemoryBytes)); + } + return maxAllocations; + } + @Override public String toString() { return id @@ -71,6 +140,8 @@ public String toString() { + currentAllocationsByNodeId + ") (max_assigned_allocations = " + maxAssignedAllocations + + ") (memory_usage = " + + ByteSizeValue.ofBytes(estimateMemoryUsageBytes(allocations)) + ")"; } }; @@ -304,19 +375,42 @@ int getRemainingAllocations(Deployment m) { } boolean canAssign(Deployment deployment, Node node, int allocations) { - return (isAlreadyAssigned(deployment, node) - || (deployment.memoryBytes() <= remainingNodeMemory.get(node)) - && (deployment.priority == Priority.LOW - || allocations * deployment.threadsPerAllocation() <= remainingNodeCores.get(node))); + long requiredMemory = getDeploymentMemoryRequirement(deployment, node, allocations); + return canAssign(deployment, node, allocations, requiredMemory); + } + + boolean canAssign(Deployment deployment, Node node, int allocations, long requiredMemory) { + return (requiredMemory <= remainingNodeMemory.get(node)) + && (deployment.priority == Priority.LOW || allocations * deployment.threadsPerAllocation() <= remainingNodeCores.get(node)); + } + + public long getDeploymentMemoryRequirement(Deployment deployment, Node node, int newAllocations) { + int assignedAllocations = getAssignedAllocations(deployment, node); + + if (assignedAllocations > 0) { + return deployment.estimateAdditionalMemoryUsageBytes(assignedAllocations, assignedAllocations + newAllocations); + } + return deployment.estimateMemoryUsageBytes(newAllocations); } public Builder assignModelToNode(Deployment deployment, Node node, int allocations) { + return assignModelToNode(deployment, node, allocations, getDeploymentMemoryRequirement(deployment, node, allocations)); + } + + public Builder assignModelToNode(Deployment deployment, Node node, int allocations, long requiredMemory) { if (allocations <= 0) { return this; } - if (isAlreadyAssigned(deployment, node) == false && deployment.memoryBytes() > remainingNodeMemory.get(node)) { + if (/*isAlreadyAssigned(deployment, node) == false + &&*/ requiredMemory > remainingNodeMemory.get(node)) { throw new IllegalArgumentException( - "not enough memory on node [" + node.id() + "] to assign model [" + deployment.id() + "]" + "not enough memory on node [" + + node.id() + + "] to assign [" + + allocations + + "] allocations to deployment [" + + deployment.id() + + "]" ); } if (deployment.priority == Priority.NORMAL && allocations * deployment.threadsPerAllocation() > remainingNodeCores.get(node)) { @@ -333,9 +427,9 @@ public Builder assignModelToNode(Deployment deployment, Node node, int allocatio ); } - long additionalModelMemory = isAlreadyAssigned(deployment, node) ? 0 : deployment.memoryBytes; assignments.get(deployment).compute(node, (n, remAllocations) -> remAllocations + allocations); - remainingNodeMemory.compute(node, (n, remMemory) -> remMemory - additionalModelMemory); + accountMemory(deployment, node, requiredMemory); + if (deployment.priority == Priority.NORMAL) { remainingNodeCores.compute(node, (n, remCores) -> remCores - allocations * deployment.threadsPerAllocation()); } @@ -347,9 +441,26 @@ private boolean isAlreadyAssigned(Deployment deployment, Node node) { return deployment.currentAllocationsByNodeId().containsKey(node.id()) || assignments.get(deployment).get(node) > 0; } + private int getAssignedAllocations(Deployment deployment, Node node) { + int currentAllocations = getCurrentAllocations(deployment, node); + int assignmentAllocations = assignments.get(deployment).get(node); + return currentAllocations + assignmentAllocations; + } + + private static int getCurrentAllocations(Deployment m, Node n) { + return m.currentAllocationsByNodeId.containsKey(n.id()) ? m.currentAllocationsByNodeId.get(n.id()) : 0; + } + public void accountMemory(Deployment m, Node n) { - remainingNodeMemory.computeIfPresent(n, (k, v) -> v - m.memoryBytes()); - if (remainingNodeMemory.get(n) < 0) { + // TODO (#101612) remove or refactor unused method + long requiredMemory = getDeploymentMemoryRequirement(m, n, getCurrentAllocations(m, n)); + accountMemory(m, n, requiredMemory); + } + + public void accountMemory(Deployment m, Node n, long requiredMemory) { + // TODO (#101612) computation of required memory should be done internally + remainingNodeMemory.computeIfPresent(n, (k, v) -> v - requiredMemory); + if (remainingNodeMemory.containsKey(n) && remainingNodeMemory.get(n) < 0) { throw new IllegalArgumentException("not enough memory on node [" + n.id() + "] to assign model [" + m.id() + "]"); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java index 73b713cced32a..b1c017b1a784c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java @@ -115,8 +115,11 @@ private AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocat m.memoryBytes(), 1, m.threadsPerAllocation(), - m.currentAllocationsByNodeId(), - m.maxAssignedAllocations() + // don't rely on the current allocation + new HashMap<>(), + m.maxAssignedAllocations(), + m.perDeploymentMemoryBytes(), + m.perAllocationMemoryBytes() ) ) .toList(); @@ -145,7 +148,9 @@ private AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocat m.allocations(), m.threadsPerAllocation(), currentAllocationsByNodeId, - m.maxAssignedAllocations() + m.maxAssignedAllocations(), + m.perDeploymentMemoryBytes(), + m.perAllocationMemoryBytes() ); }).toList(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/LinearProgrammingPlanSolver.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/LinearProgrammingPlanSolver.java index 90c5a2257d94d..bd97680e285cc 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/LinearProgrammingPlanSolver.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/LinearProgrammingPlanSolver.java @@ -68,6 +68,8 @@ class LinearProgrammingPlanSolver { private final Map normalizedMemoryPerNode; private final Map coresPerNode; private final Map normalizedMemoryPerModel; + private final Map normalizedMemoryPerAllocation; + private final Map normalizedMinimumDeploymentMemoryRequired; private final int maxNodeCores; private final long maxModelMemoryBytes; @@ -84,12 +86,17 @@ class LinearProgrammingPlanSolver { .filter(m -> m.threadsPerAllocation() <= maxNodeCores) .toList(); - maxModelMemoryBytes = this.deployments.stream().map(AssignmentPlan.Deployment::memoryBytes).max(Long::compareTo).orElse(1L); + // We use the maximum memory to deploy a model with one allocation as the normalization factor. + maxModelMemoryBytes = this.deployments.stream().map(m -> m.minimumMemoryRequiredBytes()).max(Long::compareTo).orElse(1L); normalizedMemoryPerNode = this.nodes.stream() .collect(Collectors.toMap(Function.identity(), n -> n.availableMemoryBytes() / (double) maxModelMemoryBytes)); coresPerNode = this.nodes.stream().collect(Collectors.toMap(Function.identity(), Node::cores)); normalizedMemoryPerModel = this.deployments.stream() - .collect(Collectors.toMap(Function.identity(), m -> m.memoryBytes() / (double) maxModelMemoryBytes)); + .collect(Collectors.toMap(Function.identity(), m -> m.estimateMemoryUsageBytes(0) / (double) maxModelMemoryBytes)); + normalizedMemoryPerAllocation = this.deployments.stream() + .collect(Collectors.toMap(Function.identity(), m -> m.perAllocationMemoryBytes() / (double) maxModelMemoryBytes)); + normalizedMinimumDeploymentMemoryRequired = this.deployments.stream() + .collect(Collectors.toMap(Function.identity(), m -> m.minimumMemoryRequiredBytes() / (double) maxModelMemoryBytes)); } AssignmentPlan solvePlan(boolean useBinPackingOnly) { @@ -133,8 +140,8 @@ private double weightForAllocationVar( Node n, Map, Double> weights ) { - return (1 + weights.get(Tuple.tuple(m, n)) - (m.memoryBytes() > n.availableMemoryBytes() ? 10 : 0)) - L1 * normalizedMemoryPerModel - .get(m) / maxNodeCores; + return (1 + weights.get(Tuple.tuple(m, n)) - (m.minimumMemoryRequiredBytes() > n.availableMemoryBytes() ? 10 : 0)) - L1 + * normalizedMemoryPerModel.get(m) / maxNodeCores; } private Tuple, Double>, AssignmentPlan> calculateWeightsAndBinPackingPlan() { @@ -156,9 +163,9 @@ private Tuple, Double>, AssignmentPlan> calculateWei .sorted(Comparator.comparingDouble(n -> descendingSizeAnyFitsNodeOrder(n, m, assignmentPlan))) .toList(); for (Node n : orderedNodes) { - int allocations = Math.min( - assignmentPlan.getRemainingCores(n) / m.threadsPerAllocation(), - assignmentPlan.getRemainingAllocations(m) + int allocations = m.findOptimalAllocations( + Math.min(assignmentPlan.getRemainingCores(n) / m.threadsPerAllocation(), assignmentPlan.getRemainingAllocations(m)), + assignmentPlan.getRemainingMemory(n) ); if (allocations > 0 && assignmentPlan.canAssign(m, n, allocations)) { assignmentPlan.assignModelToNode(m, n, allocations); @@ -185,7 +192,8 @@ private Tuple, Double>, AssignmentPlan> calculateWei } private double descendingSizeAnyFitsModelOrder(AssignmentPlan.Deployment m) { - return (m.currentAllocationsByNodeId().isEmpty() ? 1 : 2) * -normalizedMemoryPerModel.get(m) * m.threadsPerAllocation(); + return (m.currentAllocationsByNodeId().isEmpty() ? 1 : 2) * -normalizedMinimumDeploymentMemoryRequired.get(m) * m + .threadsPerAllocation(); } private double descendingSizeAnyFitsNodeOrder(Node n, AssignmentPlan.Deployment m, AssignmentPlan.Builder assignmentPlan) { @@ -307,7 +315,10 @@ private boolean solveLinearProgram( List modelMemories = new ArrayList<>(); deployments.stream().filter(m -> m.currentAllocationsByNodeId().containsKey(n.id()) == false).forEach(m -> { allocations.add(allocationVars.get(Tuple.tuple(m, n))); - modelMemories.add(normalizedMemoryPerModel.get(m) * m.threadsPerAllocation() / (double) coresPerNode.get(n)); + modelMemories.add( + (normalizedMemoryPerModel.get(m) / (double) coresPerNode.get(n) + normalizedMemoryPerAllocation.get(m)) * m + .threadsPerAllocation() + ); }); model.addExpression("used_memory_on_node_" + n.id() + "_not_more_than_available") .upper(normalizedMemoryPerNode.get(n)) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocations.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocations.java index f10ece8f5a593..72109941ad477 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocations.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocations.java @@ -37,6 +37,6 @@ protected int calculatePreservedAllocations(Deployment m) { @Override protected int addPreservedAllocations(Node n, Deployment m) { - return m.currentAllocationsByNodeId().get(n.id()); + return m.currentAllocationsByNodeId().containsKey(n.id()) ? m.currentAllocationsByNodeId().get(n.id()) : 0; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocation.java index 324e1a8d69a53..43b8860803596 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocation.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocation.java @@ -37,6 +37,6 @@ protected int calculatePreservedAllocations(AssignmentPlan.Deployment m) { @Override protected int addPreservedAllocations(Node n, AssignmentPlan.Deployment m) { - return 1; + return m.currentAllocationsByNodeId().containsKey(n.id()) ? 1 : 0; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/RandomizedAssignmentRounding.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/RandomizedAssignmentRounding.java index dafc07099f850..8bdc99998a0c2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/RandomizedAssignmentRounding.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/RandomizedAssignmentRounding.java @@ -135,8 +135,9 @@ private void assignUnderSubscribedNodes(Collection nodeSelection) { for (AssignmentPlan.Deployment m : deployments) { Tuple assignment = Tuple.tuple(m, n); if (assignments.get(assignment) > 0) { - totalModelMemory += m.memoryBytes(); - maxTotalThreads += (int) Math.ceil(allocations.get(assignment)) * m.threadsPerAllocation(); + int roundedAllocations = (int) Math.ceil(allocations.get(assignment)); + totalModelMemory += m.estimateMemoryUsageBytes(roundedAllocations); + maxTotalThreads += roundedAllocations * m.threadsPerAllocation(); assignedDeployments.add(m); } } @@ -199,9 +200,12 @@ private void assignExcessCores(Node n) { if (resourceTracker.remainingNodeCores.get(n) <= 0) { break; } - int extraAllocations = Math.min( - resourceTracker.remainingNodeCores.get(n) / m.threadsPerAllocation(), - resourceTracker.remainingModelAllocations.get(m) + int extraAllocations = m.findExcessAllocations( + Math.min( + resourceTracker.remainingNodeCores.get(n) / m.threadsPerAllocation(), + resourceTracker.remainingModelAllocations.get(m) + ), + resourceTracker.remainingNodeMemory.get(n) ); allocations.compute(Tuple.tuple(m, n), (k, v) -> v + extraAllocations); resourceTracker.assign(m, n, extraAllocations); @@ -211,7 +215,7 @@ private void assignExcessCores(Node n) { } private static double remainingModelOrder(AssignmentPlan.Deployment m) { - return (m.currentAllocationsByNodeId().isEmpty() ? 1 : 2) * -m.memoryBytes(); + return (m.currentAllocationsByNodeId().isEmpty() ? 1 : 2) * -m.minimumMemoryRequiredBytes(); } private boolean hasSoftAssignments(Node n) { @@ -275,15 +279,17 @@ private void doRandomizedRounding(List> s int roundedAllocations = random.nextDouble() < roundUpProbability ? (int) Math.ceil(allocations.get(assignment)) : (int) Math.floor(allocations.get(assignment)); - - if (m.memoryBytes() > resourceTracker.remainingNodeMemory.get(n) + if (m.estimateMemoryUsageBytes(roundedAllocations) > resourceTracker.remainingNodeMemory.get(n) || m.threadsPerAllocation() > resourceTracker.remainingNodeCores.get(n) || roundedAllocations == 0 || random.nextDouble() > assignments.get(assignment)) { unassign(assignment); assignUnderSubscribedNodes(Set.of(n)); } else { - roundedAllocations = Math.min(roundedAllocations, resourceTracker.remainingNodeCores.get(n) / m.threadsPerAllocation()); + roundedAllocations = m.findOptimalAllocations( + Math.min(roundedAllocations, resourceTracker.remainingNodeCores.get(n) / m.threadsPerAllocation()), + resourceTracker.remainingNodeMemory.get(n) + ); assignModelToNode(m, n, roundedAllocations); unassignOversizedModels(n); assignExcessCores(n); @@ -294,7 +300,8 @@ private void doRandomizedRounding(List> s private void unassignOversizedModels(Node n) { for (AssignmentPlan.Deployment m : deployments) { Tuple assignment = Tuple.tuple(m, n); - if (assignments.get(assignment) < 1.0 && m.memoryBytes() > resourceTracker.remainingNodeMemory.get(n)) { + int roundedAllocations = (int) Math.ceil(allocations.get(assignment)); + if (assignments.get(assignment) < 1.0 && m.minimumMemoryRequiredBytes() > resourceTracker.remainingNodeMemory.get(n)) { unassign(assignment); } } @@ -303,7 +310,11 @@ private void unassignOversizedModels(Node n) { private AssignmentPlan toPlan() { AssignmentPlan.Builder builder = AssignmentPlan.builder(nodes, deployments); for (Map.Entry, Integer> assignment : tryAssigningRemainingCores().entrySet()) { - builder.assignModelToNode(assignment.getKey().v1(), assignment.getKey().v2(), assignment.getValue()); + // TODO (#101612) The model should be assigned to the node only when it is possible. This means, that canAssign should be + // integrated into the assignModelToNode. + if (builder.canAssign(assignment.getKey().v1(), assignment.getKey().v2(), assignment.getValue())) { + builder.assignModelToNode(assignment.getKey().v1(), assignment.getKey().v2(), assignment.getValue()); + } } return builder.build(); } @@ -338,7 +349,7 @@ private Map, Integer> tryAssigningRemaini .toList()) { for (Node n : nodes.stream() .filter( - n -> resourceTracker.remainingNodeMemory.get(n) >= m.memoryBytes() + n -> resourceTracker.remainingNodeMemory.get(n) >= m.minimumMemoryRequiredBytes() && resourceTracker.remainingNodeCores.get(n) >= m.threadsPerAllocation() && resultAllocations.get(Tuple.tuple(m, n)) == 0 ) @@ -354,10 +365,15 @@ private Map, Integer> tryAssigningRemaini ) ) .toList()) { - int assigningAllocations = Math.min( resourceTracker.remainingNodeCores.get(n) / m.threadsPerAllocation(), - resourceTracker.remainingModelAllocations.get(m) + Math.min( + resourceTracker.remainingModelAllocations.get(m), + m.findOptimalAllocations( + resourceTracker.remainingNodeCores.get(n) / m.threadsPerAllocation(), + resourceTracker.remainingModelAllocations.get(m) + ) + ) ); resourceTracker.assign(m, n, assigningAllocations); resultAllocations.put(Tuple.tuple(m, n), assigningAllocations); @@ -427,7 +443,7 @@ private static class ResourceTracker { void assign(AssignmentPlan.Deployment m, Node n, int allocations) { if (assignments.contains(Tuple.tuple(m, n)) == false) { assignments.add(Tuple.tuple(m, n)); - remainingNodeMemory.compute(n, (k, v) -> v - m.memoryBytes()); + remainingNodeMemory.compute(n, (k, v) -> v - m.estimateMemoryUsageBytes(allocations)); } remainingNodeCores.compute(n, (k, v) -> v - allocations * m.threadsPerAllocation()); remainingModelAllocations.compute(m, (k, v) -> v - allocations); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java index 9870aa93bf6ce..8c9499ca9e00c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java @@ -126,10 +126,12 @@ private AssignmentPlan computeZonePlan( modelIdToTargetAllocations.get(m.id()), m.threadsPerAllocation(), m.currentAllocationsByNodeId(), - // Only force assigning at least once previously assigned models that have not had any allocation yet (tryAssigningPreviouslyAssignedModels && modelIdToRemainingAllocations.get(m.id()) == m.allocations()) ? m.maxAssignedAllocations() - : 0 + : 0, + // Only force assigning at least once previously assigned models that have not had any allocation yet + m.perDeploymentMemoryBytes(), + m.perAllocationMemoryBytes() ) ) .toList(); @@ -151,7 +153,9 @@ private AssignmentPlan computePlanAcrossAllNodes(List plans) { m.allocations(), m.threadsPerAllocation(), allocationsByNodeIdByModelId.get(m.id()), - m.maxAssignedAllocations() + m.maxAssignedAllocations(), + m.perDeploymentMemoryBytes(), + m.perAllocationMemoryBytes() ) ) .toList(); @@ -180,9 +184,13 @@ private AssignmentPlan swapOriginalModelsInPlan( Node originalNode = originalNodeById.get(assignment.getKey().id()); planBuilder.assignModelToNode(originalDeployment, originalNode, assignment.getValue()); if (originalDeployment.currentAllocationsByNodeId().containsKey(originalNode.id())) { + // TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder // As the node has all its available memory we need to manually account memory of models with // current allocations. - planBuilder.accountMemory(m, originalNode); + long requiredMemory = originalDeployment.estimateMemoryUsageBytes( + originalDeployment.currentAllocationsByNodeId().get(originalNode.id()) + ); + planBuilder.accountMemory(m, originalNode, requiredMemory); } } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java index 8ccf8839cfc08..334fdfbb8b922 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java @@ -44,7 +44,8 @@ public void testRebalance_GivenNoAssignments() { Map.of(), Map.of(), Optional.empty(), - 1 + 1, + false ).rebalance().build(); assertThat(result.allAssignments().isEmpty(), is(true)); } @@ -78,7 +79,8 @@ public void testRebalance_GivenAllAssignmentsAreSatisfied_ShouldMakeNoChanges() nodeLoads, Map.of(), Optional.empty(), - 1 + 1, + false ).rebalance().build(); assertThat(currentMetadata, equalTo(result)); @@ -116,7 +118,8 @@ public void testRebalance_GivenAllAssignmentsAreSatisfied_GivenOutdatedRoutingEn nodeLoads, Map.of(List.of(), List.of(node1, node2)), Optional.empty(), - 1 + 1, + false ).rebalance().build(); assertThat(result.allAssignments(), is(aMapWithSize(2))); @@ -140,7 +143,7 @@ public void testRebalance_GivenModelToAddAlreadyExists() { .build(); expectThrows( ResourceAlreadyExistsException.class, - () -> new TrainedModelAssignmentRebalancer(currentMetadata, Map.of(), Map.of(), Optional.of(taskParams), 1).rebalance() + () -> new TrainedModelAssignmentRebalancer(currentMetadata, Map.of(), Map.of(), Optional.of(taskParams), 1, false).rebalance() ); } @@ -154,7 +157,8 @@ public void testRebalance_GivenFirstModelToAdd_NoMLNodes() throws Exception { Map.of(), Map.of(), Optional.of(taskParams), - 1 + 1, + false ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(modelId); @@ -181,7 +185,8 @@ public void testRebalance_GivenFirstModelToAdd_NotEnoughProcessors() throws Exce nodeLoads, Map.of(List.of(), List.of(node)), Optional.of(taskParams), - 1 + 1, + false ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(modelId); @@ -217,7 +222,8 @@ public void testRebalance_GivenFirstModelToAdd_NotEnoughMemory() throws Exceptio nodeLoads, Map.of(), Optional.of(taskParams), - 1 + 1, + false ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(modelId); @@ -253,7 +259,8 @@ public void testRebalance_GivenFirstModelToAdd_ErrorDetectingNodeLoad() throws E nodeLoads, Map.of(), Optional.of(taskParams), - 1 + 1, + false ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(modelId); @@ -289,7 +296,8 @@ public void testRebalance_GivenProblemsOnMultipleNodes() throws Exception { nodeLoads, Map.of(List.of(), List.of(node1, node2)), Optional.of(taskParams), - 1 + 1, + false ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(modelId); @@ -322,7 +330,8 @@ public void testRebalance_GivenFirstModelToAdd_FitsFully() throws Exception { nodeLoads, Map.of(List.of(), List.of(node1)), Optional.of(taskParams), - 1 + 1, + false ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(modelId); @@ -361,7 +370,8 @@ public void testRebalance_GivenModelToAdd_AndPreviousAssignments_AndTwoNodes_All nodeLoads, Map.of(List.of(), List.of(node1, node2)), Optional.of(taskParams), - 1 + 1, + false ).rebalance().build(); assertThat(result.allAssignments(), is(aMapWithSize(2))); @@ -425,7 +435,8 @@ public void testRebalance_GivenPreviousAssignments_AndNewNode() throws Exception nodeLoads, Map.of(List.of(), List.of(node1, node2, node3)), Optional.empty(), - 1 + 1, + false ).rebalance().build(); assertThat(result.allAssignments(), is(aMapWithSize(2))); @@ -489,7 +500,8 @@ public void testRebalance_GivenPreviousAssignments_AndRemovedNode_AndRemainingNo nodeLoads, Map.of(List.of(), List.of(node1)), Optional.empty(), - 1 + 1, + false ).rebalance().build(); assertThat(result.allAssignments(), is(aMapWithSize(2))); @@ -559,7 +571,8 @@ public void testRebalance_GivenPreviousAssignments_AndRemovedNode_AndRemainingNo nodeLoads, Map.of(List.of(), List.of(node1)), Optional.empty(), - 1 + 1, + false ).rebalance().build(); assertThat(result.allAssignments(), is(aMapWithSize(2))); @@ -608,7 +621,8 @@ public void testRebalance_GivenFailedAssignment_RestartsAssignment() throws Exce nodeLoads, Map.of(List.of(), List.of(node1)), Optional.empty(), - 1 + 1, + false ).rebalance().build(); assertThat(result.allAssignments(), is(aMapWithSize(1))); @@ -642,7 +656,8 @@ public void testRebalance_GivenLowPriorityModelToAdd_OnlyModel_NotEnoughMemory() nodeLoads, Map.of(), Optional.of(taskParams), - 1 + 1, + false ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(deploymentId); @@ -658,8 +673,8 @@ public void testRebalance_GivenLowPriorityModelToAdd_OnlyModel_NotEnoughMemory() public void testRebalance_GivenLowPriorityModelToAdd_NotEnoughMemoryNorProcessors() throws Exception { long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes(); - DiscoveryNode node1 = buildNode("node-1", nodeMemoryBytes, 1); - DiscoveryNode node2 = buildNode("node-2", nodeMemoryBytes, 1); + DiscoveryNode node1 = buildNode("node-1", nodeMemoryBytes, 8); + DiscoveryNode node2 = buildNode("node-2", nodeMemoryBytes, 8); Map nodeLoads = new HashMap<>(); nodeLoads.put(node1, NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build()); @@ -688,7 +703,8 @@ public void testRebalance_GivenLowPriorityModelToAdd_NotEnoughMemoryNorProcessor nodeLoads, Map.of(List.of("zone-1"), List.of(node1), List.of("zone-2"), List.of(node2)), Optional.of(taskParams1), - 1 + 1, + false ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(deployment1); @@ -727,7 +743,8 @@ public void testRebalance_GivenMixedPriorityModels_NotEnoughMemoryForLowPriority nodeLoads, Map.of(List.of(), List.of(node1)), Optional.empty(), - 1 + 1, + false ).rebalance().build(); { @@ -780,7 +797,8 @@ public void testRebalance_GivenMixedPriorityModels_TwoZones_EachNodeCanHoldOneMo nodeLoads, Map.of(List.of("zone-1"), List.of(node1), List.of("zone-2"), List.of(node2)), Optional.empty(), - 1 + 1, + false ).rebalance().build(); List assignedNodes = new ArrayList<>(); @@ -834,7 +852,8 @@ public void testRebalance_GivenModelUsingAllCpu_FittingLowPriorityModelCanStart( nodeLoads, Map.of(List.of(), List.of(node1)), Optional.empty(), - 1 + 1, + false ).rebalance().build(); { @@ -884,7 +903,8 @@ public void testRebalance_GivenMultipleLowPriorityModels_AndMultipleNodes() thro nodeLoads, Map.of(List.of(), List.of(node1, node2)), Optional.empty(), - 1 + 1, + false ).rebalance().build(); { @@ -934,7 +954,8 @@ public void testRebalance_GivenNormalPriorityModelToLoad_EvictsLowPriorityModel( nodeLoads, Map.of(List.of(), List.of(node1)), Optional.of(taskParams2), - 1 + 1, + false ).rebalance().build(); { @@ -986,7 +1007,8 @@ public void testRebalance_GivenNormalPriorityModelToLoad_AndLowPriorityModelCanS nodeLoads, Map.of(List.of(), List.of(node1, node2)), Optional.of(taskParams2), - 1 + 1, + false ).rebalance().build(); { @@ -1038,7 +1060,8 @@ public void testRebalance_GivenNormalPriorityModelToLoad_AndLowPriorityModelMust nodeLoads, Map.of(List.of(), List.of(node1, node2)), Optional.of(taskParams2), - 1 + 1, + false ).rebalance().build(); { @@ -1084,7 +1107,8 @@ public void testRebalance_GivenFirstModelToAdd_GivenScalingProcessorSetting() { nodeLoads, Map.of(List.of(), List.of(node)), Optional.of(taskParams), - 2 + 2, + false ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(modelId); @@ -1106,7 +1130,8 @@ public void testRebalance_GivenFirstModelToAdd_GivenScalingProcessorSetting() { nodeLoads, Map.of(List.of(), List.of(node)), Optional.of(taskParams), - 1 + 1, + false ).rebalance().build(); assignment = result.getDeploymentAssignment(modelId); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java index 3ecdd5000ba35..cbbb38f1d1ddd 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.assignment.planning; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Deployment; import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Node; @@ -24,109 +25,248 @@ public class AssignmentPlanTests extends ESTestCase { public void testBuilderCtor_GivenDuplicateNode() { Node n = new Node("n_1", 100, 4); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0, 0, 0); expectThrows(IllegalArgumentException.class, () -> AssignmentPlan.builder(List.of(n, n), List.of(m))); } public void testBuilderCtor_GivenDuplicateModel() { Node n = new Node("n_1", 100, 4); - Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0); + Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0, 0, 0); expectThrows(IllegalArgumentException.class, () -> AssignmentPlan.builder(List.of(n), List.of(m, m))); } public void testAssignModelToNode_GivenNoPreviousAssignment() { - Node n = new Node("n_1", 100, 4); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0); + Node n = new Node("n_1", ByteSizeValue.ofMb(350).getBytes(), 4); - AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + { // old memory format + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(40).getBytes(), 1, 2, Map.of(), 0, 0, 0); - assertThat(builder.getRemainingCores(n), equalTo(4)); - assertThat(builder.getRemainingMemory(n), equalTo(100L)); - assertThat(builder.getRemainingAllocations(m), equalTo(1)); - assertThat(builder.getRemainingThreads(m), equalTo(2)); + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); - builder.assignModelToNode(m, n, 1); + assertThat(builder.getRemainingCores(n), equalTo(4)); + assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(350).getBytes())); + assertThat(builder.getRemainingAllocations(m), equalTo(1)); + assertThat(builder.getRemainingThreads(m), equalTo(2)); - assertThat(builder.getRemainingCores(n), equalTo(2)); - assertThat(builder.getRemainingMemory(n), equalTo(60L)); - assertThat(builder.getRemainingAllocations(m), equalTo(0)); - assertThat(builder.getRemainingThreads(m), equalTo(0)); + builder.assignModelToNode(m, n, 1); - AssignmentPlan plan = builder.build(); + assertThat(builder.getRemainingCores(n), equalTo(2)); + assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(30).getBytes())); + assertThat(builder.getRemainingAllocations(m), equalTo(0)); + assertThat(builder.getRemainingThreads(m), equalTo(0)); - assertThat(plan.models(), contains(m)); - assertThat(plan.satisfiesCurrentAssignments(), is(true)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + AssignmentPlan plan = builder.build(); + + assertThat(plan.models(), contains(m)); + assertThat(plan.satisfiesCurrentAssignments(), is(true)); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + } + { // new memory format + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(20).getBytes(), + 1, + 2, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(30).getBytes() + ); + + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + + assertThat(builder.getRemainingCores(n), equalTo(4)); + assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(350).getBytes())); + assertThat(builder.getRemainingAllocations(m), equalTo(1)); + assertThat(builder.getRemainingThreads(m), equalTo(2)); + + builder.assignModelToNode(m, n, 1); + + assertThat(builder.getRemainingCores(n), equalTo(2)); + assertThat(builder.getRemainingMemory(n), equalTo(0L)); + assertThat(builder.getRemainingAllocations(m), equalTo(0)); + assertThat(builder.getRemainingThreads(m), equalTo(0)); + + AssignmentPlan plan = builder.build(); + + assertThat(plan.models(), contains(m)); + assertThat(plan.satisfiesCurrentAssignments(), is(true)); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + } } public void testAssignModelToNode_GivenNewPlanSatisfiesCurrentAssignment() { - Node n = new Node("n_1", 100, 4); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 40, 2, 2, Map.of("n_1", 1), 0); + Node n = new Node("n_1", ByteSizeValue.ofMb(350).getBytes(), 4); + { // old memory format + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(30).getBytes(), + 2, + 2, + Map.of("n_1", 1), + 0, + 0, + 0 + ); - AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); - builder.assignModelToNode(m, n, 1); + builder.assignModelToNode(m, n, 1); - assertThat(builder.getRemainingCores(n), equalTo(2)); - assertThat(builder.getRemainingMemory(n), equalTo(100L)); - assertThat(builder.getRemainingAllocations(m), equalTo(1)); - assertThat(builder.getRemainingThreads(m), equalTo(2)); + assertThat(builder.getRemainingCores(n), equalTo(2)); + assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(350).getBytes())); + assertThat(builder.getRemainingAllocations(m), equalTo(1)); + assertThat(builder.getRemainingThreads(m), equalTo(2)); - AssignmentPlan plan = builder.build(); + AssignmentPlan plan = builder.build(); - assertThat(plan.models(), contains(m)); - assertThat(plan.satisfiesCurrentAssignments(), is(true)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + assertThat(plan.models(), contains(m)); + assertThat(plan.satisfiesCurrentAssignments(), is(true)); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + } + { // new memory format + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(25).getBytes(), + 2, + 2, + Map.of("n_1", 1), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(25).getBytes() + ); + + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + + builder.assignModelToNode(m, n, 1); + + assertThat(builder.getRemainingCores(n), equalTo(2)); + assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(325).getBytes())); + assertThat(builder.getRemainingAllocations(m), equalTo(1)); + assertThat(builder.getRemainingThreads(m), equalTo(2)); + + AssignmentPlan plan = builder.build(); + + assertThat(plan.models(), contains(m)); + assertThat(plan.satisfiesCurrentAssignments(), is(true)); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + + } } public void testAssignModelToNode_GivenNewPlanDoesNotSatisfyCurrentAssignment() { - Node n = new Node("n_1", 100, 4); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 40, 2, 2, Map.of("n_1", 2), 0); + Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 4); + { + // old memory format + Deployment m = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 0, 0, 0); - AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); - builder.assignModelToNode(m, n, 1); + builder.assignModelToNode(m, n, 1); - assertThat(builder.getRemainingCores(n), equalTo(2)); - assertThat(builder.getRemainingMemory(n), equalTo(100L)); - assertThat(builder.getRemainingAllocations(m), equalTo(1)); - assertThat(builder.getRemainingThreads(m), equalTo(2)); + assertThat(builder.getRemainingCores(n), equalTo(2)); + assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(300).getBytes())); + assertThat(builder.getRemainingAllocations(m), equalTo(1)); + assertThat(builder.getRemainingThreads(m), equalTo(2)); - AssignmentPlan plan = builder.build(); + AssignmentPlan plan = builder.build(); - assertThat(plan.models(), contains(m)); - assertThat(plan.satisfiesCurrentAssignments(), is(false)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + assertThat(plan.models(), contains(m)); + assertThat(plan.satisfiesCurrentAssignments(), is(false)); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + } + { + // new memory format + Deployment m = new Deployment( + "m_1", + ByteSizeValue.ofMb(25).getBytes(), + 2, + 2, + Map.of("n_1", 2), + 0, + ByteSizeValue.ofMb(250).getBytes(), + ByteSizeValue.ofMb(25).getBytes() + ); + + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + + builder.assignModelToNode(m, n, 1); + + assertThat(builder.getRemainingCores(n), equalTo(2)); + assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(275).getBytes())); + assertThat(builder.getRemainingAllocations(m), equalTo(1)); + assertThat(builder.getRemainingThreads(m), equalTo(2)); + + AssignmentPlan plan = builder.build(); + + assertThat(plan.models(), contains(m)); + assertThat(plan.satisfiesCurrentAssignments(), is(false)); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + } } public void testAssignModelToNode_GivenPreviouslyUnassignedModelDoesNotFit() { - Node n = new Node("n_1", 100, 4); - Deployment m = new AssignmentPlan.Deployment("m_1", 101, 2, 2, Map.of(), 0); + Node n = new Node("n_1", ByteSizeValue.ofMb(340 - 1).getBytes(), 4); + Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 2, 2, Map.of(), 0, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 1)); - assertThat(e.getMessage(), equalTo("not enough memory on node [n_1] to assign model [m_1]")); + assertThat(e.getMessage(), equalTo("not enough memory on node [n_1] to assign [1] allocations to deployment [m_1]")); } public void testAssignModelToNode_GivenPreviouslyAssignedModelDoesNotFit() { - Node n = new Node("n_1", 100, 4); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 101, 2, 2, Map.of("n_1", 1), 0); + { // old memory format + Node n = new Node("n_1", ByteSizeValue.ofMb(340 - 1).getBytes(), 4); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(50).getBytes(), + 2, + 2, + Map.of("n_1", 1), + 0, + 0, + 0 + ); - AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); - builder.assignModelToNode(m, n, 2); - AssignmentPlan plan = builder.build(); + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); - assertThat(plan.models(), contains(m)); - assertThat(plan.satisfiesCurrentAssignments(), is(true)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 2))); + builder.assignModelToNode(m, n, 2); + AssignmentPlan plan = builder.build(); + + assertThat(plan.models(), contains(m)); + assertThat(plan.satisfiesCurrentAssignments(), is(true)); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 2))); + } + { // new memory format + Node n = new Node("n_1", ByteSizeValue.ofMb(340 - 1).getBytes(), 4); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(30).getBytes(), + 2, + 2, + Map.of("n_1", 1), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(5).getBytes() + ); + + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + + builder.assignModelToNode(m, n, 2); + AssignmentPlan plan = builder.build(); + + assertThat(plan.models(), contains(m)); + assertThat(plan.satisfiesCurrentAssignments(), is(true)); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 2))); + } } public void testAssignModelToNode_GivenNotEnoughCores_AndSingleThreadPerAllocation() { - Node n = new Node("n_1", 100, 4); - Deployment m = new AssignmentPlan.Deployment("m_1", 100, 5, 1, Map.of(), 0); + Node n = new Node("n_1", ByteSizeValue.ofMb(500).getBytes(), 4); + Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 5, 1, Map.of(), 0, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 5)); @@ -138,8 +278,8 @@ public void testAssignModelToNode_GivenNotEnoughCores_AndSingleThreadPerAllocati } public void testAssignModelToNode_GivenNotEnoughCores_AndMultipleThreadsPerAllocation() { - Node n = new Node("n_1", 100, 5); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 100, 3, 2, Map.of(), 0); + Node n = new Node("n_1", ByteSizeValue.ofMb(500).getBytes(), 5); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 3, 2, Map.of(), 0, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 3)); @@ -151,13 +291,22 @@ public void testAssignModelToNode_GivenNotEnoughCores_AndMultipleThreadsPerAlloc } public void testAssignModelToNode_GivenSameModelAssignedTwice() { - Node n = new Node("n_1", 100, 8); - Deployment m = new AssignmentPlan.Deployment("m_1", 60, 4, 2, Map.of(), 0); + Node n = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 8); + Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(50).getBytes(), + 4, + 2, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(50).getBytes() + ); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); assertThat(builder.getRemainingCores(n), equalTo(8)); - assertThat(builder.getRemainingMemory(n), equalTo(100L)); + assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(1000).getBytes())); assertThat(builder.getRemainingAllocations(m), equalTo(4)); assertThat(builder.getRemainingThreads(m), equalTo(8)); assertThat(builder.canAssign(m, n, 1), is(true)); @@ -165,7 +314,7 @@ public void testAssignModelToNode_GivenSameModelAssignedTwice() { builder.assignModelToNode(m, n, 1); assertThat(builder.getRemainingCores(n), equalTo(6)); - assertThat(builder.getRemainingMemory(n), equalTo(40L)); + assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(600).getBytes())); assertThat(builder.getRemainingAllocations(m), equalTo(3)); assertThat(builder.getRemainingThreads(m), equalTo(6)); assertThat(builder.canAssign(m, n, 2), is(true)); @@ -173,7 +322,7 @@ public void testAssignModelToNode_GivenSameModelAssignedTwice() { builder.assignModelToNode(m, n, 2); assertThat(builder.getRemainingCores(n), equalTo(2)); - assertThat(builder.getRemainingMemory(n), equalTo(40L)); + assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(500).getBytes())); assertThat(builder.getRemainingAllocations(m), equalTo(1)); assertThat(builder.getRemainingThreads(m), equalTo(2)); @@ -186,7 +335,7 @@ public void testAssignModelToNode_GivenSameModelAssignedTwice() { public void testCanAssign_GivenPreviouslyUnassignedModelDoesNotFit() { Node n = new Node("n_1", 100, 5); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 101, 1, 1, Map.of(), 0); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 101, 1, 1, Map.of(), 0, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); @@ -194,17 +343,33 @@ public void testCanAssign_GivenPreviouslyUnassignedModelDoesNotFit() { } public void testCanAssign_GivenPreviouslyAssignedModelDoesNotFit() { - Node n = new Node("n_1", 100, 5); - Deployment m = new AssignmentPlan.Deployment("m_1", 101, 1, 1, Map.of("n_1", 1), 0); - - AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); - - assertThat(builder.canAssign(m, n, 1), is(true)); + Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 5); + { + // old memory format + Deployment m = new Deployment("m_1", ByteSizeValue.ofMb(31).getBytes(), 1, 1, Map.of("n_1", 1), 0, 0, 0); + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + assertThat(builder.canAssign(m, n, 1), is(true)); + } + { + // new memory format + Deployment m = new Deployment( + "m_1", + ByteSizeValue.ofMb(25).getBytes(), + 1, + 1, + Map.of("n_1", 1), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(10).getBytes() + ); + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + assertThat(builder.canAssign(m, n, 1), is(true)); + } } public void testCanAssign_GivenEnoughMemory() { - Node n = new Node("n_1", 100, 5); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 100, 3, 2, Map.of(), 0); + Node n = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 5); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 3, 2, Map.of(), 0, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); @@ -216,16 +381,25 @@ public void testCanAssign_GivenEnoughMemory() { public void testCompareTo_GivenDifferenceInPreviousAssignments() { AssignmentPlan planSatisfyingPreviousAssignments; AssignmentPlan planNotSatisfyingPreviousAssignments; - Node n = new Node("n_1", 100, 5); + Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 5); { - Deployment m = new AssignmentPlan.Deployment("m_1", 100, 3, 2, Map.of("n_1", 2), 0); + Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 3, 2, Map.of("n_1", 2), 0, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); builder.assignModelToNode(m, n, 2); planSatisfyingPreviousAssignments = builder.build(); } { - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 100, 3, 2, Map.of("n_1", 3), 0); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(30).getBytes(), + 3, + 2, + Map.of("n_1", 3), + 0, + 0, + 0 + ); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); builder.assignModelToNode(m, n, 2); planNotSatisfyingPreviousAssignments = builder.build(); @@ -238,8 +412,17 @@ public void testCompareTo_GivenDifferenceInPreviousAssignments() { public void testCompareTo_GivenDifferenceInAllocations() { AssignmentPlan planWithMoreAllocations; AssignmentPlan planWithFewerAllocations; - Node n = new Node("n_1", 100, 5); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 100, 3, 2, Map.of("n_1", 1), 0); + Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 5); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(30).getBytes(), + 3, + 2, + Map.of("n_1", 1), + 0, + 0, + 0 + ); { AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); @@ -259,16 +442,25 @@ public void testCompareTo_GivenDifferenceInAllocations() { public void testCompareTo_GivenDifferenceInMemory() { AssignmentPlan planUsingMoreMemory; AssignmentPlan planUsingLessMemory; - Node n = new Node("n_1", 100, 5); + Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 5); { - Deployment m = new AssignmentPlan.Deployment("m_1", 100, 3, 2, Map.of("n_1", 1), 0); + Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 3, 2, Map.of("n_1", 1), 0, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); builder.assignModelToNode(m, n, 2); planUsingMoreMemory = builder.build(); } { - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 99, 3, 2, Map.of("n_1", 1), 0); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(29).getBytes(), + 3, + 2, + Map.of("n_1", 1), + 0, + 0, + 0 + ); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); builder.assignModelToNode(m, n, 2); planUsingLessMemory = builder.build(); @@ -279,26 +471,96 @@ public void testCompareTo_GivenDifferenceInMemory() { } public void testSatisfiesAllModels_GivenAllModelsAreSatisfied() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 50, 1, 2, Map.of(), 0); - AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment("m_2", 30, 2, 1, Map.of(), 0); - AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment("m_3", 20, 4, 1, Map.of(), 0); - AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) - .assignModelToNode(deployment1, node1, 1) - .assignModelToNode(deployment2, node2, 2) - .assignModelToNode(deployment3, node1, 2) - .assignModelToNode(deployment3, node2, 2) - .build(); - assertThat(plan.satisfiesAllModels(), is(true)); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); + { + // old memory format + AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(50).getBytes(), + 1, + 2, + Map.of(), + 0, + 0, + 0 + ); + AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment( + "m_2", + ByteSizeValue.ofMb(30).getBytes(), + 2, + 1, + Map.of(), + 0, + 0, + 0 + ); + AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment( + "m_3", + ByteSizeValue.ofMb(20).getBytes(), + 4, + 1, + Map.of(), + 0, + 0, + 0 + ); + AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) + .assignModelToNode(deployment1, node1, 1) + .assignModelToNode(deployment2, node2, 2) + .assignModelToNode(deployment3, node1, 2) + .assignModelToNode(deployment3, node2, 2) + .build(); + assertThat(plan.satisfiesAllModels(), is(true)); + } + { + // new memory format + AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(50).getBytes(), + 1, + 2, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(10).getBytes() + ); + AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment( + "m_2", + ByteSizeValue.ofMb(30).getBytes(), + 2, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(10).getBytes() + ); + AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment( + "m_3", + ByteSizeValue.ofMb(20).getBytes(), + 4, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(10).getBytes() + ); + AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) + .assignModelToNode(deployment1, node1, 1) + .assignModelToNode(deployment2, node2, 2) + .assignModelToNode(deployment3, node1, 2) + .assignModelToNode(deployment3, node2, 2) + .build(); + assertThat(plan.satisfiesAllModels(), is(true)); + } } public void testSatisfiesAllModels_GivenOneModelHasOneAllocationLess() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 50, 1, 2, Map.of(), 0); - AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment("m_2", 30, 2, 1, Map.of(), 0); - Deployment deployment3 = new Deployment("m_3", 20, 4, 1, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 0, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(20).getBytes(), 4, 1, Map.of(), 0, 0, 0); AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) .assignModelToNode(deployment1, node1, 1) .assignModelToNode(deployment2, node2, 2) @@ -309,11 +571,11 @@ public void testSatisfiesAllModels_GivenOneModelHasOneAllocationLess() { } public void testArePreviouslyAssignedModelsAssigned_GivenTrue() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 50, 1, 2, Map.of(), 3); - AssignmentPlan.Deployment deployment2 = new Deployment("m_2", 30, 2, 1, Map.of(), 4); - AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment("m_3", 20, 4, 1, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 4, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(20).getBytes(), 4, 1, Map.of(), 0, 0, 0); AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) .assignModelToNode(deployment1, node1, 1) .assignModelToNode(deployment2, node2, 1) @@ -322,10 +584,10 @@ public void testArePreviouslyAssignedModelsAssigned_GivenTrue() { } public void testArePreviouslyAssignedModelsAssigned_GivenFalse() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - AssignmentPlan.Deployment deployment1 = new Deployment("m_1", 50, 1, 2, Map.of(), 3); - AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment("m_2", 30, 2, 1, Map.of(), 4); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 4, 0, 0); AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) .assignModelToNode(deployment1, node1, 1) .build(); @@ -333,12 +595,39 @@ public void testArePreviouslyAssignedModelsAssigned_GivenFalse() { } public void testCountPreviouslyAssignedThatAreStillAssigned() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 50, 1, 2, Map.of(), 3); - AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment("m_2", 30, 2, 1, Map.of(), 4); - AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment("m_3", 20, 4, 1, Map.of(), 1); - AssignmentPlan.Deployment deployment4 = new AssignmentPlan.Deployment("m_4", 20, 4, 1, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); + Deployment deployment1 = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, 0, 0); + AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment( + "m_2", + ByteSizeValue.ofMb(30).getBytes(), + 2, + 1, + Map.of(), + 4, + 0, + 0 + ); + AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment( + "m_3", + ByteSizeValue.ofMb(20).getBytes(), + 4, + 1, + Map.of(), + 1, + 0, + 0 + ); + AssignmentPlan.Deployment deployment4 = new AssignmentPlan.Deployment( + "m_4", + ByteSizeValue.ofMb(20).getBytes(), + 4, + 1, + Map.of(), + 0, + 0, + 0 + ); AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3, deployment4)) .assignModelToNode(deployment1, node1, 1) .assignModelToNode(deployment2, node2, 1) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java index 82a291a8d9fb2..6a72ccf4c4445 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java @@ -33,50 +33,144 @@ public class AssignmentPlannerTests extends ESTestCase { + private static long scaleNodeSize(long nodeMemory) { + // 240 Mb is the size in StartTrainedModelDeploymentAction.MEMORY_OVERHEAD + return ByteSizeValue.ofMb(240 + 2 * nodeMemory).getBytes(); + } + public void testModelThatDoesNotFitInMemory() { - List nodes = List.of(new Node("n_1", 100, 4)); - Deployment deployment = new AssignmentPlan.Deployment("m_1", 101, 4, 1, Map.of(), 0); - AssignmentPlan plan = new AssignmentPlanner(nodes, List.of(deployment)).computePlan(); - assertThat(plan.assignments(deployment).isEmpty(), is(true)); + { // Without perDeploymentMemory and perAllocationMemory specified + List nodes = List.of(new Node("n_1", scaleNodeSize(50), 4)); + Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(51).getBytes(), 4, 1, Map.of(), 0, 0, 0); + AssignmentPlan plan = new AssignmentPlanner(nodes, List.of(deployment)).computePlan(); + assertThat(plan.assignments(deployment).isEmpty(), is(true)); + } + { // With perDeploymentMemory and perAllocationMemory specified + List nodes = List.of(new Node("n_1", scaleNodeSize(55), 4)); + Deployment deployment = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(50).getBytes(), + 4, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(250).getBytes(), + ByteSizeValue.ofMb(51).getBytes() + ); + AssignmentPlan plan = new AssignmentPlanner(nodes, List.of(deployment)).computePlan(); + assertThat(plan.assignments(deployment).isEmpty(), is(true)); + } } public void testModelWithThreadsPerAllocationNotFittingOnAnyNode() { - List nodes = List.of(new Node("n_1", 100, 4), new Node("n_2", 100, 5)); - Deployment deployment = new AssignmentPlan.Deployment("m_1", 1, 1, 6, Map.of(), 0); + List nodes = List.of(new Node("n_1", scaleNodeSize(100), 4), new Node("n_2", scaleNodeSize(100), 5)); + Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(1).getBytes(), 1, 6, Map.of(), 0, 0, 0); AssignmentPlan plan = new AssignmentPlanner(nodes, List.of(deployment)).computePlan(); assertThat(plan.assignments(deployment).isEmpty(), is(true)); } public void testSingleModelThatFitsFullyOnSingleNode() { { - Node node = new Node("n_1", 100, 4); - Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 1, 1, Map.of(), 0); + Node node = new Node("n_1", scaleNodeSize(100), 4); + Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 1, 1, Map.of(), 0, 0, 0); + AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); + assertModelFullyAssignedToNode(plan, deployment, node); + } + { + Node node = new Node("n_1", scaleNodeSize(1000), 8); + Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(1000).getBytes(), 8, 1, Map.of(), 0, 0, 0); AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertModelFullyAssignedToNode(plan, deployment, node); } { - Node node = new Node("n_1", 1000, 8); - Deployment deployment = new Deployment("m_1", 1000, 8, 1, Map.of(), 0); + Node node = new Node("n_1", scaleNodeSize(10000), 16); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(10000).getBytes(), + 1, + 16, + Map.of(), + 0, + 0, + 0 + ); AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertModelFullyAssignedToNode(plan, deployment, node); } { - Node node = new Node("n_1", 10000, 16); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 10000, 1, 16, Map.of(), 0); + Node node = new Node("n_1", scaleNodeSize(100), 4); + Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 1, 1, Map.of(), 0, 0, 0); + AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); + assertModelFullyAssignedToNode(plan, deployment, node); + } + } + + public void testSingleModelThatFitsFullyOnSingleNode_NewMemoryFields() { + { + Node node = new Node("n_1", ByteSizeValue.ofMb(500).getBytes(), 4); + Deployment deployment = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 1, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(100).getBytes() + ); + AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); + assertModelFullyAssignedToNode(plan, deployment, node); + } + { + Node node = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 8); + Deployment deployment = new Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 8, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(100).getBytes(), + ByteSizeValue.ofMb(100).getBytes() + ); AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertModelFullyAssignedToNode(plan, deployment, node); } } public void testSingleModelThatFitsFullyOnSingleNode_GivenTwoNodes_ShouldBeFullyAssignedOnOneNode() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - AssignmentPlan.Deployment deployment = new Deployment("m_1", 100, 4, 1, Map.of(), 0); + Node node1 = new Node("n_1", scaleNodeSize(100), 4); + Node node2 = new Node("n_2", scaleNodeSize(100), 4); + AssignmentPlan.Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 4, 1, Map.of(), 0, 0, 0); AssignmentPlan plan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); Map assignments = plan.assignments(deployment).get(); - if (assignments.get(node1) > 0) { + if (assignments.get(node1) != null) { + assertThat(assignments.get(node1), equalTo(4)); + } else { + assertThat(assignments.get(node2), equalTo(4)); + } + } + + public void testSingleModelThatFitsFullyOnSingleNode_GivenTwoNodes_ShouldBeFullyAssignedOnOneNode_NewMemoryFields() { + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); + AssignmentPlan.Deployment deployment = new Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 4, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(150).getBytes() + ); + + AssignmentPlan plan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); + + Map assignments = plan.assignments(deployment).get(); + if (assignments.get(node1) != null) { assertThat(assignments.get(node1), equalTo(4)); } else { assertThat(assignments.get(node2), equalTo(4)); @@ -84,10 +178,53 @@ public void testSingleModelThatFitsFullyOnSingleNode_GivenTwoNodes_ShouldBeFully } public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerAllocation() { - AssignmentPlan.Deployment deployment = new Deployment("m_1", 30, 10, 1, Map.of(), 0); + AssignmentPlan.Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 10, 1, Map.of(), 0, 0, 0); + // Single node + { + Node node = new Node("n_1", scaleNodeSize(100), 4); + AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); + assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node), equalTo(4)); + } + // Two nodes + { + Node node1 = new Node("n_1", scaleNodeSize(100), 4); + Node node2 = new Node("n_2", scaleNodeSize(100), 2); + AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); + assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node1), equalTo(4)); + assertThat(assignments.get(node2), equalTo(2)); + } + // Three nodes + { + Node node1 = new Node("n_1", scaleNodeSize(100), 4); + Node node2 = new Node("n_2", scaleNodeSize(100), 2); + Node node3 = new Node("n_3", scaleNodeSize(100), 3); + AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment)).computePlan(); + assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node1), equalTo(4)); + assertThat(assignments.get(node2), equalTo(2)); + assertThat(assignments.get(node3), equalTo(3)); + } + } + + public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerAllocation_NewMemoryFields() { + AssignmentPlan.Deployment deployment = new Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 10, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(100).getBytes() + ); // Single node { - Node node = new Node("n_1", 100, 4); + Node node = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 4); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); Map assignments = assignmentPlan.assignments(deployment).get(); @@ -95,8 +232,8 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerA } // Two nodes { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 2); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(600).getBytes(), 2); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); Map assignments = assignmentPlan.assignments(deployment).get(); @@ -105,9 +242,9 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerA } // Three nodes { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 2); - Node node3 = new Node("n_3", 100, 3); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(600).getBytes(), 2); + Node node3 = new Node("n_3", ByteSizeValue.ofMb(700).getBytes(), 3); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); Map assignments = assignmentPlan.assignments(deployment).get(); @@ -118,14 +255,105 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerA } public void testMultipleModelsAndNodesWithSingleSolution() { - Node node1 = new Node("n_1", 100, 7); - Node node2 = new Node("n_2", 100, 7); - Node node3 = new Node("n_3", 100, 2); - Node node4 = new Node("n_4", 100, 2); - Deployment deployment1 = new Deployment("m_1", 50, 2, 4, Map.of(), 0); - AssignmentPlan.Deployment deployment2 = new Deployment("m_2", 50, 2, 3, Map.of(), 0); - Deployment deployment3 = new AssignmentPlan.Deployment("m_3", 50, 1, 2, Map.of(), 0); - AssignmentPlan.Deployment deployment4 = new AssignmentPlan.Deployment("m_4", 50, 2, 1, Map.of(), 0); + Node node1 = new Node("n_1", 2 * scaleNodeSize(50), 7); + Node node2 = new Node("n_2", 2 * scaleNodeSize(50), 7); + Node node3 = new Node("n_3", 2 * scaleNodeSize(50), 2); + Node node4 = new Node("n_4", 2 * scaleNodeSize(50), 2); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 2, 4, Map.of(), 0, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(50).getBytes(), 2, 3, Map.of(), 0, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 0, 0, 0); + Deployment deployment4 = new Deployment("m_4", ByteSizeValue.ofMb(50).getBytes(), 2, 1, Map.of(), 0, 0, 0); + + AssignmentPlan plan = new AssignmentPlanner( + List.of(node1, node2, node3, node4), + List.of(deployment1, deployment2, deployment3, deployment4) + ).computePlan(); + + { + assertThat(plan.assignments(deployment1).isPresent(), is(true)); + Map assignments = plan.assignments(deployment1).get(); + assertThat(assignments.get(node1), equalTo(1)); + assertThat(assignments.get(node2), equalTo(1)); + assertThat(assignments.get(node3), is(nullValue())); + assertThat(assignments.get(node4), is(nullValue())); + } + { + assertThat(plan.assignments(deployment2).isPresent(), is(true)); + Map assignments = plan.assignments(deployment2).get(); + assertThat(assignments.get(node1), equalTo(1)); + assertThat(assignments.get(node2), equalTo(1)); + assertThat(assignments.get(node3), is(nullValue())); + assertThat(assignments.get(node4), is(nullValue())); + } + { + assertThat(plan.assignments(deployment3).isPresent(), is(true)); + Map assignments = plan.assignments(deployment3).get(); + assertThat(assignments.get(node1), is(nullValue())); + assertThat(assignments.get(node2), is(nullValue())); + // Will either be on node 3 or 4 + Node assignedNode = assignments.get(node3) != null ? node3 : node4; + Node otherNode = assignedNode.equals(node3) ? node4 : node3; + assertThat(assignments.get(assignedNode), equalTo(1)); + assertThat(assignments.get(otherNode), is(nullValue())); + } + { + assertThat(plan.assignments(deployment4).isPresent(), is(true)); + Map assignments = plan.assignments(deployment4).get(); + assertThat(assignments.get(node1), is(nullValue())); + assertThat(assignments.get(node2), is(nullValue())); + // Will either be on node 3 or 4 + Node assignedNode = assignments.get(node3) != null ? node3 : node4; + Node otherNode = assignedNode.equals(node3) ? node4 : node3; + assertThat(assignments.get(assignedNode), equalTo(2)); + assertThat(assignments.get(otherNode), is(nullValue())); + } + } + + public void testMultipleModelsAndNodesWithSingleSolution_NewMemoryFields() { + Node node1 = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 7); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(800).getBytes(), 7); + Node node3 = new Node("n_3", ByteSizeValue.ofMb(900).getBytes(), 2); + Node node4 = new Node("n_4", ByteSizeValue.ofMb(900).getBytes(), 2); + Deployment deployment1 = new Deployment( + "m_1", + ByteSizeValue.ofMb(50).getBytes(), + 2, + 4, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(50).getBytes() + ); + Deployment deployment2 = new Deployment( + "m_2", + ByteSizeValue.ofMb(50).getBytes(), + 2, + 3, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(50).getBytes() + ); + Deployment deployment3 = new Deployment( + "m_3", + ByteSizeValue.ofMb(50).getBytes(), + 1, + 2, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(50).getBytes() + ); + Deployment deployment4 = new Deployment( + "m_4", + ByteSizeValue.ofMb(50).getBytes(), + 2, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(50).getBytes() + ); AssignmentPlan plan = new AssignmentPlanner( List.of(node1, node2, node3, node4), @@ -173,10 +401,53 @@ public void testMultipleModelsAndNodesWithSingleSolution() { } public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerAllocation() { - Deployment deployment = new AssignmentPlan.Deployment("m_1", 30, 10, 3, Map.of(), 0); + Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 10, 3, Map.of(), 0, 0, 0); + // Single node + { + Node node = new Node("n_1", scaleNodeSize(100), 4); + AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); + assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node), equalTo(1)); + } + // Two nodes + { + Node node1 = new Node("n_1", scaleNodeSize(100), 4); + Node node2 = new Node("n_2", scaleNodeSize(100), 8); + AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); + assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node1), equalTo(1)); + assertThat(assignments.get(node2), equalTo(2)); + } + // Three nodes + { + Node node1 = new Node("n_1", scaleNodeSize(100), 4); + Node node2 = new Node("n_2", scaleNodeSize(100), 7); + Node node3 = new Node("n_3", scaleNodeSize(100), 15); + AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment)).computePlan(); + assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node1), equalTo(1)); + assertThat(assignments.get(node2), equalTo(2)); + assertThat(assignments.get(node3), equalTo(5)); + } + } + + public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerAllocation_NewMemoryFields() { + Deployment deployment = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(50).getBytes(), + 10, + 3, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(50).getBytes() + ); // Single node { - Node node = new Node("n_1", 100, 4); + Node node = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 4); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); Map assignments = assignmentPlan.assignments(deployment).get(); @@ -184,8 +455,8 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerA } // Two nodes { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 8); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(800).getBytes(), 8); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); Map assignments = assignmentPlan.assignments(deployment).get(); @@ -194,9 +465,9 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerA } // Three nodes { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 7); - Node node3 = new Node("n_3", 100, 15); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(800).getBytes(), 7); + Node node3 = new Node("n_3", ByteSizeValue.ofMb(800).getBytes(), 15); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); Map assignments = assignmentPlan.assignments(deployment).get(); @@ -207,8 +478,17 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerA } public void testModelWithPreviousAssignmentAndNoMoreCoresAvailable() { - Node node = new Node("n_1", 100, 4); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 30, 4, 1, Map.of("n_1", 4), 0); + Node node = new Node("n_1", scaleNodeSize(100), 4); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(30).getBytes(), + 4, + 1, + Map.of("n_1", 4), + 0, + 0, + 0 + ); AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertThat(plan.assignments(deployment).isPresent(), is(true)); @@ -217,26 +497,117 @@ public void testModelWithPreviousAssignmentAndNoMoreCoresAvailable() { public void testFullCoreUtilization_GivenModelsWithSingleThreadPerAllocation() { List nodes = List.of( - new Node("n_1", ByteSizeValue.ofGb(6).getBytes(), 8), - new Node("n_2", ByteSizeValue.ofGb(6).getBytes(), 8), - new Node("n_3", ByteSizeValue.ofGb(6).getBytes(), 8), - new Node("n_4", ByteSizeValue.ofGb(6).getBytes(), 8), - new Node("n_5", ByteSizeValue.ofGb(16).getBytes(), 16), - new Node("n_6", ByteSizeValue.ofGb(8).getBytes(), 16) + new Node("n_1", ByteSizeValue.ofGb(18).getBytes(), 8), + new Node("n_2", ByteSizeValue.ofGb(18).getBytes(), 8), + new Node("n_3", ByteSizeValue.ofGb(18).getBytes(), 8), + new Node("n_4", ByteSizeValue.ofGb(18).getBytes(), 8), + new Node("n_5", ByteSizeValue.ofGb(64).getBytes(), 16), + new Node("n_6", ByteSizeValue.ofGb(32).getBytes(), 16) ); List deployments = List.of( - new Deployment("m_1", ByteSizeValue.ofGb(4).getBytes(), 10, 1, Map.of("n_1", 5), 0), - new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of("n_3", 2), 0), - new AssignmentPlan.Deployment("m_3", ByteSizeValue.ofGb(3).getBytes(), 3, 1, Map.of(), 0), - new Deployment("m_4", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of("n_3", 2), 0), - new Deployment("m_5", ByteSizeValue.ofGb(6).getBytes(), 2, 1, Map.of(), 0), - new Deployment("m_6", ByteSizeValue.ofGb(1).getBytes(), 12, 1, Map.of(), 0), - new AssignmentPlan.Deployment("m_7", ByteSizeValue.ofGb(1).getBytes() / 2, 12, 1, Map.of("n_2", 6), 0), - new Deployment("m_8", ByteSizeValue.ofGb(2).getBytes(), 4, 1, Map.of(), 0), - new Deployment("m_9", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of(), 0), - new AssignmentPlan.Deployment("m_10", ByteSizeValue.ofGb(7).getBytes(), 7, 1, Map.of(), 0), - new Deployment("m_11", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of(), 0), - new Deployment("m_12", ByteSizeValue.ofGb(1).getBytes(), 10, 1, Map.of(), 0) + new Deployment("m_1", ByteSizeValue.ofGb(4).getBytes(), 10, 1, Map.of("n_1", 5), 0, 0, 0), + new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of("n_3", 2), 0, 0, 0), + new AssignmentPlan.Deployment("m_3", ByteSizeValue.ofGb(3).getBytes(), 3, 1, Map.of(), 0, 0, 0), + new Deployment("m_4", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of("n_3", 2), 0, 0, 0), + new Deployment("m_5", ByteSizeValue.ofGb(6).getBytes(), 2, 1, Map.of(), 0, 0, 0), + new Deployment("m_6", ByteSizeValue.ofGb(1).getBytes(), 12, 1, Map.of(), 0, 0, 0), + new AssignmentPlan.Deployment("m_7", ByteSizeValue.ofGb(1).getBytes() / 2, 12, 1, Map.of("n_2", 6), 0, 0, 0), + new Deployment("m_8", ByteSizeValue.ofGb(2).getBytes(), 4, 1, Map.of(), 0, 0, 0), + new Deployment("m_9", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of(), 0, 0, 0), + new AssignmentPlan.Deployment("m_10", ByteSizeValue.ofGb(7).getBytes(), 7, 1, Map.of(), 0, 0, 0), + new Deployment("m_11", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of(), 0, 0, 0), + new Deployment("m_12", ByteSizeValue.ofGb(1).getBytes(), 10, 1, Map.of(), 0, 0, 0) + ); + + AssignmentPlan assignmentPlan = new AssignmentPlanner(nodes, deployments).computePlan(); + + int usedCores = 0; + for (AssignmentPlan.Deployment m : deployments) { + Map assignments = assignmentPlan.assignments(m).orElse(Map.of()); + usedCores += assignments.values().stream().mapToInt(Integer::intValue).sum(); + } + assertThat(usedCores, equalTo(64)); + + assertPreviousAssignmentsAreSatisfied(deployments, assignmentPlan); + } + + public void testFullCoreUtilization_GivenModelsWithSingleThreadPerAllocation_NewMemoryFields() { + List nodes = List.of( + new Node("n_1", ByteSizeValue.ofGb(18).getBytes(), 8), + new Node("n_2", ByteSizeValue.ofGb(18).getBytes(), 8), + new Node("n_3", ByteSizeValue.ofGb(18).getBytes(), 8), + new Node("n_4", ByteSizeValue.ofGb(18).getBytes(), 8), + new Node("n_5", ByteSizeValue.ofGb(64).getBytes(), 16), + new Node("n_6", ByteSizeValue.ofGb(32).getBytes(), 16) + ); + // Use mix of old and new memory fields + List deployments = List.of( + new Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 10, + 1, + Map.of("n_1", 5), + 0, + ByteSizeValue.ofMb(400).getBytes(), + ByteSizeValue.ofMb(100).getBytes() + ), + new Deployment("m_2", ByteSizeValue.ofMb(100).getBytes(), 3, 1, Map.of("n_3", 2), 0, 0, 0), + new Deployment( + "m_3", + ByteSizeValue.ofMb(50).getBytes(), + 3, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(50).getBytes() + ), + new Deployment( + "m_4", + ByteSizeValue.ofMb(50).getBytes(), + 4, + 1, + Map.of("n_3", 2), + 0, + ByteSizeValue.ofMb(400).getBytes(), + ByteSizeValue.ofMb(100).getBytes() + ), + new Deployment( + "m_5", + ByteSizeValue.ofMb(500).getBytes(), + 2, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(800).getBytes(), + ByteSizeValue.ofMb(100).getBytes() + ), + new Deployment( + "m_6", + ByteSizeValue.ofMb(50).getBytes(), + 12, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(50).getBytes(), + ByteSizeValue.ofMb(20).getBytes() + ), + new Deployment( + "m_7", + ByteSizeValue.ofMb(50).getBytes(), + 12, + 1, + Map.of("n_2", 6), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(50).getBytes() + ), + new Deployment("m_8", ByteSizeValue.ofGb(2).getBytes(), 4, 1, Map.of(), 0, 0, 0), + new Deployment("m_9", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of(), 0, 0, 0), + new Deployment("m_10", ByteSizeValue.ofGb(7).getBytes(), 7, 1, Map.of(), 0, 0, 0), + new Deployment("m_11", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of(), 0, 0, 0), + new Deployment("m_12", ByteSizeValue.ofGb(1).getBytes(), 10, 1, Map.of(), 0, 0, 0) ); AssignmentPlan assignmentPlan = new AssignmentPlanner(nodes, deployments).computePlan(); @@ -297,6 +668,9 @@ public void testRandomBenchmark() { StopWatch stopWatch = new StopWatch(); stopWatch.start(); AssignmentPlan assignmentPlan = solver.computePlan(); + for (Node node : nodes) { + assertThat(assignmentPlan.getRemainingNodeMemory(node.id()), greaterThanOrEqualTo(0L)); + } stopWatch.stop(); Quality quality = computeQuality(nodes, deployments, assignmentPlan); @@ -336,7 +710,16 @@ public void testPreviousAssignmentsGetAtLeastAsManyAllocationsAfterAddingNewMode .stream() .collect(Collectors.toMap(e -> e.getKey().id(), Map.Entry::getValue)); previousModelsPlusNew.add( - new AssignmentPlan.Deployment(m.id(), m.memoryBytes(), m.allocations(), m.threadsPerAllocation(), previousAssignments, 0) + new AssignmentPlan.Deployment( + m.id(), + m.memoryBytes(), + m.allocations(), + m.threadsPerAllocation(), + previousAssignments, + 0, + 0, + 0 + ) ); } previousModelsPlusNew.add(randomModel("new")); @@ -347,18 +730,20 @@ public void testPreviousAssignmentsGetAtLeastAsManyAllocationsAfterAddingNewMode } public void testGivenLargerModelWithPreviousAssignmentsAndSmallerModelWithoutAssignments() { - Node node1 = new Node("n_1", ByteSizeValue.ofGb(2).getBytes(), 2); - Node node2 = new Node("n_2", ByteSizeValue.ofGb(2).getBytes(), 2); - Node node3 = new Node("n_3", ByteSizeValue.ofGb(2).getBytes(), 2); + Node node1 = new Node("n_1", scaleNodeSize(ByteSizeValue.ofGb(2).getMb()), 2); + Node node2 = new Node("n_2", scaleNodeSize(ByteSizeValue.ofGb(2).getMb()), 2); + Node node3 = new Node("n_3", scaleNodeSize(ByteSizeValue.ofGb(2).getMb()), 2); Deployment deployment1 = new AssignmentPlan.Deployment( "m_1", ByteSizeValue.ofMb(1200).getBytes(), 3, 1, Map.of("n_1", 2, "n_2", 1), + 0, + 0, 0 ); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 2, 1, Map.of(), 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 2, 1, Map.of(), 0, 0, 0); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment1, deployment2)) .computePlan(); assertThat(assignmentPlan.getRemainingNodeMemory("n_1"), greaterThanOrEqualTo(0L)); @@ -381,15 +766,17 @@ public void testGivenLargerModelWithPreviousAssignmentsAndSmallerModelWithoutAss } public void testModelWithoutCurrentAllocationsGetsAssignedIfAllocatedPreviously() { - Node node1 = new Node("n_1", ByteSizeValue.ofGb(4).getBytes(), 2); - Node node2 = new Node("n_2", ByteSizeValue.ofGb(4).getBytes(), 2); + Node node1 = new Node("n_1", ByteSizeValue.ofGb(6).getBytes(), 2); + Node node2 = new Node("n_2", ByteSizeValue.ofGb(6).getBytes(), 2); AssignmentPlan.Deployment deployment1 = new Deployment( "m_1", ByteSizeValue.ofMb(1200).getBytes(), 3, 1, Map.of("n_1", 2, "n_2", 1), - 3 + 3, + 0, + 0 ); AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment( "m_2", @@ -397,35 +784,84 @@ public void testModelWithoutCurrentAllocationsGetsAssignedIfAllocatedPreviously( 1, 2, Map.of(), - 1 + 1, + 0, + 0 ); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1, deployment2)).computePlan(); Map> indexedBasedPlan = convertToIdIndexed(assignmentPlan); assertThat(indexedBasedPlan.keySet(), hasItems("m_1", "m_2")); - assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2))); - assertThat(indexedBasedPlan.get("m_2"), equalTo(Map.of("n_2", 1))); + if (indexedBasedPlan.get("m_2").containsKey("n_1")) { + assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_2", 2))); + assertThat(indexedBasedPlan.get("m_2"), equalTo(Map.of("n_1", 1))); + } else { + assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2))); + assertThat(indexedBasedPlan.get("m_2"), equalTo(Map.of("n_2", 1))); + } assertThat(assignmentPlan.getRemainingNodeMemory("n_1"), greaterThanOrEqualTo(0L)); assertThat(assignmentPlan.getRemainingNodeMemory("n_2"), greaterThanOrEqualTo(0L)); } public void testGivenPreviouslyAssignedModels_CannotAllBeAllocated() { - Node node1 = new Node("n_1", ByteSizeValue.ofGb(2).getBytes(), 2); - AssignmentPlan.Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(1200).getBytes(), 1, 1, Map.of(), 1); - AssignmentPlan.Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 1, 1, Map.of(), 1); + Node node1 = new Node("n_1", scaleNodeSize(ByteSizeValue.ofGb(2).getMb()), 2); + AssignmentPlan.Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(1200).getBytes(), 1, 1, Map.of(), 1, 0, 0); + AssignmentPlan.Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 1, 1, Map.of(), 1, 0, 0); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1), List.of(deployment1, deployment2)).computePlan(); assertThat(assignmentPlan.countPreviouslyAssignedModelsThatAreStillAssigned(), equalTo(1L)); } + public void testGivenClusterResize_AllocationShouldNotExceedMemoryConstraints() { + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1840).getBytes(), 2); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(2580).getBytes(), 2); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, 0, 0); + Deployment deployment2 = new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, 0, 0); + + // First only start m_1 + AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1)).computePlan(); + + Map> indexedBasedPlan = convertToIdIndexed(assignmentPlan); + assertThat(indexedBasedPlan.keySet(), hasItems("m_1")); + assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2))); + + // Then start m_2 + assignmentPlan = new AssignmentPlanner( + List.of(node1, node2), + Stream.concat(createModelsFromPlan(assignmentPlan).stream(), Stream.of(deployment2)).toList() + ).computePlan(); + + indexedBasedPlan = convertToIdIndexed(assignmentPlan); + assertThat(indexedBasedPlan.keySet(), hasItems("m_1", "m_2")); + assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2))); + assertThat(indexedBasedPlan.get("m_2"), equalTo(Map.of("n_2", 1))); + + // Then start m_3 + assignmentPlan = new AssignmentPlanner( + List.of(node1, node2), + Stream.concat(createModelsFromPlan(assignmentPlan).stream(), Stream.of(deployment3)).toList() + ).computePlan(); + + indexedBasedPlan = convertToIdIndexed(assignmentPlan); + assertThat(indexedBasedPlan.keySet(), hasItems("m_1", "m_2", "m_3")); + assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2))); + assertThat(indexedBasedPlan.get("m_2"), equalTo(Map.of("n_2", 1))); + assertThat(indexedBasedPlan.get("m_3"), equalTo(Map.of("n_2", 1))); + + // First, one node goes away. + assignmentPlan = new AssignmentPlanner(List.of(node1), createModelsFromPlan(assignmentPlan)).computePlan(); + assertThat(assignmentPlan.getRemainingNodeMemory("n_1"), greaterThanOrEqualTo(0L)); + } + public void testGivenClusterResize_ShouldAllocateEachModelAtLeastOnce() { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(1200).getBytes(), 2); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(1200).getBytes(), 2); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(2600).getBytes(), 2); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(2600).getBytes(), 2); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, 0, 0); // First only start m_1 AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1)).computePlan(); @@ -458,8 +894,8 @@ public void testGivenClusterResize_ShouldAllocateEachModelAtLeastOnce() { assertThat(indexedBasedPlan.get("m_3"), equalTo(Map.of("n_2", 1))); // Now the cluster starts getting resized. - Node node3 = new Node("n_3", ByteSizeValue.ofMb(2400).getBytes(), 2); - Node node4 = new Node("n_4", ByteSizeValue.ofMb(2400).getBytes(), 2); + Node node3 = new Node("n_3", ByteSizeValue.ofMb(2600).getBytes(), 2); + Node node4 = new Node("n_4", ByteSizeValue.ofMb(2600).getBytes(), 2); // First, one node goes away. assignmentPlan = new AssignmentPlanner(List.of(node1), createModelsFromPlan(assignmentPlan)).computePlan(); @@ -492,11 +928,65 @@ public void testGivenClusterResize_ShouldAllocateEachModelAtLeastOnce() { public void testGivenClusterResize_ShouldRemoveAllocatedModels() { // Ensure that plan is removing previously allocated models if not enough memory is available - Node node1 = new Node("n_1", ByteSizeValue.ofMb(1200).getBytes(), 2); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(1200).getBytes(), 2); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 1, 1, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1840).getBytes(), 2); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(2580).getBytes(), 2); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 1, 1, Map.of(), 0, 0, 0); + + // Create a plan where all deployments are assigned at least once + AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) + .computePlan(); + Map> indexedBasedPlan = convertToIdIndexed(assignmentPlan); + assertThat(indexedBasedPlan.keySet(), hasItems("m_1", "m_2", "m_3")); + assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2))); + assertThat(indexedBasedPlan.get("m_2"), equalTo(Map.of("n_2", 1))); + assertThat(indexedBasedPlan.get("m_3"), equalTo(Map.of("n_2", 1))); + assertThat(assignmentPlan.getRemainingNodeMemory(node1.id()), greaterThanOrEqualTo(0L)); + assertThat(assignmentPlan.getRemainingNodeMemory(node2.id()), greaterThanOrEqualTo(0L)); + + // Now the cluster starts getting resized. Ensure that resources are not over-allocated. + assignmentPlan = new AssignmentPlanner(List.of(node1), createModelsFromPlan(assignmentPlan)).computePlan(); + assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2))); + assertThat(assignmentPlan.getRemainingNodeMemory(node1.id()), greaterThanOrEqualTo(0L)); + assertThat(assignmentPlan.getRemainingNodeCores(node1.id()), greaterThanOrEqualTo(0)); + + } + + public void testGivenClusterResize_ShouldRemoveAllocatedModels_NewMemoryFields() { + // Ensure that plan is removing previously allocated models if not enough memory is available + Node node1 = new Node("n_1", ByteSizeValue.ofMb(700).getBytes(), 2); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 2); + Deployment deployment1 = new Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 2, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(400).getBytes(), + ByteSizeValue.ofMb(100).getBytes() + ); + Deployment deployment2 = new Deployment( + "m_2", + ByteSizeValue.ofMb(100).getBytes(), + 1, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(400).getBytes(), + ByteSizeValue.ofMb(150).getBytes() + ); + Deployment deployment3 = new Deployment( + "m_3", + ByteSizeValue.ofMb(50).getBytes(), + 1, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(250).getBytes(), + ByteSizeValue.ofMb(50).getBytes() + ); // Create a plan where all deployments are assigned at least once AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) @@ -536,7 +1026,9 @@ public static List createModelsFromPlan(AssignmentPlan plan) { m.allocations(), m.threadsPerAllocation(), currentAllocations, - Math.max(m.maxAssignedAllocations(), totalAllocations) + Math.max(m.maxAssignedAllocations(), totalAllocations), + 0, + 0 ) ); } @@ -579,7 +1071,7 @@ public static List randomNodes(int scale, String nodeIdPrefix) { for (int i = 0; i < 1 + 3 * scale; i++) { int cores = randomIntBetween(2, 32); long memBytesPerCore = randomFrom(memBytesPerCoreValues); - nodes.add(new Node(nodeIdPrefix + "n_" + i, cores * memBytesPerCore, cores)); + nodes.add(new Node(nodeIdPrefix + "n_" + i, scaleNodeSize(ByteSizeValue.ofBytes(cores * memBytesPerCore).getMb()), cores)); } return nodes; } @@ -594,14 +1086,30 @@ public static List randomModels(int scale, double load) { public static Deployment randomModel(String idSuffix) { int allocations = randomIntBetween(1, 32); - return new Deployment( - "m_" + idSuffix, - randomLongBetween(ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofGb(10).getBytes()), - randomIntBetween(1, 32), - randomIntBetween(1, 4), - Map.of(), - 0 - ); + // randomly choose between old and new memory fields format + if (randomBoolean()) { + return new Deployment( + "m_" + idSuffix, + randomLongBetween(ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofGb(10).getBytes()), + randomIntBetween(1, 32), + randomIntBetween(1, 4), + Map.of(), + 0, + 0, + 0 + ); + } else { + return new Deployment( + "m_" + idSuffix, + randomLongBetween(ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofGb(1).getBytes()), + randomIntBetween(1, 32), + randomIntBetween(1, 4), + Map.of(), + 0, + randomLongBetween(ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofGb(1).getBytes()), + randomLongBetween(ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofGb(1).getBytes()) + ); + } } public static void assertPreviousAssignmentsAreSatisfied(List deployments, AssignmentPlan assignmentPlan) { @@ -628,7 +1136,7 @@ private void runTooManyNodesAndModels(int nodesSize, int modelsSize) { } List deployments = new ArrayList<>(); for (int i = 0; i < modelsSize; i++) { - deployments.add(new Deployment("m_" + i, ByteSizeValue.ofMb(200).getBytes(), 2, 1, Map.of(), 0)); + deployments.add(new Deployment("m_" + i, ByteSizeValue.ofMb(200).getBytes(), 2, 1, Map.of(), 0, 0, 0)); } // Check plan is computed without OOM exception diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java index 4a9b01e535d88..c45ce36394109 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.assignment.planning; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Deployment; import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Node; @@ -14,7 +15,6 @@ import java.util.List; import java.util.Map; -import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; @@ -22,77 +22,179 @@ public class PreserveAllAllocationsTests extends ESTestCase { public void testGivenNoPreviousAssignments() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - Deployment deployment1 = new Deployment("m_1", 30, 2, 1, Map.of(), 0); - Deployment deployment2 = new Deployment("m_2", 30, 2, 4, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 4, Map.of(), 0, 0, 0); PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations( List.of(node1, node2), List.of(deployment1, deployment2) ); - - List nodesPreservingAllocations = preserveAllAllocations.nodesPreservingAllocations(); - assertThat(nodesPreservingAllocations, contains(node1, node2)); - - List modelsPreservingAllocations = preserveAllAllocations.modelsPreservingAllocations(); - assertThat(modelsPreservingAllocations, contains(deployment1, deployment2)); } public void testGivenPreviousAssignments() { - Node node1 = new Node("n_1", 100, 8); - Node node2 = new Node("n_2", 100, 8); - Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 30, 2, 1, Map.of("n_1", 1), 1); - Deployment deployment2 = new Deployment("m_2", 50, 6, 4, Map.of("n_1", 1, "n_2", 2), 3); - PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations( - List.of(node1, node2), - List.of(deployment1, deployment2) - ); - - List nodesPreservingAllocations = preserveAllAllocations.nodesPreservingAllocations(); - assertThat(nodesPreservingAllocations, hasSize(2)); - - assertThat(nodesPreservingAllocations.get(0).id(), equalTo("n_1")); - assertThat(nodesPreservingAllocations.get(0).availableMemoryBytes(), equalTo(20L)); - assertThat(nodesPreservingAllocations.get(0).cores(), equalTo(3)); - - assertThat(nodesPreservingAllocations.get(1).id(), equalTo("n_2")); - assertThat(nodesPreservingAllocations.get(1).availableMemoryBytes(), equalTo(50L)); - assertThat(nodesPreservingAllocations.get(1).cores(), equalTo(0)); - - List modelsPreservingAllocations = preserveAllAllocations.modelsPreservingAllocations(); - assertThat(modelsPreservingAllocations, hasSize(2)); - - assertThat(modelsPreservingAllocations.get(0).id(), equalTo("m_1")); - assertThat(modelsPreservingAllocations.get(0).memoryBytes(), equalTo(30L)); - assertThat(modelsPreservingAllocations.get(0).allocations(), equalTo(1)); - assertThat(modelsPreservingAllocations.get(0).threadsPerAllocation(), equalTo(1)); - assertThat(modelsPreservingAllocations.get(0).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0))); - - assertThat(modelsPreservingAllocations.get(1).id(), equalTo("m_2")); - assertThat(modelsPreservingAllocations.get(1).memoryBytes(), equalTo(50L)); - assertThat(modelsPreservingAllocations.get(1).allocations(), equalTo(3)); - assertThat(modelsPreservingAllocations.get(1).threadsPerAllocation(), equalTo(4)); - assertThat(modelsPreservingAllocations.get(1).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0, "n_2", 0))); - - AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) - .assignModelToNode(deployment1, node1, 2) - .build(); - assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 2))); - assertThat(plan.assignments(deployment2).isEmpty(), is(true)); - - plan = preserveAllAllocations.mergePreservedAllocations(plan); - - assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 3))); - assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node1, 1, node2, 2))); - assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(20L)); - assertThat(plan.getRemainingNodeCores("n_1"), equalTo(1)); - assertThat(plan.getRemainingNodeMemory("n_2"), equalTo(50L)); - assertThat(plan.getRemainingNodeCores("n_2"), equalTo(0)); + { + // old memory format + Node node1 = new Node("n_1", ByteSizeValue.ofMb(640).getBytes(), 8); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(640).getBytes(), 8); + Deployment deployment1 = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(30).getBytes(), + 2, + 1, + Map.of("n_1", 1), + 1, + 0, + 0 + ); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(50).getBytes(), 6, 4, Map.of("n_1", 1, "n_2", 2), 3, 0, 0); + PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations( + List.of(node1, node2), + List.of(deployment1, deployment2) + ); + + List nodesPreservingAllocations = preserveAllAllocations.nodesPreservingAllocations(); + assertThat(nodesPreservingAllocations, hasSize(2)); + + assertThat(nodesPreservingAllocations.get(0).id(), equalTo("n_1")); + // 640 - [(2*30 + 240) + (2*50 + 240)] = 0: deployments use 640 MB on the node 1 + assertThat(nodesPreservingAllocations.get(0).availableMemoryBytes(), equalTo(0L)); + // 8 - (1*1+1*4) = 3 : deployments use 5 cores on the node + assertThat(nodesPreservingAllocations.get(0).cores(), equalTo(3)); + + assertThat(nodesPreservingAllocations.get(1).id(), equalTo("n_2")); + // 640 - (50*2+240) = 300 : deployments use 340MB on the node + assertThat(nodesPreservingAllocations.get(1).availableMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); + // 8 - (2*4) = 0 : preserving all allocation2 of deployment 2 should use 8 cores on the node + assertThat(nodesPreservingAllocations.get(1).cores(), equalTo(0)); + + List modelsPreservingAllocations = preserveAllAllocations.modelsPreservingAllocations(); + assertThat(modelsPreservingAllocations, hasSize(2)); + + assertThat(modelsPreservingAllocations.get(0).id(), equalTo("m_1")); + assertThat(modelsPreservingAllocations.get(0).memoryBytes(), equalTo(ByteSizeValue.ofMb(30).getBytes())); + assertThat(modelsPreservingAllocations.get(0).allocations(), equalTo(1)); + assertThat(modelsPreservingAllocations.get(0).threadsPerAllocation(), equalTo(1)); + assertThat(modelsPreservingAllocations.get(0).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0))); + + assertThat(modelsPreservingAllocations.get(1).id(), equalTo("m_2")); + assertThat(modelsPreservingAllocations.get(1).memoryBytes(), equalTo(ByteSizeValue.ofMb(50).getBytes())); + assertThat(modelsPreservingAllocations.get(1).allocations(), equalTo(3)); + assertThat(modelsPreservingAllocations.get(1).threadsPerAllocation(), equalTo(4)); + assertThat(modelsPreservingAllocations.get(1).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0, "n_2", 0))); + + // Now we have a plan with 2 deployments assigned to 2 nodes. + // Note that deployment 1 has already 1 allocation on node 1, and it gets 2 more. It's more than 2 allocations defined during + // initialization of deployment1, but we don't care at this point. + AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) + .assignModelToNode(deployment1, node1, 2) + .build(); + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 2))); + assertThat(plan.assignments(deployment2).isEmpty(), is(true)); + + plan = preserveAllAllocations.mergePreservedAllocations(plan); + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 3))); + assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node1, 1, node2, 2))); + + // Node 1 already had deployments 1 and 2 assigned to it so adding more allocation doesn't change memory usage. + assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(0L)); + // 8 - ((1*1+1*4) + 2*1) = 1 : deployments use 7 cores on the node + assertThat(plan.getRemainingNodeCores("n_1"), equalTo(1)); + // Nothing changed for Node 2 + assertThat(plan.getRemainingNodeMemory("n_2"), equalTo(ByteSizeValue.ofMb(300).getBytes())); + // Nothing changed for Node 2 + assertThat(plan.getRemainingNodeCores("n_2"), equalTo(0)); + } + { + // new memory format + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 8); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 8); + Deployment deployment1 = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(30).getBytes(), + 2, + 1, + Map.of("n_1", 1), + 1, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(10).getBytes() + ); + Deployment deployment2 = new Deployment( + "m_2", + ByteSizeValue.ofMb(50).getBytes(), + 6, + 4, + Map.of("n_1", 1, "n_2", 2), + 3, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(10).getBytes() + ); + PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations( + List.of(node1, node2), + List.of(deployment1, deployment2) + ); + + List nodesPreservingAllocations = preserveAllAllocations.nodesPreservingAllocations(); + assertThat(nodesPreservingAllocations, hasSize(2)); + + assertThat(nodesPreservingAllocations.get(0).id(), equalTo("n_1")); + // 1000 - [(30 + 300+10) + (50 + 300 + 10)] = 300: deployments use 700 MB on the node 1 + assertThat(nodesPreservingAllocations.get(0).availableMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); + // 8 - (1*1+1*4) = 3 : deployments use 5 cores on the node + assertThat(nodesPreservingAllocations.get(0).cores(), equalTo(3)); + + assertThat(nodesPreservingAllocations.get(1).id(), equalTo("n_2")); + // 1000 - (50 + 300 + 2*10) = 630 : deployments use 370MB on the node + assertThat(nodesPreservingAllocations.get(1).availableMemoryBytes(), equalTo(ByteSizeValue.ofMb(630).getBytes())); + // 8 - (2*4) = 0 : preserving all allocation2 of deployment 2 should use 8 cores on the node + assertThat(nodesPreservingAllocations.get(1).cores(), equalTo(0)); + + List modelsPreservingAllocations = preserveAllAllocations.modelsPreservingAllocations(); + assertThat(modelsPreservingAllocations, hasSize(2)); + + assertThat(modelsPreservingAllocations.get(0).id(), equalTo("m_1")); + assertThat(modelsPreservingAllocations.get(0).memoryBytes(), equalTo(ByteSizeValue.ofMb(30).getBytes())); + assertThat(modelsPreservingAllocations.get(0).perDeploymentMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); + assertThat(modelsPreservingAllocations.get(0).perAllocationMemoryBytes(), equalTo(ByteSizeValue.ofMb(10).getBytes())); + assertThat(modelsPreservingAllocations.get(0).allocations(), equalTo(1)); + assertThat(modelsPreservingAllocations.get(0).threadsPerAllocation(), equalTo(1)); + assertThat(modelsPreservingAllocations.get(0).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0))); + + assertThat(modelsPreservingAllocations.get(1).id(), equalTo("m_2")); + assertThat(modelsPreservingAllocations.get(1).memoryBytes(), equalTo(ByteSizeValue.ofMb(50).getBytes())); + assertThat(modelsPreservingAllocations.get(1).perDeploymentMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); + assertThat(modelsPreservingAllocations.get(1).perAllocationMemoryBytes(), equalTo(ByteSizeValue.ofMb(10).getBytes())); + assertThat(modelsPreservingAllocations.get(1).allocations(), equalTo(3)); + assertThat(modelsPreservingAllocations.get(1).threadsPerAllocation(), equalTo(4)); + assertThat(modelsPreservingAllocations.get(1).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0, "n_2", 0))); + + // Now we have a plan with 2 deployments assigned to 2 nodes. + // Note that deployment 1 has already 1 allocation on node 1, and it gets 2 more. It's more than 2 allocations defined during + // initialization of deployment1, but we don't care at this point. + AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) + .assignModelToNode(deployment1, node1, 2) + .build(); + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 2))); + assertThat(plan.assignments(deployment2).isEmpty(), is(true)); + + plan = preserveAllAllocations.mergePreservedAllocations(plan); + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 3))); + assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node1, 1, node2, 2))); + + // 1000 - ((30 + 300 + 3*10) + (50 + 300 + 10)) = 280 : deployments use 720 MB on the node 1 + assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(ByteSizeValue.ofMb(280).getBytes())); + // 8 - ((1*1+1*4) + 2*1) = 1 : deployments use 7 cores on the node + assertThat(plan.getRemainingNodeCores("n_1"), equalTo(1)); + // Nothing changed for Node 2 + assertThat(plan.getRemainingNodeMemory("n_2"), equalTo(ByteSizeValue.ofMb(630).getBytes())); + // Nothing changed for Node 2 + assertThat(plan.getRemainingNodeCores("n_2"), equalTo(0)); + } } public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments() { - Node node = new Node("n_1", 100, 4); - AssignmentPlan.Deployment deployment = new Deployment("m_1", 30, 2, 2, Map.of("n_1", 2), 2); + Node node = new Node("n_1", ByteSizeValue.ofMb(400).getBytes(), 4); + Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 2, 0, 0); PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations(List.of(node), List.of(deployment)); AssignmentPlan plan = AssignmentPlan.builder(List.of(node), List.of(deployment)).build(); @@ -101,7 +203,7 @@ public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments plan = preserveAllAllocations.mergePreservedAllocations(plan); assertThat(plan.assignments(deployment).isPresent(), is(true)); assertThat(plan.assignments(deployment).get(), equalTo(Map.of(node, 2))); - assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(70L)); + assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(ByteSizeValue.ofMb(100).getBytes())); assertThat(plan.getRemainingNodeCores("n_1"), equalTo(0)); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java index d8c3b09422e92..f646bf5cb2e9d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.assignment.planning; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Deployment; import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Node; @@ -22,10 +23,10 @@ public class PreserveOneAllocationTests extends ESTestCase { public void testGivenNoPreviousAssignments() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 30, 2, 1, Map.of(), 0); - AssignmentPlan.Deployment deployment2 = new Deployment("m_2", 30, 2, 4, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4); + Deployment deployment1 = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, 0, 0); + AssignmentPlan.Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 4, Map.of(), 0, 0, 0); PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation(List.of(node1, node2), List.of(deployment1, deployment2)); List nodesPreservingAllocations = preserveOneAllocation.nodesPreservingAllocations(); @@ -36,67 +37,204 @@ public void testGivenNoPreviousAssignments() { } public void testGivenPreviousAssignments() { - Node node1 = new Node("n_1", 100, 8); - Node node2 = new Node("n_2", 100, 8); - AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 30, 2, 1, Map.of("n_1", 1), 1); - AssignmentPlan.Deployment deployment2 = new Deployment("m_2", 50, 6, 4, Map.of("n_1", 1, "n_2", 2), 3); - PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation(List.of(node1, node2), List.of(deployment1, deployment2)); - - List nodesPreservingAllocations = preserveOneAllocation.nodesPreservingAllocations(); - assertThat(nodesPreservingAllocations, hasSize(2)); - - assertThat(nodesPreservingAllocations.get(0).id(), equalTo("n_1")); - assertThat(nodesPreservingAllocations.get(0).availableMemoryBytes(), equalTo(20L)); - assertThat(nodesPreservingAllocations.get(0).cores(), equalTo(3)); - - assertThat(nodesPreservingAllocations.get(1).id(), equalTo("n_2")); - assertThat(nodesPreservingAllocations.get(1).availableMemoryBytes(), equalTo(50L)); - assertThat(nodesPreservingAllocations.get(1).cores(), equalTo(4)); - - List modelsPreservingAllocations = preserveOneAllocation.modelsPreservingAllocations(); - assertThat(modelsPreservingAllocations, hasSize(2)); - - assertThat(modelsPreservingAllocations.get(0).id(), equalTo("m_1")); - assertThat(modelsPreservingAllocations.get(0).memoryBytes(), equalTo(30L)); - assertThat(modelsPreservingAllocations.get(0).allocations(), equalTo(1)); - assertThat(modelsPreservingAllocations.get(0).threadsPerAllocation(), equalTo(1)); - assertThat(modelsPreservingAllocations.get(0).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0))); - - assertThat(modelsPreservingAllocations.get(1).id(), equalTo("m_2")); - assertThat(modelsPreservingAllocations.get(1).memoryBytes(), equalTo(50L)); - assertThat(modelsPreservingAllocations.get(1).allocations(), equalTo(4)); - assertThat(modelsPreservingAllocations.get(1).threadsPerAllocation(), equalTo(4)); - assertThat(modelsPreservingAllocations.get(1).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0, "n_2", 1))); - - AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) - .assignModelToNode(deployment1, node1, 2) - .assignModelToNode(deployment2, node2, 1) - .build(); - assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 2))); - assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node2, 1))); - - plan = preserveOneAllocation.mergePreservedAllocations(plan); - - assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 3))); - assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node1, 1, node2, 2))); - assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(20L)); - assertThat(plan.getRemainingNodeCores("n_1"), equalTo(1)); - assertThat(plan.getRemainingNodeMemory("n_2"), equalTo(50L)); - assertThat(plan.getRemainingNodeCores("n_2"), equalTo(0)); + { + // old memory format + Node node1 = new Node("n_1", ByteSizeValue.ofMb(640).getBytes(), 8); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(640).getBytes(), 8); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of("n_1", 1), 1, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(50).getBytes(), 6, 4, Map.of("n_1", 1, "n_2", 2), 3, 0, 0); + PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation( + List.of(node1, node2), + List.of(deployment1, deployment2) + ); + + List nodesPreservingAllocations = preserveOneAllocation.nodesPreservingAllocations(); + assertThat(nodesPreservingAllocations, hasSize(2)); + + assertThat(nodesPreservingAllocations.get(0).id(), equalTo("n_1")); + // 640 - [(30*2+240)+(50*2+240)] = 0 : deployments use all memory on the node + assertThat(nodesPreservingAllocations.get(0).availableMemoryBytes(), equalTo(0L)); + // 8 - (1*1+1*4) = 3 : deployments use 5 cores on the node + assertThat(nodesPreservingAllocations.get(0).cores(), equalTo(3)); + + assertThat(nodesPreservingAllocations.get(1).id(), equalTo("n_2")); + // 640 - (50*2+240) = 300 : deployments use 340MB on the node + assertThat(nodesPreservingAllocations.get(1).availableMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); + // 8 - (1*4) = 4 : preserving 1 allocation of deployment 2 should use 4 cores on the node + assertThat(nodesPreservingAllocations.get(1).cores(), equalTo(4)); + + List modelsPreservingAllocations = preserveOneAllocation.modelsPreservingAllocations(); + assertThat(modelsPreservingAllocations, hasSize(2)); + + assertThat(modelsPreservingAllocations.get(0).id(), equalTo("m_1")); + assertThat(modelsPreservingAllocations.get(0).memoryBytes(), equalTo(ByteSizeValue.ofMb(30).getBytes())); + assertThat(modelsPreservingAllocations.get(0).perDeploymentMemoryBytes(), equalTo(ByteSizeValue.ofMb(0).getBytes())); + assertThat(modelsPreservingAllocations.get(0).perAllocationMemoryBytes(), equalTo(ByteSizeValue.ofMb(0).getBytes())); + assertThat(modelsPreservingAllocations.get(0).allocations(), equalTo(1)); + assertThat(modelsPreservingAllocations.get(0).threadsPerAllocation(), equalTo(1)); + assertThat(modelsPreservingAllocations.get(0).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0))); + + assertThat(modelsPreservingAllocations.get(1).id(), equalTo("m_2")); + assertThat(modelsPreservingAllocations.get(1).memoryBytes(), equalTo(ByteSizeValue.ofMb(50).getBytes())); + assertThat(modelsPreservingAllocations.get(1).perDeploymentMemoryBytes(), equalTo(ByteSizeValue.ofMb(0).getBytes())); + assertThat(modelsPreservingAllocations.get(1).perAllocationMemoryBytes(), equalTo(ByteSizeValue.ofMb(0).getBytes())); + assertThat(modelsPreservingAllocations.get(1).allocations(), equalTo(4)); + assertThat(modelsPreservingAllocations.get(1).threadsPerAllocation(), equalTo(4)); + assertThat(modelsPreservingAllocations.get(1).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0, "n_2", 1))); + + // Now we have a plan with 2 deployments assigned to 2 nodes. + // Note that deployment 1 has already 1 allocation on node 1, and it gets 2 more. It's more than 2 allocations defined during + // initialization of deployment1, but we don't care at this point. + AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) + .assignModelToNode(deployment1, node1, 2) + .assignModelToNode(deployment2, node2, 1) + .build(); + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 2))); + assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node2, 1))); + + plan = preserveOneAllocation.mergePreservedAllocations(plan); + + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 3))); + assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node1, 1, node2, 2))); + // Node 1 already had deployments 1 and 2 assigned to it so adding more allocation doesn't change memory usage. + assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(0L)); + // 8 - ((1*1+1*4) + 2*1) = 1 : deployments use 7 cores on the node + assertThat(plan.getRemainingNodeCores("n_1"), equalTo(1)); + // Node 2 already had deployment 2 assigned to it so adding more allocation doesn't change memory usage. + assertThat(plan.getRemainingNodeMemory("n_2"), equalTo(ByteSizeValue.ofMb(300).getBytes())); + // 8 - [(1*4) + (1*4)] = 4 : deployment 2 should use all cores on the node + assertThat(plan.getRemainingNodeCores("n_2"), equalTo(0)); + } + { + // new memory format + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 8); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 8); + Deployment deployment1 = new Deployment( + "m_1", + ByteSizeValue.ofMb(30).getBytes(), + 2, + 1, + Map.of("n_1", 1), + 1, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(10).getBytes() + ); + Deployment deployment2 = new Deployment( + "m_2", + ByteSizeValue.ofMb(50).getBytes(), + 6, + 4, + Map.of("n_1", 1, "n_2", 2), + 3, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(10).getBytes() + ); + PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation( + List.of(node1, node2), + List.of(deployment1, deployment2) + ); + + List nodesPreservingAllocations = preserveOneAllocation.nodesPreservingAllocations(); + assertThat(nodesPreservingAllocations, hasSize(2)); + + assertThat(nodesPreservingAllocations.get(0).id(), equalTo("n_1")); + // 1000 - [(30+300+10)+(50 + 300 +10)] = 300 : deployments use 700 memory on the node + assertThat(nodesPreservingAllocations.get(0).availableMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); + // 8 - (1*1+1*4) = 3 : deployments use 5 cores on the node + assertThat(nodesPreservingAllocations.get(0).cores(), equalTo(3)); + + assertThat(nodesPreservingAllocations.get(1).id(), equalTo("n_2")); + // 1000 - (50 +300 + 2*10) = 630 : deployments use 340MB on the node + assertThat(nodesPreservingAllocations.get(1).availableMemoryBytes(), equalTo(ByteSizeValue.ofMb(630).getBytes())); + // 8 - (1*4) = 0 : preserving 1 allocation of deployment 2 should use 4 cores on the node + assertThat(nodesPreservingAllocations.get(1).cores(), equalTo(4)); + + List modelsPreservingAllocations = preserveOneAllocation.modelsPreservingAllocations(); + assertThat(modelsPreservingAllocations, hasSize(2)); + + assertThat(modelsPreservingAllocations.get(0).id(), equalTo("m_1")); + assertThat(modelsPreservingAllocations.get(0).memoryBytes(), equalTo(ByteSizeValue.ofMb(30).getBytes())); + assertThat(modelsPreservingAllocations.get(0).perDeploymentMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); + assertThat(modelsPreservingAllocations.get(0).perAllocationMemoryBytes(), equalTo(ByteSizeValue.ofMb(10).getBytes())); + assertThat(modelsPreservingAllocations.get(0).allocations(), equalTo(1)); + assertThat(modelsPreservingAllocations.get(0).threadsPerAllocation(), equalTo(1)); + assertThat(modelsPreservingAllocations.get(0).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0))); + + assertThat(modelsPreservingAllocations.get(1).id(), equalTo("m_2")); + assertThat(modelsPreservingAllocations.get(1).memoryBytes(), equalTo(ByteSizeValue.ofMb(50).getBytes())); + assertThat(modelsPreservingAllocations.get(1).perDeploymentMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); + assertThat(modelsPreservingAllocations.get(1).perAllocationMemoryBytes(), equalTo(ByteSizeValue.ofMb(10).getBytes())); + assertThat(modelsPreservingAllocations.get(1).allocations(), equalTo(4)); + assertThat(modelsPreservingAllocations.get(1).threadsPerAllocation(), equalTo(4)); + assertThat(modelsPreservingAllocations.get(1).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0, "n_2", 1))); + + // Now we have a plan with 2 deployments assigned to 2 nodes. + // Note that deployment 1 has already 1 allocation on node 1, and it gets 2 more. It's more than 2 allocations defined during + // initialization of deployment1, but we don't care at this point. + AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) + .assignModelToNode(deployment1, node1, 2) + .assignModelToNode(deployment2, node2, 1) + .build(); + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 2))); + assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node2, 1))); + + plan = preserveOneAllocation.mergePreservedAllocations(plan); + + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 3))); + assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node1, 1, node2, 2))); + // 1000 - [(30+300+3*10) + (50+300+10)] = 280 : deployments use 720MB on the node + assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(ByteSizeValue.ofMb(280).getBytes())); + // 8 - ((1*1+1*4) + 2*1) = 1 : deployments use 7 cores on the node + assertThat(plan.getRemainingNodeCores("n_1"), equalTo(1)); + // 1000 - (50 + 300 + 2*10) = 630 : deployments use 370MB on the node + assertThat(plan.getRemainingNodeMemory("n_2"), equalTo(ByteSizeValue.ofMb(630).getBytes())); + // 8 - [(1*4) + (1*4)] = 4 : deployment 2 should use all cores on the node + assertThat(plan.getRemainingNodeCores("n_2"), equalTo(0)); + + } } public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments() { - Node node = new Node("n_1", 100, 4); - AssignmentPlan.Deployment deployment = new Deployment("m_1", 30, 2, 2, Map.of("n_1", 2), 2); - PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation(List.of(node), List.of(deployment)); - - AssignmentPlan plan = AssignmentPlan.builder(List.of(node), List.of(deployment)).build(); - assertThat(plan.assignments(deployment).isEmpty(), is(true)); - - plan = preserveOneAllocation.mergePreservedAllocations(plan); - assertThat(plan.assignments(deployment).isPresent(), is(true)); - assertThat(plan.assignments(deployment).get(), equalTo(Map.of(node, 1))); - assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(70L)); - assertThat(plan.getRemainingNodeCores("n_1"), equalTo(2)); + { + // old memory format + Node node = new Node("n_1", ByteSizeValue.ofMb(400).getBytes(), 4); + Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 2, 0, 0); + PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation(List.of(node), List.of(deployment)); + + AssignmentPlan plan = AssignmentPlan.builder(List.of(node), List.of(deployment)).build(); + assertThat(plan.assignments(deployment).isEmpty(), is(true)); + + plan = preserveOneAllocation.mergePreservedAllocations(plan); + assertThat(plan.assignments(deployment).isPresent(), is(true)); + assertThat(plan.assignments(deployment).get(), equalTo(Map.of(node, 1))); + // 400 - (30*2 + 240) = 100 : deployments use 300MB on the node + assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(ByteSizeValue.ofMb(100).getBytes())); + assertThat(plan.getRemainingNodeCores("n_1"), equalTo(2)); + } + { + // new memory format + Node node = new Node("n_1", ByteSizeValue.ofMb(400).getBytes(), 4); + Deployment deployment = new Deployment( + "m_1", + ByteSizeValue.ofMb(30).getBytes(), + 2, + 2, + Map.of("n_1", 2), + 2, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(10).getBytes() + ); + PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation(List.of(node), List.of(deployment)); + + AssignmentPlan plan = AssignmentPlan.builder(List.of(node), List.of(deployment)).build(); + assertThat(plan.assignments(deployment).isEmpty(), is(true)); + + plan = preserveOneAllocation.mergePreservedAllocations(plan); + assertThat(plan.assignments(deployment).isPresent(), is(true)); + assertThat(plan.assignments(deployment).get(), equalTo(Map.of(node, 1))); + // 400 - (30 + 300 + 10) = 60 : deployments use 340MB on the node + assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(ByteSizeValue.ofMb(60).getBytes())); + assertThat(plan.getRemainingNodeCores("n_1"), equalTo(2)); + } } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java index 7ceb8bbb86869..651e4764cb894 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java @@ -36,7 +36,7 @@ public class ZoneAwareAssignmentPlannerTests extends ESTestCase { public void testGivenOneModel_OneNode_OneZone_DoesNotFit() { Node node = new Node("n_1", 100, 1); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 1, 2, Map.of(), 0); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 1, 2, Map.of(), 0, 0, 0); AssignmentPlan plan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node)), List.of(deployment)).computePlan(); @@ -44,8 +44,17 @@ public void testGivenOneModel_OneNode_OneZone_DoesNotFit() { } public void testGivenOneModel_OneNode_OneZone_FullyFits() { - Node node = new Node("n_1", 100, 4); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 2, 2, Map.of(), 0); + Node node = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 2, + 2, + Map.of(), + 0, + 0, + 0 + ); AssignmentPlan plan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node)), List.of(deployment)).computePlan(); @@ -53,8 +62,17 @@ public void testGivenOneModel_OneNode_OneZone_FullyFits() { } public void testGivenOneModel_OneNode_OneZone_PartiallyFits() { - Node node = new Node("n_1", 100, 5); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 3, 2, Map.of(), 0); + Node node = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 5); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 3, + 2, + Map.of(), + 0, + 0, + 0 + ); AssignmentPlan plan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node)), List.of(deployment)).computePlan(); @@ -64,9 +82,18 @@ public void testGivenOneModel_OneNode_OneZone_PartiallyFits() { } public void testGivenOneModelWithSingleAllocation_OneNode_TwoZones() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 1, 2, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 1, + 2, + Map.of(), + 0, + 0, + 0 + ); AssignmentPlan plan = new ZoneAwareAssignmentPlanner( Map.of(List.of("z1"), List.of(node1), List.of("z2"), List.of(node2)), @@ -82,9 +109,18 @@ public void testGivenOneModelWithSingleAllocation_OneNode_TwoZones() { } public void testGivenOneModel_OneNodePerZone_TwoZones_FullyFits() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 2, 2, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 2, + 2, + Map.of(), + 0, + 0, + 0 + ); AssignmentPlan plan = new ZoneAwareAssignmentPlanner( Map.of(List.of("z_1"), List.of(node1), List.of("z_2"), List.of(node2)), @@ -99,9 +135,18 @@ public void testGivenOneModel_OneNodePerZone_TwoZones_FullyFits() { } public void testGivenOneModel_OneNodePerZone_TwoZones_PartiallyFits() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 3, 3, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 3, + 3, + Map.of(), + 0, + 0, + 0 + ); AssignmentPlan plan = new ZoneAwareAssignmentPlanner( Map.of(List.of("z_1"), List.of(node1), List.of("z_2"), List.of(node2)), @@ -117,15 +162,15 @@ public void testGivenOneModel_OneNodePerZone_TwoZones_PartiallyFits() { } public void testGivenThreeModels_TwoNodesPerZone_ThreeZones_FullyFit() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - Node node3 = new Node("n_3", 100, 4); - Node node4 = new Node("n_4", 100, 4); - Node node5 = new Node("n_5", 100, 4); - Node node6 = new Node("n_6", 100, 4); - AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 25, 4, 1, Map.of(), 0); - Deployment deployment2 = new AssignmentPlan.Deployment("m_2", 25, 6, 2, Map.of(), 0); - AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment("m_3", 25, 2, 3, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node3 = new Node("n_3", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node4 = new Node("n_4", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node5 = new Node("n_5", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node6 = new Node("n_6", ByteSizeValue.ofMb(1000).getBytes(), 4); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 4, 1, Map.of(), 0, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 6, 2, Map.of(), 0, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(30).getBytes(), 2, 3, Map.of(), 0, 0, 0); Map, List> nodesByZone = Map.of( List.of("z_1"), @@ -168,11 +213,11 @@ public void testGivenThreeModels_TwoNodesPerZone_ThreeZones_FullyFit() { } public void testGivenTwoModelsWithSingleAllocation_OneNode_ThreeZones() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - Node node3 = new Node("n_3", 100, 4); - AssignmentPlan.Deployment deployment1 = new Deployment("m_1", 25, 1, 1, Map.of(), 0); - AssignmentPlan.Deployment deployment2 = new Deployment("m_2", 25, 1, 1, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node3 = new Node("n_3", ByteSizeValue.ofMb(1000).getBytes(), 4); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 1, 1, Map.of(), 0, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 1, 1, Map.of(), 0, 0, 0); AssignmentPlan plan = new ZoneAwareAssignmentPlanner( Map.of(List.of("z1"), List.of(node1), List.of("z2"), List.of(node2), List.of("z3"), List.of(node3)), @@ -203,7 +248,16 @@ public void testPreviousAssignmentsGetAtLeastAsManyAllocationsAfterAddingNewMode .stream() .collect(Collectors.toMap(e -> e.getKey().id(), Map.Entry::getValue)); previousModelsPlusNew.add( - new AssignmentPlan.Deployment(m.id(), m.memoryBytes(), m.allocations(), m.threadsPerAllocation(), previousAssignments, 0) + new AssignmentPlan.Deployment( + m.id(), + m.memoryBytes(), + m.allocations(), + m.threadsPerAllocation(), + previousAssignments, + 0, + 0, + 0 + ) ); } previousModelsPlusNew.add(randomModel("new")); @@ -214,11 +268,11 @@ public void testPreviousAssignmentsGetAtLeastAsManyAllocationsAfterAddingNewMode } public void testGivenClusterResize_GivenOneZone_ShouldAllocateEachModelAtLeastOnce() { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(1200).getBytes(), 2); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(1200).getBytes(), 2); - AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0); - AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0); - AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(2580).getBytes(), 2); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(2580).getBytes(), 2); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, 0, 0); // First only start m_1 AssignmentPlan assignmentPlan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node1, node2)), List.of(deployment1)) @@ -252,8 +306,8 @@ public void testGivenClusterResize_GivenOneZone_ShouldAllocateEachModelAtLeastOn assertThat(indexedBasedPlan.get("m_3"), equalTo(Map.of("n_2", 1))); // Now the cluster starts getting resized. - Node node3 = new Node("n_3", ByteSizeValue.ofMb(2400).getBytes(), 2); - Node node4 = new Node("n_4", ByteSizeValue.ofMb(2400).getBytes(), 2); + Node node3 = new Node("n_3", ByteSizeValue.ofMb(5160).getBytes(), 2); + Node node4 = new Node("n_4", ByteSizeValue.ofMb(5160).getBytes(), 2); // First, one node goes away. assignmentPlan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node1)), createModelsFromPlan(assignmentPlan)) diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MlAssignmentPlannerUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MlAssignmentPlannerUpgradeIT.java new file mode 100644 index 0000000000000..549ac23e16845 --- /dev/null +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MlAssignmentPlannerUpgradeIT.java @@ -0,0 +1,287 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.upgrades; + +import org.elasticsearch.Version; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.core.Strings; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static org.elasticsearch.client.WarningsHandler.PERMISSIVE; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class MlAssignmentPlannerUpgradeIT extends AbstractUpgradeTestCase { + + private Logger logger = LogManager.getLogger(MlAssignmentPlannerUpgradeIT.class); + + // See PyTorchModelIT for how this model was created + static final String BASE_64_ENCODED_MODEL = + "UEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWlpaWlpaWlpaWoACY19fdG9yY2hfXwp" + + "TdXBlclNpbXBsZQpxACmBfShYCAAAAHRyYWluaW5ncQGIdWJxAi5QSwcIXOpBBDQAAAA0AAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAA" + + "AAAAAdAEEAc2ltcGxlbW9kZWwvY29kZS9fX3RvcmNoX18ucHlGQj0AWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaW" + + "lpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWnWOMWvDMBCF9/yKI5MMrnHTQsHgjt2aJdlCEIp9SgWSTpykFvfXV1htaYds0nfv473Jqhjh" + + "kAPywbhgUbzSnC02wwZAyqBYOUzIUUoY4XRe6SVr/Q8lVsYbf4UBLkS2kBk1aOIPxbOIaPVQtEQ8vUnZ/WlrSxTA+JCTNHMc4Ig+Ele" + + "s+Jod+iR3N/jDDf74wxu4e/5+DmtE9mUyhdgFNq7bZ3ekehbruC6aTxS/c1rom6Z698WrEfIYxcn4JGTftLA7tzCnJeD41IJVC+U07k" + + "umUHw3E47Vqh+xnULeFisYLx064mV8UTZibWFMmX0p23wBUEsHCE0EGH3yAAAAlwEAAFBLAwQUAAgICAAAAAAAAAAAAAAAAAAAAAAAJ" + + "wA5AHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYnVnX3BrbEZCNQBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpa" + + "WlpaWlpaWlpaWlpaWlpaWlpaWlpaWrWST0+DMBiHW6bOod/BGS94kKpo2Mwyox5x3pbgiXSAFtdR/nQu3IwHiZ9oX88CaeGu9tL0efq" + + "+v8P7fmiGA1wgTgoIcECZQqe6vmYD6G4hAJOcB1E8NazTm+ELyzY4C3Q0z8MsRwF+j4JlQUPEEo5wjH0WB9hCNFqgpOCExZY5QnnEw7" + + "ME+0v8GuaIs8wnKI7RigVrKkBzm0lh2OdjkeHllG28f066vK6SfEypF60S+vuYt4gjj2fYr/uPrSvRv356TepfJ9iWJRN0OaELQSZN3" + + "FRPNbcP1PTSntMr0x0HzLZQjPYIEo3UaFeiISRKH0Mil+BE/dyT1m7tCBLwVO1MX4DK3bbuTlXuy8r71j5Aoho66udAoseOnrdVzx28" + + "UFW6ROuO/lT6QKKyo79VU54emj9QSwcInsUTEDMBAAAFAwAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAZAAYAc2ltcGxlbW9kZWw" + + "vY29uc3RhbnRzLnBrbEZCAgBaWoACKS5QSwcIbS8JVwQAAAAEAAAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAATADsAc2ltcGxlbW" + + "9kZWwvdmVyc2lvbkZCNwBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaMwpQSwcI0" + + "Z5nVQIAAAACAAAAUEsBAgAAAAAICAAAAAAAAFzqQQQ0AAAANAAAABQAAAAAAAAAAAAAAAAAAAAAAHNpbXBsZW1vZGVsL2RhdGEucGts" + + "UEsBAgAAFAAICAgAAAAAAE0EGH3yAAAAlwEAAB0AAAAAAAAAAAAAAAAAhAAAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5UEs" + + "BAgAAFAAICAgAAAAAAJ7FExAzAQAABQMAACcAAAAAAAAAAAAAAAAAAgIAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYn" + + "VnX3BrbFBLAQIAAAAACAgAAAAAAABtLwlXBAAAAAQAAAAZAAAAAAAAAAAAAAAAAMMDAABzaW1wbGVtb2RlbC9jb25zdGFudHMucGtsU" + + "EsBAgAAAAAICAAAAAAAANGeZ1UCAAAAAgAAABMAAAAAAAAAAAAAAAAAFAQAAHNpbXBsZW1vZGVsL3ZlcnNpb25QSwYGLAAAAAAAAAAe" + + "Ay0AAAAAAAAAAAAFAAAAAAAAAAUAAAAAAAAAagEAAAAAAACSBAAAAAAAAFBLBgcAAAAA/AUAAAAAAAABAAAAUEsFBgAAAAAFAAUAagE" + + "AAJIEAAAAAA=="; + static final long RAW_MODEL_SIZE; // size of the model before base64 encoding + static { + RAW_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_MODEL).length; + } + + public void testMlAssignmentPlannerUpgrade() throws Exception { + assumeTrue("NLP model deployments added in 8.0", isOriginalClusterVersionAtLeast(Version.V_8_0_0)); + + logger.info("Starting testMlAssignmentPlannerUpgrade, model size {}", RAW_MODEL_SIZE); + + switch (CLUSTER_TYPE) { + case OLD -> { + // setup deployments using old and new memory format + setupDeployments(); + + waitForDeploymentStarted("old_memory_format"); + waitForDeploymentStarted("new_memory_format"); + + // assert correct memory format is used + assertOldMemoryFormat("old_memory_format"); + if (isOriginalClusterVersionAtLeast(Version.V_8_11_0)) { + assertNewMemoryFormat("new_memory_format"); + } else { + assertOldMemoryFormat("new_memory_format"); + } + } + case MIXED -> { + ensureHealth(".ml-inference-*,.ml-config*", (request -> { + request.addParameter("wait_for_status", "yellow"); + request.addParameter("timeout", "70s"); + })); + waitForDeploymentStarted("old_memory_format"); + waitForDeploymentStarted("new_memory_format"); + + // assert correct memory format is used + assertOldMemoryFormat("old_memory_format"); + if (isOriginalClusterVersionAtLeast(Version.V_8_11_0)) { + assertNewMemoryFormat("new_memory_format"); + } else { + assertOldMemoryFormat("new_memory_format"); + } + + } + case UPGRADED -> { + ensureHealth(".ml-inference-*,.ml-config*", (request -> { + request.addParameter("wait_for_status", "yellow"); + request.addParameter("timeout", "70s"); + })); + waitForDeploymentStarted("old_memory_format"); + waitForDeploymentStarted("new_memory_format"); + + // assert correct memory format is used + assertOldMemoryFormat("old_memory_format"); + assertNewMemoryFormat("new_memory_format"); + + cleanupDeployments(); + } + } + } + + @SuppressWarnings("unchecked") + private void waitForDeploymentStarted(String modelId) throws Exception { + assertBusy(() -> { + var response = getTrainedModelStats(modelId); + Map map = entityAsMap(response); + List> stats = (List>) map.get("trained_model_stats"); + assertThat(stats, hasSize(1)); + var stat = stats.get(0); + assertThat(stat.toString(), XContentMapValues.extractValue("deployment_stats.state", stat), equalTo("started")); + }, 30, TimeUnit.SECONDS); + } + + @SuppressWarnings("unchecked") + private void assertOldMemoryFormat(String modelId) throws Exception { + var response = getTrainedModelStats(modelId); + Map map = entityAsMap(response); + List> stats = (List>) map.get("trained_model_stats"); + assertThat(stats, hasSize(1)); + var stat = stats.get(0); + Long expectedMemoryUsage = ByteSizeValue.ofMb(240).getBytes() + RAW_MODEL_SIZE * 2; + Integer actualMemoryUsage = (Integer) XContentMapValues.extractValue("model_size_stats.required_native_memory_bytes", stat); + assertThat( + Strings.format("Memory usage mismatch for the model %s in cluster state %s", modelId, CLUSTER_TYPE.toString()), + actualMemoryUsage, + equalTo(expectedMemoryUsage.intValue()) + ); + } + + @SuppressWarnings("unchecked") + private void assertNewMemoryFormat(String modelId) throws Exception { + var response = getTrainedModelStats(modelId); + Map map = entityAsMap(response); + List> stats = (List>) map.get("trained_model_stats"); + assertThat(stats, hasSize(1)); + var stat = stats.get(0); + Long expectedMemoryUsage = ByteSizeValue.ofMb(300).getBytes() + RAW_MODEL_SIZE + ByteSizeValue.ofMb(10).getBytes(); + Integer actualMemoryUsage = (Integer) XContentMapValues.extractValue("model_size_stats.required_native_memory_bytes", stat); + assertThat(stat.toString(), actualMemoryUsage.toString(), equalTo(expectedMemoryUsage.toString())); + } + + private Response getTrainedModelStats(String modelId) throws IOException { + Request request = new Request("GET", "/_ml/trained_models/" + modelId + "/_stats"); + request.setOptions(request.getOptions().toBuilder().setWarningsHandler(PERMISSIVE).build()); + var response = client().performRequest(request); + assertOK(response); + return response; + } + + private Response infer(String input, String modelId) throws IOException { + Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/deployment/_infer"); + request.setJsonEntity(Strings.format(""" + { "docs": [{"input":"%s"}] } + """, input)); + request.setOptions(request.getOptions().toBuilder().setWarningsHandler(PERMISSIVE).build()); + var response = client().performRequest(request); + assertOK(response); + return response; + } + + private void putModelDefinition(String modelId) throws IOException { + Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/definition/0"); + request.setJsonEntity(Strings.format(""" + {"total_definition_length":%s,"definition": "%s","total_parts": 1}""", RAW_MODEL_SIZE, BASE_64_ENCODED_MODEL)); + client().performRequest(request); + } + + private void putVocabulary(List vocabulary, String modelId) throws IOException { + List vocabularyWithPad = new ArrayList<>(); + vocabularyWithPad.add("[PAD]"); + vocabularyWithPad.add("[UNK]"); + vocabularyWithPad.addAll(vocabulary); + String quotedWords = vocabularyWithPad.stream().map(s -> "\"" + s + "\"").collect(Collectors.joining(",")); + + Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/vocabulary"); + request.setJsonEntity(Strings.format(""" + { "vocabulary": [%s] } + """, quotedWords)); + client().performRequest(request); + } + + private void setupDeployments() throws Exception { + createTrainedModel("old_memory_format", 0, 0); + putModelDefinition("old_memory_format"); + putVocabulary(List.of("these", "are", "my", "words"), "old_memory_format"); + startDeployment("old_memory_format"); + + createTrainedModel("new_memory_format", ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(10).getBytes()); + putModelDefinition("new_memory_format"); + putVocabulary(List.of("these", "are", "my", "words"), "new_memory_format"); + startDeployment("new_memory_format"); + } + + private void cleanupDeployments() throws IOException { + stopDeployment("old_memory_format"); + deleteTrainedModel("old_memory_format"); + stopDeployment("new_memory_format"); + deleteTrainedModel("new_memory_format"); + } + + private void createTrainedModel(String modelId, long perDeploymentMemoryBytes, long perAllocationMemoryBytes) throws IOException { + Request request = new Request("PUT", "/_ml/trained_models/" + modelId); + if (perAllocationMemoryBytes > 0 && perDeploymentMemoryBytes > 0) { + request.setJsonEntity(Strings.format(""" + { + "description": "simple model for testing", + "model_type": "pytorch", + "inference_config": { + "pass_through": { + "tokenization": { + "bert": { + "with_special_tokens": false + } + } + } + }, + "metadata": { + "per_deployment_memory_bytes": %s, + "per_allocation_memory_bytes": %s + } + }""", perDeploymentMemoryBytes, perAllocationMemoryBytes)); + } else { + request.setJsonEntity(""" + { + "description": "simple model for testing", + "model_type": "pytorch", + "inference_config": { + "pass_through": { + "tokenization": { + "bert": { + "with_special_tokens": false + } + } + } + } + }"""); + } + client().performRequest(request); + } + + private void deleteTrainedModel(String modelId) throws IOException { + Request request = new Request("DELETE", "_ml/trained_models/" + modelId); + client().performRequest(request); + } + + private Response startDeployment(String modelId) throws IOException { + return startDeployment(modelId, "started"); + } + + private Response startDeployment(String modelId, String waitForState) throws IOException { + Request request = new Request( + "POST", + "/_ml/trained_models/" + + modelId + + "/deployment/_start?timeout=40s&wait_for=" + + waitForState + + "&inference_threads=1&model_threads=1" + ); + request.setOptions(request.getOptions().toBuilder().setWarningsHandler(PERMISSIVE).build()); + var response = client().performRequest(request); + assertOK(response); + return response; + } + + private void stopDeployment(String modelId) throws IOException { + String endpoint = "/_ml/trained_models/" + modelId + "/deployment/_stop"; + Request request = new Request("POST", endpoint); + client().performRequest(request); + } +} From 6ad48306d029e6e527c0481e2e9880bd2f06b239 Mon Sep 17 00:00:00 2001 From: Ievgen Degtiarenko Date: Mon, 6 Nov 2023 13:24:19 +0100 Subject: [PATCH 03/21] Skip nodes that already contain copy of this shard (#101210) This change skips desired nodes that already contains a copy of the same shard during reconciliation process. This should make reconciliation process a tiny bit cheaper and should allow to log less but more meaningful decision for the rest of the nodes. --- .../allocation/allocator/DesiredBalanceReconciler.java | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconciler.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconciler.java index f8c45026b1df7..048ade3ef86c5 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconciler.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconciler.java @@ -261,6 +261,12 @@ private void allocateUnassigned() { // desired node no longer exists continue; } + if (routingNode.getByShardId(shard.shardId()) != null) { + // node already contains same shard. + // Skipping it allows us to exclude NO decisions from SameShardAllocationDecider and only log more relevant + // NO or THROTTLE decisions of the preventing shard from starting on assigned node + continue; + } final var decision = allocation.deciders().canAllocate(shard, routingNode, allocation); switch (decision.type()) { case YES -> { @@ -287,10 +293,10 @@ private void allocateUnassigned() { case THROTTLE -> { nodeIdsIterator.wasThrottled = true; unallocatedStatus = AllocationStatus.DECIDERS_THROTTLED; - logger.trace("Couldn't assign shard [{}] to [{}]: {}", shard.shardId(), nodeId, decision); + logger.debug("Couldn't assign shard [{}] to [{}]: {}", shard.shardId(), nodeId, decision); } case NO -> { - logger.trace("Couldn't assign shard [{}] to [{}]: {}", shard.shardId(), nodeId, decision); + logger.debug("Couldn't assign shard [{}] to [{}]: {}", shard.shardId(), nodeId, decision); } } } From b158dc7ed00e5db8f86a858d29104f5e39779d85 Mon Sep 17 00:00:00 2001 From: David Turner Date: Mon, 6 Nov 2023 12:58:06 +0000 Subject: [PATCH 04/21] Cancel via tasks API in RestActionCancellationIT (#101819) We should get the same effect whether we cancel the task by closing the REST connection or by using the tasks API. This commit generalises the test to cover both cases. --- .../http/RestActionCancellationIT.java | 94 +++++++++++++++---- 1 file changed, 75 insertions(+), 19 deletions(-) diff --git a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/RestActionCancellationIT.java b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/RestActionCancellationIT.java index 83e4582239075..d46868094907d 100644 --- a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/RestActionCancellationIT.java +++ b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/RestActionCancellationIT.java @@ -9,78 +9,134 @@ package org.elasticsearch.http; import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPost; import org.elasticsearch.action.admin.cluster.health.ClusterHealthAction; import org.elasticsearch.action.admin.cluster.state.ClusterStateAction; import org.elasticsearch.action.admin.indices.alias.get.GetAliasesAction; import org.elasticsearch.action.admin.indices.recovery.RecoveryAction; import org.elasticsearch.action.support.CancellableActionTestPlugin; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.RefCountingListener; +import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.client.Request; import org.elasticsearch.client.Response; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.rest.ObjectPath; import java.util.Collection; import java.util.concurrent.CancellationException; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import static org.elasticsearch.action.support.ActionTestUtils.wrapAsRestResponseListener; import static org.elasticsearch.test.TaskAssertions.assertAllTasksHaveFinished; +import static org.hamcrest.Matchers.greaterThan; public class RestActionCancellationIT extends HttpSmokeTestCase { - public void testIndicesRecoveryRestCancellation() throws Exception { + public void testIndicesRecoveryRestCancellation() { createIndex("test"); ensureGreen("test"); runRestActionCancellationTest(new Request(HttpGet.METHOD_NAME, "/_recovery"), RecoveryAction.NAME); } - public void testCatRecoveryRestCancellation() throws Exception { + public void testCatRecoveryRestCancellation() { createIndex("test"); ensureGreen("test"); runRestActionCancellationTest(new Request(HttpGet.METHOD_NAME, "/_cat/recovery"), RecoveryAction.NAME); } - public void testClusterHealthRestCancellation() throws Exception { + public void testClusterHealthRestCancellation() { runRestActionCancellationTest(new Request(HttpGet.METHOD_NAME, "/_cluster/health"), ClusterHealthAction.NAME); } - public void testClusterStateRestCancellation() throws Exception { + public void testClusterStateRestCancellation() { runRestActionCancellationTest(new Request(HttpGet.METHOD_NAME, "/_cluster/state"), ClusterStateAction.NAME); } - public void testGetAliasesCancellation() throws Exception { - runRestActionCancellationTest(new Request("GET", "/_alias"), GetAliasesAction.NAME); + public void testGetAliasesCancellation() { + runRestActionCancellationTest(new Request(HttpGet.METHOD_NAME, "/_alias"), GetAliasesAction.NAME); } - public void testCatAliasesCancellation() throws Exception { - runRestActionCancellationTest(new Request("GET", "/_cat/aliases"), GetAliasesAction.NAME); + public void testCatAliasesCancellation() { + runRestActionCancellationTest(new Request(HttpGet.METHOD_NAME, "/_cat/aliases"), GetAliasesAction.NAME); } - private void runRestActionCancellationTest(Request request, String actionName) throws Exception { + private void runRestActionCancellationTest(Request request, String actionName) { final var node = usually() ? internalCluster().getRandomNodeName() : internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); try ( var restClient = createRestClient(node); var capturingAction = CancellableActionTestPlugin.capturingActionOnNode(actionName, node) ) { - expectThrows( - CancellationException.class, - () -> PlainActionFuture.get( - responseFuture -> capturingAction.captureAndCancel( - restClient.performRequestAsync(request, wrapAsRestResponseListener(responseFuture))::cancel - ), - 10, - TimeUnit.SECONDS - ) - ); + final var responseFuture = new PlainActionFuture(); + final var restInvocation = restClient.performRequestAsync(request, wrapAsRestResponseListener(responseFuture)); + + if (randomBoolean()) { + // cancel by aborting the REST request + capturingAction.captureAndCancel(restInvocation::cancel); + expectThrows(ExecutionException.class, CancellationException.class, () -> responseFuture.get(10, TimeUnit.SECONDS)); + } else { + // cancel via the task management API + final var cancelFuture = new PlainActionFuture(); + capturingAction.captureAndCancel( + () -> SubscribableListener + + .newForked( + l -> restClient.performRequestAsync( + getListTasksRequest(node, actionName), + wrapAsRestResponseListener(l.map(ObjectPath::createFromResponse)) + ) + ) + + .andThen((l, listTasksResponse) -> { + final var taskCount = listTasksResponse.evaluateArraySize("tasks"); + assertThat(taskCount, greaterThan(0)); + try (var listeners = new RefCountingListener(l)) { + for (int i = 0; i < taskCount; i++) { + final var taskPrefix = "tasks." + i + "."; + assertTrue(listTasksResponse.evaluate(taskPrefix + "cancellable")); + assertFalse(listTasksResponse.evaluate(taskPrefix + "cancelled")); + restClient.performRequestAsync( + getCancelTaskRequest( + listTasksResponse.evaluate(taskPrefix + "node"), + listTasksResponse.evaluate(taskPrefix + "id") + ), + wrapAsRestResponseListener(listeners.acquire(HttpSmokeTestCase::assertOK)) + ); + } + } + }) + + .addListener(cancelFuture) + ); + cancelFuture.get(10, TimeUnit.SECONDS); + expectThrows(Exception.class, () -> responseFuture.get(10, TimeUnit.SECONDS)); + } + assertAllTasksHaveFinished(actionName); } catch (Exception e) { fail(e); } } + private static Request getListTasksRequest(String taskNode, String actionName) { + final var listTasksRequest = new Request(HttpGet.METHOD_NAME, "/_tasks"); + listTasksRequest.addParameter("nodes", taskNode); + listTasksRequest.addParameter("actions", actionName); + listTasksRequest.addParameter("group_by", "none"); + return listTasksRequest; + } + + private static Request getCancelTaskRequest(String taskNode, int taskId) { + final var cancelTaskRequest = new Request(HttpPost.METHOD_NAME, Strings.format("/_tasks/%s:%d/_cancel", taskNode, taskId)); + cancelTaskRequest.addParameter("wait_for_completion", null); + return cancelTaskRequest; + } + @Override protected Collection> nodePlugins() { return CollectionUtils.appendToCopy(super.nodePlugins(), CancellableActionTestPlugin.class); From 90d3672d1107e25fa1bb8ae42f70ebcefee7e0de Mon Sep 17 00:00:00 2001 From: Abdon Pijpelink Date: Mon, 6 Nov 2023 15:12:08 +0100 Subject: [PATCH 05/21] [DOCS] Default and max limits are now dynamic settings (#101831) * [DOCS] Default and max limits are now dynamic settings * Delete reference to Discover --- docs/reference/esql/esql-limitations.asciidoc | 16 ++----------- .../esql/processing-commands/limit.asciidoc | 24 +++++++++++++++---- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/docs/reference/esql/esql-limitations.asciidoc b/docs/reference/esql/esql-limitations.asciidoc index 303f9a337b6c4..85f107feeb8fd 100644 --- a/docs/reference/esql/esql-limitations.asciidoc +++ b/docs/reference/esql/esql-limitations.asciidoc @@ -10,20 +10,8 @@ === Result set size limit By default, an {esql} query returns up to 500 rows. You can increase the number -of rows up to 10,000 using the <> command. Queries do not return -more than 10,000 rows, regardless of the `LIMIT` command's value. - -This limit only applies to the number of rows that are retrieved by the query -and displayed in Discover. Queries and aggregations run on the full data set. - -To overcome this limitation: - -* Reduce the result set size by modifying the query to only return relevant -data. Use <> to select a smaller subset of the data. -* Shift any post-query processing to the query itself. You can use the {esql} -<> command to aggregate data in the query. -* Increase the limit with the `esql.query.result_truncation_max_size` static -cluster setting. +of rows up to 10,000 using the <> command. +include::processing-commands/limit.asciidoc[tag=limitation] [discrete] [[esql-supported-types]] diff --git a/docs/reference/esql/processing-commands/limit.asciidoc b/docs/reference/esql/processing-commands/limit.asciidoc index 457d5e9e65223..5f659fc493a75 100644 --- a/docs/reference/esql/processing-commands/limit.asciidoc +++ b/docs/reference/esql/processing-commands/limit.asciidoc @@ -17,11 +17,27 @@ The maximum number of rows to return. *Description* The `LIMIT` processing command enables you to limit the number of rows that are -returned. If not specified, `LIMIT` defaults to `500`. +returned. +// tag::limitation[] +Queries do not return more than 10,000 rows, regardless of the `LIMIT` command's +value. -A query does not return more than 10,000 rows, regardless of the `LIMIT` value. -You can change this with the `esql.query.result_truncation_max_size` static -cluster setting. +This limit only applies to the number of rows that are retrieved by the query. +Queries and aggregations run on the full data set. + +To overcome this limitation: + +* Reduce the result set size by modifying the query to only return relevant +data. Use <> to select a smaller subset of the data. +* Shift any post-query processing to the query itself. You can use the {esql} +<> command to aggregate data in the query. + +The default and maximum limits can be changed using these dynamic cluster +settings: + +* `esql.query.result_truncation_default_size` +* `esql.query.result_truncation_max_size` +// end::limitation[] *Example* From b74effe2bc19b8dc66e35cf684bbb6b4b2967ea3 Mon Sep 17 00:00:00 2001 From: David Turner Date: Mon, 6 Nov 2023 14:24:16 +0000 Subject: [PATCH 06/21] Introduce ESTestCase-createThreadPool (#101813) Adds `ESTestCase#createThreadPool` to create a `TestThreadPool` named after the running test, and fixes up the places that already had a slightly different `createThreadPool` utility. --- .../org/elasticsearch/test/ESTestCase.java | 6 ++++++ .../HuggingFaceElserActionTests.java | 4 ++-- .../external/http/HttpClientManagerTests.java | 4 ++-- .../external/http/HttpClientTests.java | 4 ++-- .../http/IdleConnectionEvictorTests.java | 4 ++-- .../xpack/inference/external/http/Utils.java | 21 +++++++------------ .../HttpRequestExecutorServiceTests.java | 4 ++-- .../sender/HttpRequestSenderFactoryTests.java | 4 ++-- .../http/sender/RequestTaskTests.java | 4 ++-- .../huggingface/HuggingFaceClientTests.java | 4 ++-- .../logging/ThrottlerManagerTests.java | 4 ++-- .../inference/logging/ThrottlerTests.java | 4 ++-- 12 files changed, 34 insertions(+), 33 deletions(-) diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java index d3f01f03ed61c..5589e1b94281d 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java @@ -104,6 +104,8 @@ import org.elasticsearch.search.MockSearchService; import org.elasticsearch.test.junit.listeners.LoggingListener; import org.elasticsearch.test.junit.listeners.ReproduceInfoPrinter; +import org.elasticsearch.threadpool.ExecutorBuilder; +import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.LeakTracker; import org.elasticsearch.transport.netty4.Netty4Plugin; @@ -1263,6 +1265,10 @@ public static boolean waitUntil(BooleanSupplier breakSupplier, long maxWaitTime, return breakSupplier.getAsBoolean(); } + protected TestThreadPool createThreadPool(ExecutorBuilder... executorBuilders) { + return new TestThreadPool(getTestName(), executorBuilders); + } + public static boolean terminate(ExecutorService... services) { boolean terminated = true; for (ExecutorService service : services) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceElserActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceElserActionTests.java index 713312204e65b..9809acf536c86 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceElserActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceElserActionTests.java @@ -35,9 +35,9 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; -import static org.elasticsearch.xpack.inference.external.http.Utils.createThreadPool; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.external.http.Utils.mockClusterServiceEmpty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -55,7 +55,7 @@ public class HuggingFaceElserActionTests extends ESTestCase { @Before public void init() throws Exception { webServer.start(); - threadPool = createThreadPool(getTestName()); + threadPool = createThreadPool(inferenceUtilityPool()); clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/HttpClientManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/HttpClientManagerTests.java index 3e07bd773c65e..246e7d6d44c5a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/HttpClientManagerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/HttpClientManagerTests.java @@ -26,7 +26,7 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.external.http.HttpClientTests.createHttpPost; -import static org.elasticsearch.xpack.inference.external.http.Utils.createThreadPool; +import static org.elasticsearch.xpack.inference.external.http.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.external.http.Utils.mockClusterService; import static org.elasticsearch.xpack.inference.external.http.Utils.mockClusterServiceEmpty; import static org.hamcrest.Matchers.equalTo; @@ -46,7 +46,7 @@ public class HttpClientManagerTests extends ESTestCase { @Before public void init() throws Exception { webServer.start(); - threadPool = createThreadPool(getTestName()); + threadPool = createThreadPool(inferenceUtilityPool()); } @After diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/HttpClientTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/HttpClientTests.java index c72d9167a9e06..3a7ec9d1b0f55 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/HttpClientTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/HttpClientTests.java @@ -42,7 +42,7 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.core.Strings.format; -import static org.elasticsearch.xpack.inference.external.http.Utils.createThreadPool; +import static org.elasticsearch.xpack.inference.external.http.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.external.http.Utils.mockClusterService; import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager; import static org.hamcrest.Matchers.equalTo; @@ -63,7 +63,7 @@ public class HttpClientTests extends ESTestCase { @Before public void init() throws Exception { webServer.start(); - threadPool = createThreadPool(getTestName()); + threadPool = createThreadPool(inferenceUtilityPool()); } @After diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/IdleConnectionEvictorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/IdleConnectionEvictorTests.java index dba80923c487d..a46586fa6121b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/IdleConnectionEvictorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/IdleConnectionEvictorTests.java @@ -20,7 +20,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import static org.elasticsearch.xpack.inference.external.http.Utils.createThreadPool; +import static org.elasticsearch.xpack.inference.external.http.Utils.inferenceUtilityPool; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.doAnswer; @@ -36,7 +36,7 @@ public class IdleConnectionEvictorTests extends ESTestCase { @Before public void init() { - threadPool = createThreadPool(getTestName()); + threadPool = createThreadPool(inferenceUtilityPool()); } @After diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/Utils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/Utils.java index becb0cc43e1e8..22c36fe38a25c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/Utils.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/Utils.java @@ -13,8 +13,6 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ScalingExecutorBuilder; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.DeprecationHandler; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.XContentParser; @@ -81,17 +79,14 @@ public static Map entityAsMap(InputStream body) throws IOExcepti } } - public static ThreadPool createThreadPool(String name) { - return new TestThreadPool( - name, - new ScalingExecutorBuilder( - UTILITY_THREAD_POOL_NAME, - 1, - 4, - TimeValue.timeValueMinutes(10), - false, - "xpack.inference.utility_thread_pool" - ) + public static ScalingExecutorBuilder inferenceUtilityPool() { + return new ScalingExecutorBuilder( + UTILITY_THREAD_POOL_NAME, + 1, + 4, + TimeValue.timeValueMinutes(10), + false, + "xpack.inference.utility_thread_pool" ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestExecutorServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestExecutorServiceTests.java index 245ce09848a7f..2dd31144b3bc2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestExecutorServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestExecutorServiceTests.java @@ -29,7 +29,7 @@ import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.external.http.HttpClientTests.createHttpPost; -import static org.elasticsearch.xpack.inference.external.http.Utils.createThreadPool; +import static org.elasticsearch.xpack.inference.external.http.Utils.inferenceUtilityPool; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; @@ -45,7 +45,7 @@ public class HttpRequestExecutorServiceTests extends ESTestCase { @Before public void init() { - threadPool = createThreadPool(getTestName()); + threadPool = createThreadPool(inferenceUtilityPool()); } @After diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderFactoryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderFactoryTests.java index 3434b951147d7..82c41794695fd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderFactoryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderFactoryTests.java @@ -35,7 +35,7 @@ import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.external.http.HttpClientTests.createHttpPost; -import static org.elasticsearch.xpack.inference.external.http.Utils.createThreadPool; +import static org.elasticsearch.xpack.inference.external.http.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.external.http.Utils.mockClusterServiceEmpty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -56,7 +56,7 @@ public class HttpRequestSenderFactoryTests extends ESTestCase { @Before public void init() throws Exception { webServer.start(); - threadPool = createThreadPool(getTestName()); + threadPool = createThreadPool(inferenceUtilityPool()); clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); threadRef.set(null); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java index f3718954d8ad9..e6c47c891f0d7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java @@ -40,7 +40,7 @@ import static org.elasticsearch.xpack.inference.external.http.HttpClientTests.createConnectionManager; import static org.elasticsearch.xpack.inference.external.http.HttpClientTests.createHttpPost; import static org.elasticsearch.xpack.inference.external.http.HttpClientTests.emptyHttpSettings; -import static org.elasticsearch.xpack.inference.external.http.Utils.createThreadPool; +import static org.elasticsearch.xpack.inference.external.http.Utils.inferenceUtilityPool; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -63,7 +63,7 @@ public class RequestTaskTests extends ESTestCase { @Before public void init() throws Exception { webServer.start(); - threadPool = createThreadPool(getTestName()); + threadPool = createThreadPool(inferenceUtilityPool()); } @After diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/huggingface/HuggingFaceClientTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/huggingface/HuggingFaceClientTests.java index 3463067143994..0cc97ca38de80 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/huggingface/HuggingFaceClientTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/huggingface/HuggingFaceClientTests.java @@ -31,9 +31,9 @@ import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; -import static org.elasticsearch.xpack.inference.external.http.Utils.createThreadPool; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.external.http.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.request.huggingface.HuggingFaceElserRequestTests.createRequest; import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager; @@ -53,7 +53,7 @@ public class HuggingFaceClientTests extends ESTestCase { @Before public void init() throws Exception { webServer.start(); - threadPool = createThreadPool(getTestName()); + threadPool = createThreadPool(inferenceUtilityPool()); clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mockThrottlerManager()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/logging/ThrottlerManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/logging/ThrottlerManagerTests.java index 01374d02a21c3..ba9e7851c9ad4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/logging/ThrottlerManagerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/logging/ThrottlerManagerTests.java @@ -15,7 +15,7 @@ import org.junit.After; import org.junit.Before; -import static org.elasticsearch.xpack.inference.external.http.Utils.createThreadPool; +import static org.elasticsearch.xpack.inference.external.http.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.external.http.Utils.mockClusterServiceEmpty; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -31,7 +31,7 @@ public class ThrottlerManagerTests extends ESTestCase { @Before public void init() { - threadPool = createThreadPool(getTestName()); + threadPool = createThreadPool(inferenceUtilityPool()); } @After diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/logging/ThrottlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/logging/ThrottlerTests.java index df95232ff85f7..27df66c54cd1c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/logging/ThrottlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/logging/ThrottlerTests.java @@ -22,7 +22,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import static org.elasticsearch.xpack.inference.external.http.Utils.createThreadPool; +import static org.elasticsearch.xpack.inference.external.http.Utils.inferenceUtilityPool; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; @@ -40,7 +40,7 @@ public class ThrottlerTests extends ESTestCase { @Before public void init() { - threadPool = createThreadPool(getTestName()); + threadPool = createThreadPool(inferenceUtilityPool()); } @After From a35d7a02f687c9fcfea6cea7438d0cb5278a9829 Mon Sep 17 00:00:00 2001 From: Chris Cressman Date: Mon, 6 Nov 2023 09:55:19 -0500 Subject: [PATCH 07/21] [DOCS] Add Enterprise Search content to docs landing page (#101797) Revise the Elasticsearch docs landing page as follows: - Add Enterprise Search server as an operations concern - Add connectors and web crawler as ingestion concerns - Add search applications and search analytics as "search and analyze" concerns - For Search use case information, send readers to Search Labs rather than Enterprise Search --- docs/reference/landing-page.asciidoc | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/docs/reference/landing-page.asciidoc b/docs/reference/landing-page.asciidoc index 1ddd0cfa28128..a53a5770fe030 100644 --- a/docs/reference/landing-page.asciidoc +++ b/docs/reference/landing-page.asciidoc @@ -105,6 +105,9 @@
  • Troubleshooting
  • +
  • + Enterprise Search server +
  • @@ -119,6 +122,12 @@
  • Adding data to Elasticsearch
  • +
  • + Connectors +
  • +
  • + Web crawler +
  • Data streams
  • @@ -145,6 +154,12 @@
  • Query data with the Query DSL, ES|QL, EQL, or SQL
  • +
  • + Search applications +
  • +
  • + Search analytics +
  • Aggregations
  • @@ -207,7 +222,7 @@
    - +

    From 461d004591325398dac6fbde08bbcceac9024952 Mon Sep 17 00:00:00 2001 From: Przemyslaw Gomulka Date: Mon, 6 Nov 2023 16:18:04 +0100 Subject: [PATCH 08/21] Fix parsing NDJson in RecordingApmServer (#101825) depending on operating system a System.lineSeparator is different however when reading a stream always with UTF_8 it will always be '\n' This commit fixes the parsing of the ndjson that was failing on windows closes #101793 --- .../test/apmintegration/RecordingApmServer.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/external-modules/apm-integration/src/javaRestTest/java/org/elasticsearch/test/apmintegration/RecordingApmServer.java b/test/external-modules/apm-integration/src/javaRestTest/java/org/elasticsearch/test/apmintegration/RecordingApmServer.java index 4be16f02f2ce4..c3a8df2c4b150 100644 --- a/test/external-modules/apm-integration/src/javaRestTest/java/org/elasticsearch/test/apmintegration/RecordingApmServer.java +++ b/test/external-modules/apm-integration/src/javaRestTest/java/org/elasticsearch/test/apmintegration/RecordingApmServer.java @@ -17,13 +17,14 @@ import org.elasticsearch.xcontent.spi.XContentProvider; import org.junit.rules.ExternalResource; +import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; +import java.io.InputStreamReader; import java.net.InetAddress; import java.net.InetSocketAddress; import java.nio.charset.StandardCharsets; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.TimeUnit; @@ -97,7 +98,7 @@ private void handle(HttpExchange exchange) throws IOException { private List readJsonMessages(InputStream input) throws IOException { // parse NDJSON - return Arrays.stream(new String(input.readAllBytes(), StandardCharsets.UTF_8).split(System.lineSeparator())).toList(); + return new BufferedReader(new InputStreamReader(input, StandardCharsets.UTF_8)).lines().toList(); } public int getPort() { From 63f29d41031d9cf976b5de6b9dfd4ccf8fc521e8 Mon Sep 17 00:00:00 2001 From: Valeriy Khakhutskyy <1292899+valeriy42@users.noreply.github.com> Date: Mon, 6 Nov 2023 16:53:57 +0100 Subject: [PATCH 09/21] Revert "[ML] Use perAllocation and perDeployment memory usage in the model assignment planner (#98874)" (#101834) There were a number of BWC test failures after the PR was merged today. I'll revert it and investigate the failures locally. Reverts #98874 --- docs/changelog/98874.yaml | 5 - .../assignment/TrainedModelAssignment.java | 5 - .../TransportGetTrainedModelsStatsAction.java | 24 +- .../TrainedModelAssignmentClusterService.java | 7 +- .../TrainedModelAssignmentRebalancer.java | 36 +- .../planning/AbstractPreserveAllocations.java | 42 +- .../assignment/planning/AssignmentPlan.java | 139 +--- .../planning/AssignmentPlanner.java | 11 +- .../planning/LinearProgrammingPlanSolver.java | 29 +- .../planning/PreserveAllAllocations.java | 2 +- .../planning/PreserveOneAllocation.java | 2 +- .../RandomizedAssignmentRounding.java | 46 +- .../planning/ZoneAwareAssignmentPlanner.java | 16 +- ...TrainedModelAssignmentRebalancerTests.java | 81 +- .../planning/AssignmentPlanTests.java | 511 +++---------- .../planning/AssignmentPlannerTests.java | 698 +++--------------- .../planning/PreserveAllAllocationsTests.java | 228 ++---- .../planning/PreserveOneAllocationTests.java | 264 ++----- .../ZoneAwareAssignmentPlannerTests.java | 126 +--- .../MlAssignmentPlannerUpgradeIT.java | 287 ------- 20 files changed, 483 insertions(+), 2076 deletions(-) delete mode 100644 docs/changelog/98874.yaml delete mode 100644 x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MlAssignmentPlannerUpgradeIT.java diff --git a/docs/changelog/98874.yaml b/docs/changelog/98874.yaml deleted file mode 100644 index e3eb7b5acc63f..0000000000000 --- a/docs/changelog/98874.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 98874 -summary: Estimate the memory required to deploy trained models more accurately -area: Machine Learning -type: enhancement -issues: [] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java index d27d325a5c596..f69be31939b32 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java @@ -9,7 +9,6 @@ import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceNotFoundException; -import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.cluster.SimpleDiffable; import org.elasticsearch.common.Randomness; @@ -97,10 +96,6 @@ public final class TrainedModelAssignment implements SimpleDiffable 0L ? StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( model.getModelId(), totalDefinitionLength, - useNewMemoryFields ? model.getPerDeploymentMemoryBytes() : 0, - useNewMemoryFields ? model.getPerAllocationMemoryBytes() : 0, + model.getPerDeploymentMemoryBytes(), + model.getPerAllocationMemoryBytes(), numberOfAllocations ) : 0L; modelSizeStatsByModelId.put( model.getModelId(), - new TrainedModelSizeStats(totalDefinitionLength, estimatedMemoryUsageBytes) + new TrainedModelSizeStats( + totalDefinitionLength, + totalDefinitionLength > 0L + ? StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( + model.getModelId(), + totalDefinitionLength, + model.getPerDeploymentMemoryBytes(), + model.getPerAllocationMemoryBytes(), + numberOfAllocations + ) + : 0L + ) ); } else { modelSizeStatsByModelId.put(model.getModelId(), new TrainedModelSizeStats(model.getModelSize(), 0)); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java index fe4462d6556ee..2caf338d2a3c7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java @@ -47,7 +47,6 @@ import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil; -import org.elasticsearch.xpack.core.ml.utils.TransportVersionUtils; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.autoscaling.NodeAvailabilityZoneMapper; import org.elasticsearch.xpack.ml.inference.assignment.planning.AllocationReducer; @@ -77,8 +76,6 @@ public class TrainedModelAssignmentClusterService implements ClusterStateListene private static final TransportVersion RENAME_ALLOCATION_TO_ASSIGNMENT_TRANSPORT_VERSION = TransportVersions.V_8_3_0; public static final TransportVersion DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION = TransportVersions.V_8_4_0; - private static final TransportVersion NEW_ALLOCATION_MEMORY_VERSION = TransportVersions.V_8_500_064; - private final ClusterService clusterService; private final ThreadPool threadPool; private final NodeLoadDetector nodeLoadDetector; @@ -647,14 +644,12 @@ private TrainedModelAssignmentMetadata.Builder rebalanceAssignments( Map nodeLoads = detectNodeLoads(nodes, currentState); TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.fromState(currentState); - boolean useNewMemoryFields = TrainedModelAssignment.useNewMemoryFields(TransportVersionUtils.getMinTransportVersion(currentState)); TrainedModelAssignmentRebalancer rebalancer = new TrainedModelAssignmentRebalancer( currentMetadata, nodeLoads, nodeAvailabilityZoneMapper.buildMlNodesByAvailabilityZone(currentState), modelToAdd, - allocatedProcessorsScale, - useNewMemoryFields + allocatedProcessorsScale ); Set shuttingDownNodeIds = currentState.metadata().nodeShutdowns().getAllNodeIds(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java index 6e6b447fcea3d..e1241dc8a93c3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java @@ -52,22 +52,18 @@ class TrainedModelAssignmentRebalancer { private final Optional deploymentToAdd; private final int allocatedProcessorsScale; - private final boolean useNewMemoryFields; - TrainedModelAssignmentRebalancer( TrainedModelAssignmentMetadata currentMetadata, Map nodeLoads, Map, Collection> mlNodesByZone, Optional deploymentToAdd, - int allocatedProcessorsScale, - boolean useNewMemoryFields + int allocatedProcessorsScale ) { this.currentMetadata = Objects.requireNonNull(currentMetadata); this.nodeLoads = Objects.requireNonNull(nodeLoads); this.mlNodesByZone = Objects.requireNonNull(mlNodesByZone); this.deploymentToAdd = Objects.requireNonNull(deploymentToAdd); this.allocatedProcessorsScale = allocatedProcessorsScale; - this.useNewMemoryFields = useNewMemoryFields; } TrainedModelAssignmentMetadata.Builder rebalance() { @@ -142,11 +138,9 @@ private static void copyAssignments( AssignmentPlan.Node originalNode = originalNodeById.get(assignment.getKey().id()); dest.assignModelToNode(m, originalNode, assignment.getValue()); if (m.currentAllocationsByNodeId().containsKey(originalNode.id())) { - // TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder // As the node has all its available memory we need to manually account memory of models with // current allocations. - long requiredMemory = m.estimateMemoryUsageBytes(m.currentAllocationsByNodeId().get(originalNode.id())); - dest.accountMemory(m, originalNode, requiredMemory); + dest.accountMemory(m, originalNode); } } } @@ -174,14 +168,11 @@ private AssignmentPlan computePlanForNormalPriorityModels( .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getTargetAllocations())); return new AssignmentPlan.Deployment( assignment.getDeploymentId(), - assignment.getTaskParams().getModelBytes(), + assignment.getTaskParams().estimateMemoryUsageBytes(), assignment.getTaskParams().getNumberOfAllocations(), assignment.getTaskParams().getThreadsPerAllocation(), currentAssignments, - assignment.getMaxAssignedAllocations(), - // in the mixed cluster state use old memory fields to avoid unstable assignment plans - useNewMemoryFields ? assignment.getTaskParams().getPerDeploymentMemoryBytes() : 0, - useNewMemoryFields ? assignment.getTaskParams().getPerAllocationMemoryBytes() : 0 + assignment.getMaxAssignedAllocations() ); }) .forEach(planDeployments::add); @@ -190,14 +181,11 @@ private AssignmentPlan computePlanForNormalPriorityModels( planDeployments.add( new AssignmentPlan.Deployment( taskParams.getDeploymentId(), - taskParams.getModelBytes(), + taskParams.estimateMemoryUsageBytes(), taskParams.getNumberOfAllocations(), taskParams.getThreadsPerAllocation(), Map.of(), - 0, - // in the mixed cluster state use old memory fields to avoid unstable assignment plans - useNewMemoryFields ? taskParams.getPerDeploymentMemoryBytes() : 0, - useNewMemoryFields ? taskParams.getPerAllocationMemoryBytes() : 0 + 0 ) ); } @@ -229,14 +217,12 @@ private AssignmentPlan computePlanForLowPriorityModels(Set assignableNod .map( assignment -> new AssignmentPlan.Deployment( assignment.getDeploymentId(), - assignment.getTaskParams().getModelBytes(), + assignment.getTaskParams().estimateMemoryUsageBytes(), assignment.getTaskParams().getNumberOfAllocations(), assignment.getTaskParams().getThreadsPerAllocation(), findFittingAssignments(assignment, assignableNodeIds, remainingNodeMemory), assignment.getMaxAssignedAllocations(), - Priority.LOW, - (useNewMemoryFields == false) ? assignment.getTaskParams().getPerDeploymentMemoryBytes() : 0, - (useNewMemoryFields == false) ? assignment.getTaskParams().getPerAllocationMemoryBytes() : 0 + Priority.LOW ) ) .forEach(planDeployments::add); @@ -245,14 +231,12 @@ private AssignmentPlan computePlanForLowPriorityModels(Set assignableNod planDeployments.add( new AssignmentPlan.Deployment( taskParams.getDeploymentId(), - taskParams.getModelBytes(), + taskParams.estimateMemoryUsageBytes(), taskParams.getNumberOfAllocations(), taskParams.getThreadsPerAllocation(), Map.of(), 0, - Priority.LOW, - (useNewMemoryFields == false) ? taskParams.getPerDeploymentMemoryBytes() : 0, - (useNewMemoryFields == false) ? taskParams.getPerAllocationMemoryBytes() : 0 + Priority.LOW ) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java index 026b433a8c2d4..4843cc43d1187 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java @@ -35,8 +35,7 @@ private Node modifyNodePreservingAllocations(Node n) { int coresUsed = 0; for (Deployment m : deployments) { if (m.currentAllocationsByNodeId().containsKey(n.id())) { - int allocations = m.currentAllocationsByNodeId().get(n.id()); - bytesUsed += m.estimateMemoryUsageBytes(allocations); + bytesUsed += m.memoryBytes(); coresUsed += calculateUsedCores(n, m); } } @@ -59,9 +58,7 @@ Deployment modifyModelPreservingPreviousAssignments(Deployment m) { m.allocations() - calculatePreservedAllocations(m), m.threadsPerAllocation(), calculateAllocationsPerNodeToPreserve(m), - m.maxAssignedAllocations(), - m.perDeploymentMemoryBytes(), - m.perAllocationMemoryBytes() + m.maxAssignedAllocations() ); } @@ -70,37 +67,28 @@ AssignmentPlan mergePreservedAllocations(AssignmentPlan assignmentPlan) { // they will not match the models/nodes members we have in this class. // Therefore, we build a lookup table based on the ids so we can merge the plan // with its preserved allocations. - final Map, Integer> plannedAssignmentsByModelNodeIdPair = new HashMap<>(); + final Map, Integer> assignmentsByModelNodeIdPair = new HashMap<>(); for (Deployment m : assignmentPlan.models()) { Map assignments = assignmentPlan.assignments(m).orElse(Map.of()); for (Map.Entry nodeAssignment : assignments.entrySet()) { - plannedAssignmentsByModelNodeIdPair.put(Tuple.tuple(m.id(), nodeAssignment.getKey().id()), nodeAssignment.getValue()); + assignmentsByModelNodeIdPair.put(Tuple.tuple(m.id(), nodeAssignment.getKey().id()), nodeAssignment.getValue()); } } AssignmentPlan.Builder mergedPlanBuilder = AssignmentPlan.builder(nodes, deployments); - for (Node n : nodes) { - // TODO (#101612) Should the first loop happen in the builder constructor? - for (Deployment deploymentAllocationsToPreserve : deployments) { - - // if the model m is already allocated on the node n and I want to preserve this allocation - int preservedAllocations = addPreservedAllocations(n, deploymentAllocationsToPreserve); - if (preservedAllocations > 0) { - long requiredMemory = deploymentAllocationsToPreserve.estimateMemoryUsageBytes(preservedAllocations); - if (mergedPlanBuilder.canAssign(deploymentAllocationsToPreserve, n, preservedAllocations, requiredMemory)) { - mergedPlanBuilder.assignModelToNode(deploymentAllocationsToPreserve, n, preservedAllocations, requiredMemory); + for (Deployment m : deployments) { + for (Node n : nodes) { + int allocations = assignmentsByModelNodeIdPair.getOrDefault(Tuple.tuple(m.id(), n.id()), 0); + if (m.currentAllocationsByNodeId().containsKey(n.id())) { + if (mergedPlanBuilder.getRemainingMemory(n) >= m.memoryBytes()) { + allocations += addPreservedAllocations(n, m); + // As the node has all its available memory we need to manually account memory of models with + // current allocations. + mergedPlanBuilder.accountMemory(m, n); } } - } - for (Deployment deploymentNewAllocations : deployments) { - int newAllocations = plannedAssignmentsByModelNodeIdPair.getOrDefault( - Tuple.tuple(deploymentNewAllocations.id(), n.id()), - 0 - ); - - long requiredMemory = mergedPlanBuilder.getDeploymentMemoryRequirement(deploymentNewAllocations, n, newAllocations); - if (newAllocations > 0 && mergedPlanBuilder.canAssign(deploymentNewAllocations, n, newAllocations, requiredMemory)) { - mergedPlanBuilder.assignModelToNode(deploymentNewAllocations, n, newAllocations); + if (allocations > 0) { + mergedPlanBuilder.assignModelToNode(m, n, allocations); } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java index 1dce7f0bb46ba..72a83d7579463 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java @@ -10,7 +10,6 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.Maps; import org.elasticsearch.core.Tuple; -import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.assignment.Priority; import java.util.ArrayList; @@ -37,32 +36,18 @@ public record Deployment( int threadsPerAllocation, Map currentAllocationsByNodeId, int maxAssignedAllocations, - Priority priority, - long perDeploymentMemoryBytes, - long perAllocationMemoryBytes + Priority priority ) { public Deployment( String id, - long modelBytes, + long memoryBytes, int allocations, int threadsPerAllocation, Map currentAllocationsByNodeId, - int maxAssignedAllocations, - long perDeploymentMemoryBytes, - long perAllocationMemoryBytes + int maxAssignedAllocations ) { - this( - id, - modelBytes, - allocations, - threadsPerAllocation, - currentAllocationsByNodeId, - maxAssignedAllocations, - Priority.NORMAL, - perDeploymentMemoryBytes, - perAllocationMemoryBytes - ); + this(id, memoryBytes, allocations, threadsPerAllocation, currentAllocationsByNodeId, maxAssignedAllocations, Priority.NORMAL); } int getCurrentAssignedAllocations() { @@ -73,60 +58,6 @@ boolean hasEverBeenAllocated() { return maxAssignedAllocations > 0; } - public long estimateMemoryUsageBytes(int allocations) { - return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( - id, - memoryBytes, - perDeploymentMemoryBytes, - perAllocationMemoryBytes, - allocations - ); - } - - long estimateAdditionalMemoryUsageBytes(int allocationsOld, int allocationsNew) { - return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( - id, - memoryBytes, - perDeploymentMemoryBytes, - perAllocationMemoryBytes, - allocationsNew - ) - StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( - id, - memoryBytes, - perDeploymentMemoryBytes, - perAllocationMemoryBytes, - allocationsOld - ); - - } - - long minimumMemoryRequiredBytes() { - return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( - id, - memoryBytes, - perDeploymentMemoryBytes, - perAllocationMemoryBytes, - 1 - ); - } - - int findOptimalAllocations(int maxAllocations, long availableMemoryBytes) { - if (perDeploymentMemoryBytes > 0 && perAllocationMemoryBytes > 0) { - return (int) Math.max( - Math.min(maxAllocations, Math.floorDiv(availableMemoryBytes - estimateMemoryUsageBytes(0), perAllocationMemoryBytes)), - 0 - ); - } - return maxAllocations; - } - - int findExcessAllocations(int maxAllocations, long availableMemoryBytes) { - if (perDeploymentMemoryBytes > 0 && perAllocationMemoryBytes > 0) { - return (int) Math.min(maxAllocations, Math.floorDiv(availableMemoryBytes, perAllocationMemoryBytes)); - } - return maxAllocations; - } - @Override public String toString() { return id @@ -140,8 +71,6 @@ public String toString() { + currentAllocationsByNodeId + ") (max_assigned_allocations = " + maxAssignedAllocations - + ") (memory_usage = " - + ByteSizeValue.ofBytes(estimateMemoryUsageBytes(allocations)) + ")"; } }; @@ -375,42 +304,19 @@ int getRemainingAllocations(Deployment m) { } boolean canAssign(Deployment deployment, Node node, int allocations) { - long requiredMemory = getDeploymentMemoryRequirement(deployment, node, allocations); - return canAssign(deployment, node, allocations, requiredMemory); - } - - boolean canAssign(Deployment deployment, Node node, int allocations, long requiredMemory) { - return (requiredMemory <= remainingNodeMemory.get(node)) - && (deployment.priority == Priority.LOW || allocations * deployment.threadsPerAllocation() <= remainingNodeCores.get(node)); - } - - public long getDeploymentMemoryRequirement(Deployment deployment, Node node, int newAllocations) { - int assignedAllocations = getAssignedAllocations(deployment, node); - - if (assignedAllocations > 0) { - return deployment.estimateAdditionalMemoryUsageBytes(assignedAllocations, assignedAllocations + newAllocations); - } - return deployment.estimateMemoryUsageBytes(newAllocations); + return (isAlreadyAssigned(deployment, node) + || (deployment.memoryBytes() <= remainingNodeMemory.get(node)) + && (deployment.priority == Priority.LOW + || allocations * deployment.threadsPerAllocation() <= remainingNodeCores.get(node))); } public Builder assignModelToNode(Deployment deployment, Node node, int allocations) { - return assignModelToNode(deployment, node, allocations, getDeploymentMemoryRequirement(deployment, node, allocations)); - } - - public Builder assignModelToNode(Deployment deployment, Node node, int allocations, long requiredMemory) { if (allocations <= 0) { return this; } - if (/*isAlreadyAssigned(deployment, node) == false - &&*/ requiredMemory > remainingNodeMemory.get(node)) { + if (isAlreadyAssigned(deployment, node) == false && deployment.memoryBytes() > remainingNodeMemory.get(node)) { throw new IllegalArgumentException( - "not enough memory on node [" - + node.id() - + "] to assign [" - + allocations - + "] allocations to deployment [" - + deployment.id() - + "]" + "not enough memory on node [" + node.id() + "] to assign model [" + deployment.id() + "]" ); } if (deployment.priority == Priority.NORMAL && allocations * deployment.threadsPerAllocation() > remainingNodeCores.get(node)) { @@ -427,9 +333,9 @@ public Builder assignModelToNode(Deployment deployment, Node node, int allocatio ); } + long additionalModelMemory = isAlreadyAssigned(deployment, node) ? 0 : deployment.memoryBytes; assignments.get(deployment).compute(node, (n, remAllocations) -> remAllocations + allocations); - accountMemory(deployment, node, requiredMemory); - + remainingNodeMemory.compute(node, (n, remMemory) -> remMemory - additionalModelMemory); if (deployment.priority == Priority.NORMAL) { remainingNodeCores.compute(node, (n, remCores) -> remCores - allocations * deployment.threadsPerAllocation()); } @@ -441,26 +347,9 @@ private boolean isAlreadyAssigned(Deployment deployment, Node node) { return deployment.currentAllocationsByNodeId().containsKey(node.id()) || assignments.get(deployment).get(node) > 0; } - private int getAssignedAllocations(Deployment deployment, Node node) { - int currentAllocations = getCurrentAllocations(deployment, node); - int assignmentAllocations = assignments.get(deployment).get(node); - return currentAllocations + assignmentAllocations; - } - - private static int getCurrentAllocations(Deployment m, Node n) { - return m.currentAllocationsByNodeId.containsKey(n.id()) ? m.currentAllocationsByNodeId.get(n.id()) : 0; - } - public void accountMemory(Deployment m, Node n) { - // TODO (#101612) remove or refactor unused method - long requiredMemory = getDeploymentMemoryRequirement(m, n, getCurrentAllocations(m, n)); - accountMemory(m, n, requiredMemory); - } - - public void accountMemory(Deployment m, Node n, long requiredMemory) { - // TODO (#101612) computation of required memory should be done internally - remainingNodeMemory.computeIfPresent(n, (k, v) -> v - requiredMemory); - if (remainingNodeMemory.containsKey(n) && remainingNodeMemory.get(n) < 0) { + remainingNodeMemory.computeIfPresent(n, (k, v) -> v - m.memoryBytes()); + if (remainingNodeMemory.get(n) < 0) { throw new IllegalArgumentException("not enough memory on node [" + n.id() + "] to assign model [" + m.id() + "]"); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java index b1c017b1a784c..73b713cced32a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java @@ -115,11 +115,8 @@ private AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocat m.memoryBytes(), 1, m.threadsPerAllocation(), - // don't rely on the current allocation - new HashMap<>(), - m.maxAssignedAllocations(), - m.perDeploymentMemoryBytes(), - m.perAllocationMemoryBytes() + m.currentAllocationsByNodeId(), + m.maxAssignedAllocations() ) ) .toList(); @@ -148,9 +145,7 @@ private AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocat m.allocations(), m.threadsPerAllocation(), currentAllocationsByNodeId, - m.maxAssignedAllocations(), - m.perDeploymentMemoryBytes(), - m.perAllocationMemoryBytes() + m.maxAssignedAllocations() ); }).toList(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/LinearProgrammingPlanSolver.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/LinearProgrammingPlanSolver.java index bd97680e285cc..90c5a2257d94d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/LinearProgrammingPlanSolver.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/LinearProgrammingPlanSolver.java @@ -68,8 +68,6 @@ class LinearProgrammingPlanSolver { private final Map normalizedMemoryPerNode; private final Map coresPerNode; private final Map normalizedMemoryPerModel; - private final Map normalizedMemoryPerAllocation; - private final Map normalizedMinimumDeploymentMemoryRequired; private final int maxNodeCores; private final long maxModelMemoryBytes; @@ -86,17 +84,12 @@ class LinearProgrammingPlanSolver { .filter(m -> m.threadsPerAllocation() <= maxNodeCores) .toList(); - // We use the maximum memory to deploy a model with one allocation as the normalization factor. - maxModelMemoryBytes = this.deployments.stream().map(m -> m.minimumMemoryRequiredBytes()).max(Long::compareTo).orElse(1L); + maxModelMemoryBytes = this.deployments.stream().map(AssignmentPlan.Deployment::memoryBytes).max(Long::compareTo).orElse(1L); normalizedMemoryPerNode = this.nodes.stream() .collect(Collectors.toMap(Function.identity(), n -> n.availableMemoryBytes() / (double) maxModelMemoryBytes)); coresPerNode = this.nodes.stream().collect(Collectors.toMap(Function.identity(), Node::cores)); normalizedMemoryPerModel = this.deployments.stream() - .collect(Collectors.toMap(Function.identity(), m -> m.estimateMemoryUsageBytes(0) / (double) maxModelMemoryBytes)); - normalizedMemoryPerAllocation = this.deployments.stream() - .collect(Collectors.toMap(Function.identity(), m -> m.perAllocationMemoryBytes() / (double) maxModelMemoryBytes)); - normalizedMinimumDeploymentMemoryRequired = this.deployments.stream() - .collect(Collectors.toMap(Function.identity(), m -> m.minimumMemoryRequiredBytes() / (double) maxModelMemoryBytes)); + .collect(Collectors.toMap(Function.identity(), m -> m.memoryBytes() / (double) maxModelMemoryBytes)); } AssignmentPlan solvePlan(boolean useBinPackingOnly) { @@ -140,8 +133,8 @@ private double weightForAllocationVar( Node n, Map, Double> weights ) { - return (1 + weights.get(Tuple.tuple(m, n)) - (m.minimumMemoryRequiredBytes() > n.availableMemoryBytes() ? 10 : 0)) - L1 - * normalizedMemoryPerModel.get(m) / maxNodeCores; + return (1 + weights.get(Tuple.tuple(m, n)) - (m.memoryBytes() > n.availableMemoryBytes() ? 10 : 0)) - L1 * normalizedMemoryPerModel + .get(m) / maxNodeCores; } private Tuple, Double>, AssignmentPlan> calculateWeightsAndBinPackingPlan() { @@ -163,9 +156,9 @@ private Tuple, Double>, AssignmentPlan> calculateWei .sorted(Comparator.comparingDouble(n -> descendingSizeAnyFitsNodeOrder(n, m, assignmentPlan))) .toList(); for (Node n : orderedNodes) { - int allocations = m.findOptimalAllocations( - Math.min(assignmentPlan.getRemainingCores(n) / m.threadsPerAllocation(), assignmentPlan.getRemainingAllocations(m)), - assignmentPlan.getRemainingMemory(n) + int allocations = Math.min( + assignmentPlan.getRemainingCores(n) / m.threadsPerAllocation(), + assignmentPlan.getRemainingAllocations(m) ); if (allocations > 0 && assignmentPlan.canAssign(m, n, allocations)) { assignmentPlan.assignModelToNode(m, n, allocations); @@ -192,8 +185,7 @@ private Tuple, Double>, AssignmentPlan> calculateWei } private double descendingSizeAnyFitsModelOrder(AssignmentPlan.Deployment m) { - return (m.currentAllocationsByNodeId().isEmpty() ? 1 : 2) * -normalizedMinimumDeploymentMemoryRequired.get(m) * m - .threadsPerAllocation(); + return (m.currentAllocationsByNodeId().isEmpty() ? 1 : 2) * -normalizedMemoryPerModel.get(m) * m.threadsPerAllocation(); } private double descendingSizeAnyFitsNodeOrder(Node n, AssignmentPlan.Deployment m, AssignmentPlan.Builder assignmentPlan) { @@ -315,10 +307,7 @@ private boolean solveLinearProgram( List modelMemories = new ArrayList<>(); deployments.stream().filter(m -> m.currentAllocationsByNodeId().containsKey(n.id()) == false).forEach(m -> { allocations.add(allocationVars.get(Tuple.tuple(m, n))); - modelMemories.add( - (normalizedMemoryPerModel.get(m) / (double) coresPerNode.get(n) + normalizedMemoryPerAllocation.get(m)) * m - .threadsPerAllocation() - ); + modelMemories.add(normalizedMemoryPerModel.get(m) * m.threadsPerAllocation() / (double) coresPerNode.get(n)); }); model.addExpression("used_memory_on_node_" + n.id() + "_not_more_than_available") .upper(normalizedMemoryPerNode.get(n)) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocations.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocations.java index 72109941ad477..f10ece8f5a593 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocations.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocations.java @@ -37,6 +37,6 @@ protected int calculatePreservedAllocations(Deployment m) { @Override protected int addPreservedAllocations(Node n, Deployment m) { - return m.currentAllocationsByNodeId().containsKey(n.id()) ? m.currentAllocationsByNodeId().get(n.id()) : 0; + return m.currentAllocationsByNodeId().get(n.id()); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocation.java index 43b8860803596..324e1a8d69a53 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocation.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocation.java @@ -37,6 +37,6 @@ protected int calculatePreservedAllocations(AssignmentPlan.Deployment m) { @Override protected int addPreservedAllocations(Node n, AssignmentPlan.Deployment m) { - return m.currentAllocationsByNodeId().containsKey(n.id()) ? 1 : 0; + return 1; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/RandomizedAssignmentRounding.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/RandomizedAssignmentRounding.java index 8bdc99998a0c2..dafc07099f850 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/RandomizedAssignmentRounding.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/RandomizedAssignmentRounding.java @@ -135,9 +135,8 @@ private void assignUnderSubscribedNodes(Collection nodeSelection) { for (AssignmentPlan.Deployment m : deployments) { Tuple assignment = Tuple.tuple(m, n); if (assignments.get(assignment) > 0) { - int roundedAllocations = (int) Math.ceil(allocations.get(assignment)); - totalModelMemory += m.estimateMemoryUsageBytes(roundedAllocations); - maxTotalThreads += roundedAllocations * m.threadsPerAllocation(); + totalModelMemory += m.memoryBytes(); + maxTotalThreads += (int) Math.ceil(allocations.get(assignment)) * m.threadsPerAllocation(); assignedDeployments.add(m); } } @@ -200,12 +199,9 @@ private void assignExcessCores(Node n) { if (resourceTracker.remainingNodeCores.get(n) <= 0) { break; } - int extraAllocations = m.findExcessAllocations( - Math.min( - resourceTracker.remainingNodeCores.get(n) / m.threadsPerAllocation(), - resourceTracker.remainingModelAllocations.get(m) - ), - resourceTracker.remainingNodeMemory.get(n) + int extraAllocations = Math.min( + resourceTracker.remainingNodeCores.get(n) / m.threadsPerAllocation(), + resourceTracker.remainingModelAllocations.get(m) ); allocations.compute(Tuple.tuple(m, n), (k, v) -> v + extraAllocations); resourceTracker.assign(m, n, extraAllocations); @@ -215,7 +211,7 @@ private void assignExcessCores(Node n) { } private static double remainingModelOrder(AssignmentPlan.Deployment m) { - return (m.currentAllocationsByNodeId().isEmpty() ? 1 : 2) * -m.minimumMemoryRequiredBytes(); + return (m.currentAllocationsByNodeId().isEmpty() ? 1 : 2) * -m.memoryBytes(); } private boolean hasSoftAssignments(Node n) { @@ -279,17 +275,15 @@ private void doRandomizedRounding(List> s int roundedAllocations = random.nextDouble() < roundUpProbability ? (int) Math.ceil(allocations.get(assignment)) : (int) Math.floor(allocations.get(assignment)); - if (m.estimateMemoryUsageBytes(roundedAllocations) > resourceTracker.remainingNodeMemory.get(n) + + if (m.memoryBytes() > resourceTracker.remainingNodeMemory.get(n) || m.threadsPerAllocation() > resourceTracker.remainingNodeCores.get(n) || roundedAllocations == 0 || random.nextDouble() > assignments.get(assignment)) { unassign(assignment); assignUnderSubscribedNodes(Set.of(n)); } else { - roundedAllocations = m.findOptimalAllocations( - Math.min(roundedAllocations, resourceTracker.remainingNodeCores.get(n) / m.threadsPerAllocation()), - resourceTracker.remainingNodeMemory.get(n) - ); + roundedAllocations = Math.min(roundedAllocations, resourceTracker.remainingNodeCores.get(n) / m.threadsPerAllocation()); assignModelToNode(m, n, roundedAllocations); unassignOversizedModels(n); assignExcessCores(n); @@ -300,8 +294,7 @@ private void doRandomizedRounding(List> s private void unassignOversizedModels(Node n) { for (AssignmentPlan.Deployment m : deployments) { Tuple assignment = Tuple.tuple(m, n); - int roundedAllocations = (int) Math.ceil(allocations.get(assignment)); - if (assignments.get(assignment) < 1.0 && m.minimumMemoryRequiredBytes() > resourceTracker.remainingNodeMemory.get(n)) { + if (assignments.get(assignment) < 1.0 && m.memoryBytes() > resourceTracker.remainingNodeMemory.get(n)) { unassign(assignment); } } @@ -310,11 +303,7 @@ private void unassignOversizedModels(Node n) { private AssignmentPlan toPlan() { AssignmentPlan.Builder builder = AssignmentPlan.builder(nodes, deployments); for (Map.Entry, Integer> assignment : tryAssigningRemainingCores().entrySet()) { - // TODO (#101612) The model should be assigned to the node only when it is possible. This means, that canAssign should be - // integrated into the assignModelToNode. - if (builder.canAssign(assignment.getKey().v1(), assignment.getKey().v2(), assignment.getValue())) { - builder.assignModelToNode(assignment.getKey().v1(), assignment.getKey().v2(), assignment.getValue()); - } + builder.assignModelToNode(assignment.getKey().v1(), assignment.getKey().v2(), assignment.getValue()); } return builder.build(); } @@ -349,7 +338,7 @@ private Map, Integer> tryAssigningRemaini .toList()) { for (Node n : nodes.stream() .filter( - n -> resourceTracker.remainingNodeMemory.get(n) >= m.minimumMemoryRequiredBytes() + n -> resourceTracker.remainingNodeMemory.get(n) >= m.memoryBytes() && resourceTracker.remainingNodeCores.get(n) >= m.threadsPerAllocation() && resultAllocations.get(Tuple.tuple(m, n)) == 0 ) @@ -365,15 +354,10 @@ private Map, Integer> tryAssigningRemaini ) ) .toList()) { + int assigningAllocations = Math.min( resourceTracker.remainingNodeCores.get(n) / m.threadsPerAllocation(), - Math.min( - resourceTracker.remainingModelAllocations.get(m), - m.findOptimalAllocations( - resourceTracker.remainingNodeCores.get(n) / m.threadsPerAllocation(), - resourceTracker.remainingModelAllocations.get(m) - ) - ) + resourceTracker.remainingModelAllocations.get(m) ); resourceTracker.assign(m, n, assigningAllocations); resultAllocations.put(Tuple.tuple(m, n), assigningAllocations); @@ -443,7 +427,7 @@ private static class ResourceTracker { void assign(AssignmentPlan.Deployment m, Node n, int allocations) { if (assignments.contains(Tuple.tuple(m, n)) == false) { assignments.add(Tuple.tuple(m, n)); - remainingNodeMemory.compute(n, (k, v) -> v - m.estimateMemoryUsageBytes(allocations)); + remainingNodeMemory.compute(n, (k, v) -> v - m.memoryBytes()); } remainingNodeCores.compute(n, (k, v) -> v - allocations * m.threadsPerAllocation()); remainingModelAllocations.compute(m, (k, v) -> v - allocations); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java index 8c9499ca9e00c..9870aa93bf6ce 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java @@ -126,12 +126,10 @@ private AssignmentPlan computeZonePlan( modelIdToTargetAllocations.get(m.id()), m.threadsPerAllocation(), m.currentAllocationsByNodeId(), + // Only force assigning at least once previously assigned models that have not had any allocation yet (tryAssigningPreviouslyAssignedModels && modelIdToRemainingAllocations.get(m.id()) == m.allocations()) ? m.maxAssignedAllocations() - : 0, - // Only force assigning at least once previously assigned models that have not had any allocation yet - m.perDeploymentMemoryBytes(), - m.perAllocationMemoryBytes() + : 0 ) ) .toList(); @@ -153,9 +151,7 @@ private AssignmentPlan computePlanAcrossAllNodes(List plans) { m.allocations(), m.threadsPerAllocation(), allocationsByNodeIdByModelId.get(m.id()), - m.maxAssignedAllocations(), - m.perDeploymentMemoryBytes(), - m.perAllocationMemoryBytes() + m.maxAssignedAllocations() ) ) .toList(); @@ -184,13 +180,9 @@ private AssignmentPlan swapOriginalModelsInPlan( Node originalNode = originalNodeById.get(assignment.getKey().id()); planBuilder.assignModelToNode(originalDeployment, originalNode, assignment.getValue()); if (originalDeployment.currentAllocationsByNodeId().containsKey(originalNode.id())) { - // TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder // As the node has all its available memory we need to manually account memory of models with // current allocations. - long requiredMemory = originalDeployment.estimateMemoryUsageBytes( - originalDeployment.currentAllocationsByNodeId().get(originalNode.id()) - ); - planBuilder.accountMemory(m, originalNode, requiredMemory); + planBuilder.accountMemory(m, originalNode); } } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java index 334fdfbb8b922..8ccf8839cfc08 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java @@ -44,8 +44,7 @@ public void testRebalance_GivenNoAssignments() { Map.of(), Map.of(), Optional.empty(), - 1, - false + 1 ).rebalance().build(); assertThat(result.allAssignments().isEmpty(), is(true)); } @@ -79,8 +78,7 @@ public void testRebalance_GivenAllAssignmentsAreSatisfied_ShouldMakeNoChanges() nodeLoads, Map.of(), Optional.empty(), - 1, - false + 1 ).rebalance().build(); assertThat(currentMetadata, equalTo(result)); @@ -118,8 +116,7 @@ public void testRebalance_GivenAllAssignmentsAreSatisfied_GivenOutdatedRoutingEn nodeLoads, Map.of(List.of(), List.of(node1, node2)), Optional.empty(), - 1, - false + 1 ).rebalance().build(); assertThat(result.allAssignments(), is(aMapWithSize(2))); @@ -143,7 +140,7 @@ public void testRebalance_GivenModelToAddAlreadyExists() { .build(); expectThrows( ResourceAlreadyExistsException.class, - () -> new TrainedModelAssignmentRebalancer(currentMetadata, Map.of(), Map.of(), Optional.of(taskParams), 1, false).rebalance() + () -> new TrainedModelAssignmentRebalancer(currentMetadata, Map.of(), Map.of(), Optional.of(taskParams), 1).rebalance() ); } @@ -157,8 +154,7 @@ public void testRebalance_GivenFirstModelToAdd_NoMLNodes() throws Exception { Map.of(), Map.of(), Optional.of(taskParams), - 1, - false + 1 ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(modelId); @@ -185,8 +181,7 @@ public void testRebalance_GivenFirstModelToAdd_NotEnoughProcessors() throws Exce nodeLoads, Map.of(List.of(), List.of(node)), Optional.of(taskParams), - 1, - false + 1 ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(modelId); @@ -222,8 +217,7 @@ public void testRebalance_GivenFirstModelToAdd_NotEnoughMemory() throws Exceptio nodeLoads, Map.of(), Optional.of(taskParams), - 1, - false + 1 ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(modelId); @@ -259,8 +253,7 @@ public void testRebalance_GivenFirstModelToAdd_ErrorDetectingNodeLoad() throws E nodeLoads, Map.of(), Optional.of(taskParams), - 1, - false + 1 ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(modelId); @@ -296,8 +289,7 @@ public void testRebalance_GivenProblemsOnMultipleNodes() throws Exception { nodeLoads, Map.of(List.of(), List.of(node1, node2)), Optional.of(taskParams), - 1, - false + 1 ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(modelId); @@ -330,8 +322,7 @@ public void testRebalance_GivenFirstModelToAdd_FitsFully() throws Exception { nodeLoads, Map.of(List.of(), List.of(node1)), Optional.of(taskParams), - 1, - false + 1 ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(modelId); @@ -370,8 +361,7 @@ public void testRebalance_GivenModelToAdd_AndPreviousAssignments_AndTwoNodes_All nodeLoads, Map.of(List.of(), List.of(node1, node2)), Optional.of(taskParams), - 1, - false + 1 ).rebalance().build(); assertThat(result.allAssignments(), is(aMapWithSize(2))); @@ -435,8 +425,7 @@ public void testRebalance_GivenPreviousAssignments_AndNewNode() throws Exception nodeLoads, Map.of(List.of(), List.of(node1, node2, node3)), Optional.empty(), - 1, - false + 1 ).rebalance().build(); assertThat(result.allAssignments(), is(aMapWithSize(2))); @@ -500,8 +489,7 @@ public void testRebalance_GivenPreviousAssignments_AndRemovedNode_AndRemainingNo nodeLoads, Map.of(List.of(), List.of(node1)), Optional.empty(), - 1, - false + 1 ).rebalance().build(); assertThat(result.allAssignments(), is(aMapWithSize(2))); @@ -571,8 +559,7 @@ public void testRebalance_GivenPreviousAssignments_AndRemovedNode_AndRemainingNo nodeLoads, Map.of(List.of(), List.of(node1)), Optional.empty(), - 1, - false + 1 ).rebalance().build(); assertThat(result.allAssignments(), is(aMapWithSize(2))); @@ -621,8 +608,7 @@ public void testRebalance_GivenFailedAssignment_RestartsAssignment() throws Exce nodeLoads, Map.of(List.of(), List.of(node1)), Optional.empty(), - 1, - false + 1 ).rebalance().build(); assertThat(result.allAssignments(), is(aMapWithSize(1))); @@ -656,8 +642,7 @@ public void testRebalance_GivenLowPriorityModelToAdd_OnlyModel_NotEnoughMemory() nodeLoads, Map.of(), Optional.of(taskParams), - 1, - false + 1 ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(deploymentId); @@ -673,8 +658,8 @@ public void testRebalance_GivenLowPriorityModelToAdd_OnlyModel_NotEnoughMemory() public void testRebalance_GivenLowPriorityModelToAdd_NotEnoughMemoryNorProcessors() throws Exception { long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes(); - DiscoveryNode node1 = buildNode("node-1", nodeMemoryBytes, 8); - DiscoveryNode node2 = buildNode("node-2", nodeMemoryBytes, 8); + DiscoveryNode node1 = buildNode("node-1", nodeMemoryBytes, 1); + DiscoveryNode node2 = buildNode("node-2", nodeMemoryBytes, 1); Map nodeLoads = new HashMap<>(); nodeLoads.put(node1, NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build()); @@ -703,8 +688,7 @@ public void testRebalance_GivenLowPriorityModelToAdd_NotEnoughMemoryNorProcessor nodeLoads, Map.of(List.of("zone-1"), List.of(node1), List.of("zone-2"), List.of(node2)), Optional.of(taskParams1), - 1, - false + 1 ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(deployment1); @@ -743,8 +727,7 @@ public void testRebalance_GivenMixedPriorityModels_NotEnoughMemoryForLowPriority nodeLoads, Map.of(List.of(), List.of(node1)), Optional.empty(), - 1, - false + 1 ).rebalance().build(); { @@ -797,8 +780,7 @@ public void testRebalance_GivenMixedPriorityModels_TwoZones_EachNodeCanHoldOneMo nodeLoads, Map.of(List.of("zone-1"), List.of(node1), List.of("zone-2"), List.of(node2)), Optional.empty(), - 1, - false + 1 ).rebalance().build(); List assignedNodes = new ArrayList<>(); @@ -852,8 +834,7 @@ public void testRebalance_GivenModelUsingAllCpu_FittingLowPriorityModelCanStart( nodeLoads, Map.of(List.of(), List.of(node1)), Optional.empty(), - 1, - false + 1 ).rebalance().build(); { @@ -903,8 +884,7 @@ public void testRebalance_GivenMultipleLowPriorityModels_AndMultipleNodes() thro nodeLoads, Map.of(List.of(), List.of(node1, node2)), Optional.empty(), - 1, - false + 1 ).rebalance().build(); { @@ -954,8 +934,7 @@ public void testRebalance_GivenNormalPriorityModelToLoad_EvictsLowPriorityModel( nodeLoads, Map.of(List.of(), List.of(node1)), Optional.of(taskParams2), - 1, - false + 1 ).rebalance().build(); { @@ -1007,8 +986,7 @@ public void testRebalance_GivenNormalPriorityModelToLoad_AndLowPriorityModelCanS nodeLoads, Map.of(List.of(), List.of(node1, node2)), Optional.of(taskParams2), - 1, - false + 1 ).rebalance().build(); { @@ -1060,8 +1038,7 @@ public void testRebalance_GivenNormalPriorityModelToLoad_AndLowPriorityModelMust nodeLoads, Map.of(List.of(), List.of(node1, node2)), Optional.of(taskParams2), - 1, - false + 1 ).rebalance().build(); { @@ -1107,8 +1084,7 @@ public void testRebalance_GivenFirstModelToAdd_GivenScalingProcessorSetting() { nodeLoads, Map.of(List.of(), List.of(node)), Optional.of(taskParams), - 2, - false + 2 ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(modelId); @@ -1130,8 +1106,7 @@ public void testRebalance_GivenFirstModelToAdd_GivenScalingProcessorSetting() { nodeLoads, Map.of(List.of(), List.of(node)), Optional.of(taskParams), - 1, - false + 1 ).rebalance().build(); assignment = result.getDeploymentAssignment(modelId); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java index cbbb38f1d1ddd..3ecdd5000ba35 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.ml.inference.assignment.planning; -import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Deployment; import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Node; @@ -25,248 +24,109 @@ public class AssignmentPlanTests extends ESTestCase { public void testBuilderCtor_GivenDuplicateNode() { Node n = new Node("n_1", 100, 4); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0, 0, 0); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0); expectThrows(IllegalArgumentException.class, () -> AssignmentPlan.builder(List.of(n, n), List.of(m))); } public void testBuilderCtor_GivenDuplicateModel() { Node n = new Node("n_1", 100, 4); - Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0, 0, 0); + Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0); expectThrows(IllegalArgumentException.class, () -> AssignmentPlan.builder(List.of(n), List.of(m, m))); } public void testAssignModelToNode_GivenNoPreviousAssignment() { - Node n = new Node("n_1", ByteSizeValue.ofMb(350).getBytes(), 4); - - { // old memory format - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(40).getBytes(), 1, 2, Map.of(), 0, 0, 0); - - AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); - - assertThat(builder.getRemainingCores(n), equalTo(4)); - assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(350).getBytes())); - assertThat(builder.getRemainingAllocations(m), equalTo(1)); - assertThat(builder.getRemainingThreads(m), equalTo(2)); - - builder.assignModelToNode(m, n, 1); - - assertThat(builder.getRemainingCores(n), equalTo(2)); - assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(30).getBytes())); - assertThat(builder.getRemainingAllocations(m), equalTo(0)); - assertThat(builder.getRemainingThreads(m), equalTo(0)); - - AssignmentPlan plan = builder.build(); - - assertThat(plan.models(), contains(m)); - assertThat(plan.satisfiesCurrentAssignments(), is(true)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); - } - { // new memory format - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(20).getBytes(), - 1, - 2, - Map.of(), - 0, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(30).getBytes() - ); + Node n = new Node("n_1", 100, 4); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0); - AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); - assertThat(builder.getRemainingCores(n), equalTo(4)); - assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(350).getBytes())); - assertThat(builder.getRemainingAllocations(m), equalTo(1)); - assertThat(builder.getRemainingThreads(m), equalTo(2)); + assertThat(builder.getRemainingCores(n), equalTo(4)); + assertThat(builder.getRemainingMemory(n), equalTo(100L)); + assertThat(builder.getRemainingAllocations(m), equalTo(1)); + assertThat(builder.getRemainingThreads(m), equalTo(2)); - builder.assignModelToNode(m, n, 1); + builder.assignModelToNode(m, n, 1); - assertThat(builder.getRemainingCores(n), equalTo(2)); - assertThat(builder.getRemainingMemory(n), equalTo(0L)); - assertThat(builder.getRemainingAllocations(m), equalTo(0)); - assertThat(builder.getRemainingThreads(m), equalTo(0)); + assertThat(builder.getRemainingCores(n), equalTo(2)); + assertThat(builder.getRemainingMemory(n), equalTo(60L)); + assertThat(builder.getRemainingAllocations(m), equalTo(0)); + assertThat(builder.getRemainingThreads(m), equalTo(0)); - AssignmentPlan plan = builder.build(); + AssignmentPlan plan = builder.build(); - assertThat(plan.models(), contains(m)); - assertThat(plan.satisfiesCurrentAssignments(), is(true)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); - } + assertThat(plan.models(), contains(m)); + assertThat(plan.satisfiesCurrentAssignments(), is(true)); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); } public void testAssignModelToNode_GivenNewPlanSatisfiesCurrentAssignment() { - Node n = new Node("n_1", ByteSizeValue.ofMb(350).getBytes(), 4); - { // old memory format - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(30).getBytes(), - 2, - 2, - Map.of("n_1", 1), - 0, - 0, - 0 - ); - - AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); - - builder.assignModelToNode(m, n, 1); - - assertThat(builder.getRemainingCores(n), equalTo(2)); - assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(350).getBytes())); - assertThat(builder.getRemainingAllocations(m), equalTo(1)); - assertThat(builder.getRemainingThreads(m), equalTo(2)); - - AssignmentPlan plan = builder.build(); - - assertThat(plan.models(), contains(m)); - assertThat(plan.satisfiesCurrentAssignments(), is(true)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); - } - { // new memory format - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(25).getBytes(), - 2, - 2, - Map.of("n_1", 1), - 0, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(25).getBytes() - ); - - AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + Node n = new Node("n_1", 100, 4); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 40, 2, 2, Map.of("n_1", 1), 0); - builder.assignModelToNode(m, n, 1); + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); - assertThat(builder.getRemainingCores(n), equalTo(2)); - assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(325).getBytes())); - assertThat(builder.getRemainingAllocations(m), equalTo(1)); - assertThat(builder.getRemainingThreads(m), equalTo(2)); + builder.assignModelToNode(m, n, 1); - AssignmentPlan plan = builder.build(); + assertThat(builder.getRemainingCores(n), equalTo(2)); + assertThat(builder.getRemainingMemory(n), equalTo(100L)); + assertThat(builder.getRemainingAllocations(m), equalTo(1)); + assertThat(builder.getRemainingThreads(m), equalTo(2)); - assertThat(plan.models(), contains(m)); - assertThat(plan.satisfiesCurrentAssignments(), is(true)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + AssignmentPlan plan = builder.build(); - } + assertThat(plan.models(), contains(m)); + assertThat(plan.satisfiesCurrentAssignments(), is(true)); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); } public void testAssignModelToNode_GivenNewPlanDoesNotSatisfyCurrentAssignment() { - Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 4); - { - // old memory format - Deployment m = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 0, 0, 0); - - AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); - - builder.assignModelToNode(m, n, 1); - - assertThat(builder.getRemainingCores(n), equalTo(2)); - assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(300).getBytes())); - assertThat(builder.getRemainingAllocations(m), equalTo(1)); - assertThat(builder.getRemainingThreads(m), equalTo(2)); - - AssignmentPlan plan = builder.build(); - - assertThat(plan.models(), contains(m)); - assertThat(plan.satisfiesCurrentAssignments(), is(false)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); - } - { - // new memory format - Deployment m = new Deployment( - "m_1", - ByteSizeValue.ofMb(25).getBytes(), - 2, - 2, - Map.of("n_1", 2), - 0, - ByteSizeValue.ofMb(250).getBytes(), - ByteSizeValue.ofMb(25).getBytes() - ); + Node n = new Node("n_1", 100, 4); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 40, 2, 2, Map.of("n_1", 2), 0); - AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); - builder.assignModelToNode(m, n, 1); + builder.assignModelToNode(m, n, 1); - assertThat(builder.getRemainingCores(n), equalTo(2)); - assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(275).getBytes())); - assertThat(builder.getRemainingAllocations(m), equalTo(1)); - assertThat(builder.getRemainingThreads(m), equalTo(2)); + assertThat(builder.getRemainingCores(n), equalTo(2)); + assertThat(builder.getRemainingMemory(n), equalTo(100L)); + assertThat(builder.getRemainingAllocations(m), equalTo(1)); + assertThat(builder.getRemainingThreads(m), equalTo(2)); - AssignmentPlan plan = builder.build(); + AssignmentPlan plan = builder.build(); - assertThat(plan.models(), contains(m)); - assertThat(plan.satisfiesCurrentAssignments(), is(false)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); - } + assertThat(plan.models(), contains(m)); + assertThat(plan.satisfiesCurrentAssignments(), is(false)); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); } public void testAssignModelToNode_GivenPreviouslyUnassignedModelDoesNotFit() { - Node n = new Node("n_1", ByteSizeValue.ofMb(340 - 1).getBytes(), 4); - Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 2, 2, Map.of(), 0, 0, 0); + Node n = new Node("n_1", 100, 4); + Deployment m = new AssignmentPlan.Deployment("m_1", 101, 2, 2, Map.of(), 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 1)); - assertThat(e.getMessage(), equalTo("not enough memory on node [n_1] to assign [1] allocations to deployment [m_1]")); + assertThat(e.getMessage(), equalTo("not enough memory on node [n_1] to assign model [m_1]")); } public void testAssignModelToNode_GivenPreviouslyAssignedModelDoesNotFit() { - { // old memory format - Node n = new Node("n_1", ByteSizeValue.ofMb(340 - 1).getBytes(), 4); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(50).getBytes(), - 2, - 2, - Map.of("n_1", 1), - 0, - 0, - 0 - ); - - AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); - - builder.assignModelToNode(m, n, 2); - AssignmentPlan plan = builder.build(); - - assertThat(plan.models(), contains(m)); - assertThat(plan.satisfiesCurrentAssignments(), is(true)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 2))); - } - { // new memory format - Node n = new Node("n_1", ByteSizeValue.ofMb(340 - 1).getBytes(), 4); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(30).getBytes(), - 2, - 2, - Map.of("n_1", 1), - 0, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(5).getBytes() - ); - - AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + Node n = new Node("n_1", 100, 4); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 101, 2, 2, Map.of("n_1", 1), 0); - builder.assignModelToNode(m, n, 2); - AssignmentPlan plan = builder.build(); + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + builder.assignModelToNode(m, n, 2); + AssignmentPlan plan = builder.build(); - assertThat(plan.models(), contains(m)); - assertThat(plan.satisfiesCurrentAssignments(), is(true)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 2))); - } + assertThat(plan.models(), contains(m)); + assertThat(plan.satisfiesCurrentAssignments(), is(true)); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 2))); } public void testAssignModelToNode_GivenNotEnoughCores_AndSingleThreadPerAllocation() { - Node n = new Node("n_1", ByteSizeValue.ofMb(500).getBytes(), 4); - Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 5, 1, Map.of(), 0, 0, 0); + Node n = new Node("n_1", 100, 4); + Deployment m = new AssignmentPlan.Deployment("m_1", 100, 5, 1, Map.of(), 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 5)); @@ -278,8 +138,8 @@ public void testAssignModelToNode_GivenNotEnoughCores_AndSingleThreadPerAllocati } public void testAssignModelToNode_GivenNotEnoughCores_AndMultipleThreadsPerAllocation() { - Node n = new Node("n_1", ByteSizeValue.ofMb(500).getBytes(), 5); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 3, 2, Map.of(), 0, 0, 0); + Node n = new Node("n_1", 100, 5); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 100, 3, 2, Map.of(), 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 3)); @@ -291,22 +151,13 @@ public void testAssignModelToNode_GivenNotEnoughCores_AndMultipleThreadsPerAlloc } public void testAssignModelToNode_GivenSameModelAssignedTwice() { - Node n = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 8); - Deployment m = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(50).getBytes(), - 4, - 2, - Map.of(), - 0, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(50).getBytes() - ); + Node n = new Node("n_1", 100, 8); + Deployment m = new AssignmentPlan.Deployment("m_1", 60, 4, 2, Map.of(), 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); assertThat(builder.getRemainingCores(n), equalTo(8)); - assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(1000).getBytes())); + assertThat(builder.getRemainingMemory(n), equalTo(100L)); assertThat(builder.getRemainingAllocations(m), equalTo(4)); assertThat(builder.getRemainingThreads(m), equalTo(8)); assertThat(builder.canAssign(m, n, 1), is(true)); @@ -314,7 +165,7 @@ public void testAssignModelToNode_GivenSameModelAssignedTwice() { builder.assignModelToNode(m, n, 1); assertThat(builder.getRemainingCores(n), equalTo(6)); - assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(600).getBytes())); + assertThat(builder.getRemainingMemory(n), equalTo(40L)); assertThat(builder.getRemainingAllocations(m), equalTo(3)); assertThat(builder.getRemainingThreads(m), equalTo(6)); assertThat(builder.canAssign(m, n, 2), is(true)); @@ -322,7 +173,7 @@ public void testAssignModelToNode_GivenSameModelAssignedTwice() { builder.assignModelToNode(m, n, 2); assertThat(builder.getRemainingCores(n), equalTo(2)); - assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(500).getBytes())); + assertThat(builder.getRemainingMemory(n), equalTo(40L)); assertThat(builder.getRemainingAllocations(m), equalTo(1)); assertThat(builder.getRemainingThreads(m), equalTo(2)); @@ -335,7 +186,7 @@ public void testAssignModelToNode_GivenSameModelAssignedTwice() { public void testCanAssign_GivenPreviouslyUnassignedModelDoesNotFit() { Node n = new Node("n_1", 100, 5); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 101, 1, 1, Map.of(), 0, 0, 0); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 101, 1, 1, Map.of(), 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); @@ -343,33 +194,17 @@ public void testCanAssign_GivenPreviouslyUnassignedModelDoesNotFit() { } public void testCanAssign_GivenPreviouslyAssignedModelDoesNotFit() { - Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 5); - { - // old memory format - Deployment m = new Deployment("m_1", ByteSizeValue.ofMb(31).getBytes(), 1, 1, Map.of("n_1", 1), 0, 0, 0); - AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); - assertThat(builder.canAssign(m, n, 1), is(true)); - } - { - // new memory format - Deployment m = new Deployment( - "m_1", - ByteSizeValue.ofMb(25).getBytes(), - 1, - 1, - Map.of("n_1", 1), - 0, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(10).getBytes() - ); - AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); - assertThat(builder.canAssign(m, n, 1), is(true)); - } + Node n = new Node("n_1", 100, 5); + Deployment m = new AssignmentPlan.Deployment("m_1", 101, 1, 1, Map.of("n_1", 1), 0); + + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + + assertThat(builder.canAssign(m, n, 1), is(true)); } public void testCanAssign_GivenEnoughMemory() { - Node n = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 5); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 3, 2, Map.of(), 0, 0, 0); + Node n = new Node("n_1", 100, 5); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 100, 3, 2, Map.of(), 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); @@ -381,25 +216,16 @@ public void testCanAssign_GivenEnoughMemory() { public void testCompareTo_GivenDifferenceInPreviousAssignments() { AssignmentPlan planSatisfyingPreviousAssignments; AssignmentPlan planNotSatisfyingPreviousAssignments; - Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 5); + Node n = new Node("n_1", 100, 5); { - Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 3, 2, Map.of("n_1", 2), 0, 0, 0); + Deployment m = new AssignmentPlan.Deployment("m_1", 100, 3, 2, Map.of("n_1", 2), 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); builder.assignModelToNode(m, n, 2); planSatisfyingPreviousAssignments = builder.build(); } { - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(30).getBytes(), - 3, - 2, - Map.of("n_1", 3), - 0, - 0, - 0 - ); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 100, 3, 2, Map.of("n_1", 3), 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); builder.assignModelToNode(m, n, 2); planNotSatisfyingPreviousAssignments = builder.build(); @@ -412,17 +238,8 @@ public void testCompareTo_GivenDifferenceInPreviousAssignments() { public void testCompareTo_GivenDifferenceInAllocations() { AssignmentPlan planWithMoreAllocations; AssignmentPlan planWithFewerAllocations; - Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 5); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(30).getBytes(), - 3, - 2, - Map.of("n_1", 1), - 0, - 0, - 0 - ); + Node n = new Node("n_1", 100, 5); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 100, 3, 2, Map.of("n_1", 1), 0); { AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); @@ -442,25 +259,16 @@ public void testCompareTo_GivenDifferenceInAllocations() { public void testCompareTo_GivenDifferenceInMemory() { AssignmentPlan planUsingMoreMemory; AssignmentPlan planUsingLessMemory; - Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 5); + Node n = new Node("n_1", 100, 5); { - Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 3, 2, Map.of("n_1", 1), 0, 0, 0); + Deployment m = new AssignmentPlan.Deployment("m_1", 100, 3, 2, Map.of("n_1", 1), 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); builder.assignModelToNode(m, n, 2); planUsingMoreMemory = builder.build(); } { - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(29).getBytes(), - 3, - 2, - Map.of("n_1", 1), - 0, - 0, - 0 - ); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 99, 3, 2, Map.of("n_1", 1), 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); builder.assignModelToNode(m, n, 2); planUsingLessMemory = builder.build(); @@ -471,96 +279,26 @@ public void testCompareTo_GivenDifferenceInMemory() { } public void testSatisfiesAllModels_GivenAllModelsAreSatisfied() { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); - { - // old memory format - AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(50).getBytes(), - 1, - 2, - Map.of(), - 0, - 0, - 0 - ); - AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment( - "m_2", - ByteSizeValue.ofMb(30).getBytes(), - 2, - 1, - Map.of(), - 0, - 0, - 0 - ); - AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment( - "m_3", - ByteSizeValue.ofMb(20).getBytes(), - 4, - 1, - Map.of(), - 0, - 0, - 0 - ); - AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) - .assignModelToNode(deployment1, node1, 1) - .assignModelToNode(deployment2, node2, 2) - .assignModelToNode(deployment3, node1, 2) - .assignModelToNode(deployment3, node2, 2) - .build(); - assertThat(plan.satisfiesAllModels(), is(true)); - } - { - // new memory format - AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(50).getBytes(), - 1, - 2, - Map.of(), - 0, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(10).getBytes() - ); - AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment( - "m_2", - ByteSizeValue.ofMb(30).getBytes(), - 2, - 1, - Map.of(), - 0, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(10).getBytes() - ); - AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment( - "m_3", - ByteSizeValue.ofMb(20).getBytes(), - 4, - 1, - Map.of(), - 0, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(10).getBytes() - ); - AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) - .assignModelToNode(deployment1, node1, 1) - .assignModelToNode(deployment2, node2, 2) - .assignModelToNode(deployment3, node1, 2) - .assignModelToNode(deployment3, node2, 2) - .build(); - assertThat(plan.satisfiesAllModels(), is(true)); - } + Node node1 = new Node("n_1", 100, 4); + Node node2 = new Node("n_2", 100, 4); + AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 50, 1, 2, Map.of(), 0); + AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment("m_2", 30, 2, 1, Map.of(), 0); + AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment("m_3", 20, 4, 1, Map.of(), 0); + AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) + .assignModelToNode(deployment1, node1, 1) + .assignModelToNode(deployment2, node2, 2) + .assignModelToNode(deployment3, node1, 2) + .assignModelToNode(deployment3, node2, 2) + .build(); + assertThat(plan.satisfiesAllModels(), is(true)); } public void testSatisfiesAllModels_GivenOneModelHasOneAllocationLess() { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(20).getBytes(), 4, 1, Map.of(), 0, 0, 0); + Node node1 = new Node("n_1", 100, 4); + Node node2 = new Node("n_2", 100, 4); + AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 50, 1, 2, Map.of(), 0); + AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment("m_2", 30, 2, 1, Map.of(), 0); + Deployment deployment3 = new Deployment("m_3", 20, 4, 1, Map.of(), 0); AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) .assignModelToNode(deployment1, node1, 1) .assignModelToNode(deployment2, node2, 2) @@ -571,11 +309,11 @@ public void testSatisfiesAllModels_GivenOneModelHasOneAllocationLess() { } public void testArePreviouslyAssignedModelsAssigned_GivenTrue() { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 4, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(20).getBytes(), 4, 1, Map.of(), 0, 0, 0); + Node node1 = new Node("n_1", 100, 4); + Node node2 = new Node("n_2", 100, 4); + AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 50, 1, 2, Map.of(), 3); + AssignmentPlan.Deployment deployment2 = new Deployment("m_2", 30, 2, 1, Map.of(), 4); + AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment("m_3", 20, 4, 1, Map.of(), 0); AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) .assignModelToNode(deployment1, node1, 1) .assignModelToNode(deployment2, node2, 1) @@ -584,10 +322,10 @@ public void testArePreviouslyAssignedModelsAssigned_GivenTrue() { } public void testArePreviouslyAssignedModelsAssigned_GivenFalse() { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 4, 0, 0); + Node node1 = new Node("n_1", 100, 4); + Node node2 = new Node("n_2", 100, 4); + AssignmentPlan.Deployment deployment1 = new Deployment("m_1", 50, 1, 2, Map.of(), 3); + AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment("m_2", 30, 2, 1, Map.of(), 4); AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) .assignModelToNode(deployment1, node1, 1) .build(); @@ -595,39 +333,12 @@ public void testArePreviouslyAssignedModelsAssigned_GivenFalse() { } public void testCountPreviouslyAssignedThatAreStillAssigned() { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); - Deployment deployment1 = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, 0, 0); - AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment( - "m_2", - ByteSizeValue.ofMb(30).getBytes(), - 2, - 1, - Map.of(), - 4, - 0, - 0 - ); - AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment( - "m_3", - ByteSizeValue.ofMb(20).getBytes(), - 4, - 1, - Map.of(), - 1, - 0, - 0 - ); - AssignmentPlan.Deployment deployment4 = new AssignmentPlan.Deployment( - "m_4", - ByteSizeValue.ofMb(20).getBytes(), - 4, - 1, - Map.of(), - 0, - 0, - 0 - ); + Node node1 = new Node("n_1", 100, 4); + Node node2 = new Node("n_2", 100, 4); + Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 50, 1, 2, Map.of(), 3); + AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment("m_2", 30, 2, 1, Map.of(), 4); + AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment("m_3", 20, 4, 1, Map.of(), 1); + AssignmentPlan.Deployment deployment4 = new AssignmentPlan.Deployment("m_4", 20, 4, 1, Map.of(), 0); AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3, deployment4)) .assignModelToNode(deployment1, node1, 1) .assignModelToNode(deployment2, node2, 1) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java index 6a72ccf4c4445..82a291a8d9fb2 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java @@ -33,144 +33,50 @@ public class AssignmentPlannerTests extends ESTestCase { - private static long scaleNodeSize(long nodeMemory) { - // 240 Mb is the size in StartTrainedModelDeploymentAction.MEMORY_OVERHEAD - return ByteSizeValue.ofMb(240 + 2 * nodeMemory).getBytes(); - } - public void testModelThatDoesNotFitInMemory() { - { // Without perDeploymentMemory and perAllocationMemory specified - List nodes = List.of(new Node("n_1", scaleNodeSize(50), 4)); - Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(51).getBytes(), 4, 1, Map.of(), 0, 0, 0); - AssignmentPlan plan = new AssignmentPlanner(nodes, List.of(deployment)).computePlan(); - assertThat(plan.assignments(deployment).isEmpty(), is(true)); - } - { // With perDeploymentMemory and perAllocationMemory specified - List nodes = List.of(new Node("n_1", scaleNodeSize(55), 4)); - Deployment deployment = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(50).getBytes(), - 4, - 1, - Map.of(), - 0, - ByteSizeValue.ofMb(250).getBytes(), - ByteSizeValue.ofMb(51).getBytes() - ); - AssignmentPlan plan = new AssignmentPlanner(nodes, List.of(deployment)).computePlan(); - assertThat(plan.assignments(deployment).isEmpty(), is(true)); - } + List nodes = List.of(new Node("n_1", 100, 4)); + Deployment deployment = new AssignmentPlan.Deployment("m_1", 101, 4, 1, Map.of(), 0); + AssignmentPlan plan = new AssignmentPlanner(nodes, List.of(deployment)).computePlan(); + assertThat(plan.assignments(deployment).isEmpty(), is(true)); } public void testModelWithThreadsPerAllocationNotFittingOnAnyNode() { - List nodes = List.of(new Node("n_1", scaleNodeSize(100), 4), new Node("n_2", scaleNodeSize(100), 5)); - Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(1).getBytes(), 1, 6, Map.of(), 0, 0, 0); + List nodes = List.of(new Node("n_1", 100, 4), new Node("n_2", 100, 5)); + Deployment deployment = new AssignmentPlan.Deployment("m_1", 1, 1, 6, Map.of(), 0); AssignmentPlan plan = new AssignmentPlanner(nodes, List.of(deployment)).computePlan(); assertThat(plan.assignments(deployment).isEmpty(), is(true)); } public void testSingleModelThatFitsFullyOnSingleNode() { { - Node node = new Node("n_1", scaleNodeSize(100), 4); - Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 1, 1, Map.of(), 0, 0, 0); - AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); - assertModelFullyAssignedToNode(plan, deployment, node); - } - { - Node node = new Node("n_1", scaleNodeSize(1000), 8); - Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(1000).getBytes(), 8, 1, Map.of(), 0, 0, 0); + Node node = new Node("n_1", 100, 4); + Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 1, 1, Map.of(), 0); AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertModelFullyAssignedToNode(plan, deployment, node); } { - Node node = new Node("n_1", scaleNodeSize(10000), 16); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(10000).getBytes(), - 1, - 16, - Map.of(), - 0, - 0, - 0 - ); + Node node = new Node("n_1", 1000, 8); + Deployment deployment = new Deployment("m_1", 1000, 8, 1, Map.of(), 0); AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertModelFullyAssignedToNode(plan, deployment, node); } { - Node node = new Node("n_1", scaleNodeSize(100), 4); - Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 1, 1, Map.of(), 0, 0, 0); - AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); - assertModelFullyAssignedToNode(plan, deployment, node); - } - } - - public void testSingleModelThatFitsFullyOnSingleNode_NewMemoryFields() { - { - Node node = new Node("n_1", ByteSizeValue.ofMb(500).getBytes(), 4); - Deployment deployment = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(100).getBytes(), - 1, - 1, - Map.of(), - 0, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(100).getBytes() - ); - AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); - assertModelFullyAssignedToNode(plan, deployment, node); - } - { - Node node = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 8); - Deployment deployment = new Deployment( - "m_1", - ByteSizeValue.ofMb(100).getBytes(), - 8, - 1, - Map.of(), - 0, - ByteSizeValue.ofMb(100).getBytes(), - ByteSizeValue.ofMb(100).getBytes() - ); + Node node = new Node("n_1", 10000, 16); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 10000, 1, 16, Map.of(), 0); AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertModelFullyAssignedToNode(plan, deployment, node); } } public void testSingleModelThatFitsFullyOnSingleNode_GivenTwoNodes_ShouldBeFullyAssignedOnOneNode() { - Node node1 = new Node("n_1", scaleNodeSize(100), 4); - Node node2 = new Node("n_2", scaleNodeSize(100), 4); - AssignmentPlan.Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 4, 1, Map.of(), 0, 0, 0); + Node node1 = new Node("n_1", 100, 4); + Node node2 = new Node("n_2", 100, 4); + AssignmentPlan.Deployment deployment = new Deployment("m_1", 100, 4, 1, Map.of(), 0); AssignmentPlan plan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); Map assignments = plan.assignments(deployment).get(); - if (assignments.get(node1) != null) { - assertThat(assignments.get(node1), equalTo(4)); - } else { - assertThat(assignments.get(node2), equalTo(4)); - } - } - - public void testSingleModelThatFitsFullyOnSingleNode_GivenTwoNodes_ShouldBeFullyAssignedOnOneNode_NewMemoryFields() { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); - AssignmentPlan.Deployment deployment = new Deployment( - "m_1", - ByteSizeValue.ofMb(100).getBytes(), - 4, - 1, - Map.of(), - 0, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(150).getBytes() - ); - - AssignmentPlan plan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); - - Map assignments = plan.assignments(deployment).get(); - if (assignments.get(node1) != null) { + if (assignments.get(node1) > 0) { assertThat(assignments.get(node1), equalTo(4)); } else { assertThat(assignments.get(node2), equalTo(4)); @@ -178,53 +84,10 @@ public void testSingleModelThatFitsFullyOnSingleNode_GivenTwoNodes_ShouldBeFully } public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerAllocation() { - AssignmentPlan.Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 10, 1, Map.of(), 0, 0, 0); - // Single node - { - Node node = new Node("n_1", scaleNodeSize(100), 4); - AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); - assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); - Map assignments = assignmentPlan.assignments(deployment).get(); - assertThat(assignments.get(node), equalTo(4)); - } - // Two nodes - { - Node node1 = new Node("n_1", scaleNodeSize(100), 4); - Node node2 = new Node("n_2", scaleNodeSize(100), 2); - AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); - assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); - Map assignments = assignmentPlan.assignments(deployment).get(); - assertThat(assignments.get(node1), equalTo(4)); - assertThat(assignments.get(node2), equalTo(2)); - } - // Three nodes - { - Node node1 = new Node("n_1", scaleNodeSize(100), 4); - Node node2 = new Node("n_2", scaleNodeSize(100), 2); - Node node3 = new Node("n_3", scaleNodeSize(100), 3); - AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment)).computePlan(); - assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); - Map assignments = assignmentPlan.assignments(deployment).get(); - assertThat(assignments.get(node1), equalTo(4)); - assertThat(assignments.get(node2), equalTo(2)); - assertThat(assignments.get(node3), equalTo(3)); - } - } - - public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerAllocation_NewMemoryFields() { - AssignmentPlan.Deployment deployment = new Deployment( - "m_1", - ByteSizeValue.ofMb(100).getBytes(), - 10, - 1, - Map.of(), - 0, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(100).getBytes() - ); + AssignmentPlan.Deployment deployment = new Deployment("m_1", 30, 10, 1, Map.of(), 0); // Single node { - Node node = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 4); + Node node = new Node("n_1", 100, 4); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); Map assignments = assignmentPlan.assignments(deployment).get(); @@ -232,8 +95,8 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerA } // Two nodes { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 4); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(600).getBytes(), 2); + Node node1 = new Node("n_1", 100, 4); + Node node2 = new Node("n_2", 100, 2); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); Map assignments = assignmentPlan.assignments(deployment).get(); @@ -242,9 +105,9 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerA } // Three nodes { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 4); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(600).getBytes(), 2); - Node node3 = new Node("n_3", ByteSizeValue.ofMb(700).getBytes(), 3); + Node node1 = new Node("n_1", 100, 4); + Node node2 = new Node("n_2", 100, 2); + Node node3 = new Node("n_3", 100, 3); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); Map assignments = assignmentPlan.assignments(deployment).get(); @@ -255,105 +118,14 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerA } public void testMultipleModelsAndNodesWithSingleSolution() { - Node node1 = new Node("n_1", 2 * scaleNodeSize(50), 7); - Node node2 = new Node("n_2", 2 * scaleNodeSize(50), 7); - Node node3 = new Node("n_3", 2 * scaleNodeSize(50), 2); - Node node4 = new Node("n_4", 2 * scaleNodeSize(50), 2); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 2, 4, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(50).getBytes(), 2, 3, Map.of(), 0, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 0, 0, 0); - Deployment deployment4 = new Deployment("m_4", ByteSizeValue.ofMb(50).getBytes(), 2, 1, Map.of(), 0, 0, 0); - - AssignmentPlan plan = new AssignmentPlanner( - List.of(node1, node2, node3, node4), - List.of(deployment1, deployment2, deployment3, deployment4) - ).computePlan(); - - { - assertThat(plan.assignments(deployment1).isPresent(), is(true)); - Map assignments = plan.assignments(deployment1).get(); - assertThat(assignments.get(node1), equalTo(1)); - assertThat(assignments.get(node2), equalTo(1)); - assertThat(assignments.get(node3), is(nullValue())); - assertThat(assignments.get(node4), is(nullValue())); - } - { - assertThat(plan.assignments(deployment2).isPresent(), is(true)); - Map assignments = plan.assignments(deployment2).get(); - assertThat(assignments.get(node1), equalTo(1)); - assertThat(assignments.get(node2), equalTo(1)); - assertThat(assignments.get(node3), is(nullValue())); - assertThat(assignments.get(node4), is(nullValue())); - } - { - assertThat(plan.assignments(deployment3).isPresent(), is(true)); - Map assignments = plan.assignments(deployment3).get(); - assertThat(assignments.get(node1), is(nullValue())); - assertThat(assignments.get(node2), is(nullValue())); - // Will either be on node 3 or 4 - Node assignedNode = assignments.get(node3) != null ? node3 : node4; - Node otherNode = assignedNode.equals(node3) ? node4 : node3; - assertThat(assignments.get(assignedNode), equalTo(1)); - assertThat(assignments.get(otherNode), is(nullValue())); - } - { - assertThat(plan.assignments(deployment4).isPresent(), is(true)); - Map assignments = plan.assignments(deployment4).get(); - assertThat(assignments.get(node1), is(nullValue())); - assertThat(assignments.get(node2), is(nullValue())); - // Will either be on node 3 or 4 - Node assignedNode = assignments.get(node3) != null ? node3 : node4; - Node otherNode = assignedNode.equals(node3) ? node4 : node3; - assertThat(assignments.get(assignedNode), equalTo(2)); - assertThat(assignments.get(otherNode), is(nullValue())); - } - } - - public void testMultipleModelsAndNodesWithSingleSolution_NewMemoryFields() { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 7); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(800).getBytes(), 7); - Node node3 = new Node("n_3", ByteSizeValue.ofMb(900).getBytes(), 2); - Node node4 = new Node("n_4", ByteSizeValue.ofMb(900).getBytes(), 2); - Deployment deployment1 = new Deployment( - "m_1", - ByteSizeValue.ofMb(50).getBytes(), - 2, - 4, - Map.of(), - 0, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(50).getBytes() - ); - Deployment deployment2 = new Deployment( - "m_2", - ByteSizeValue.ofMb(50).getBytes(), - 2, - 3, - Map.of(), - 0, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(50).getBytes() - ); - Deployment deployment3 = new Deployment( - "m_3", - ByteSizeValue.ofMb(50).getBytes(), - 1, - 2, - Map.of(), - 0, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(50).getBytes() - ); - Deployment deployment4 = new Deployment( - "m_4", - ByteSizeValue.ofMb(50).getBytes(), - 2, - 1, - Map.of(), - 0, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(50).getBytes() - ); + Node node1 = new Node("n_1", 100, 7); + Node node2 = new Node("n_2", 100, 7); + Node node3 = new Node("n_3", 100, 2); + Node node4 = new Node("n_4", 100, 2); + Deployment deployment1 = new Deployment("m_1", 50, 2, 4, Map.of(), 0); + AssignmentPlan.Deployment deployment2 = new Deployment("m_2", 50, 2, 3, Map.of(), 0); + Deployment deployment3 = new AssignmentPlan.Deployment("m_3", 50, 1, 2, Map.of(), 0); + AssignmentPlan.Deployment deployment4 = new AssignmentPlan.Deployment("m_4", 50, 2, 1, Map.of(), 0); AssignmentPlan plan = new AssignmentPlanner( List.of(node1, node2, node3, node4), @@ -401,53 +173,10 @@ public void testMultipleModelsAndNodesWithSingleSolution_NewMemoryFields() { } public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerAllocation() { - Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 10, 3, Map.of(), 0, 0, 0); - // Single node - { - Node node = new Node("n_1", scaleNodeSize(100), 4); - AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); - assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); - Map assignments = assignmentPlan.assignments(deployment).get(); - assertThat(assignments.get(node), equalTo(1)); - } - // Two nodes - { - Node node1 = new Node("n_1", scaleNodeSize(100), 4); - Node node2 = new Node("n_2", scaleNodeSize(100), 8); - AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); - assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); - Map assignments = assignmentPlan.assignments(deployment).get(); - assertThat(assignments.get(node1), equalTo(1)); - assertThat(assignments.get(node2), equalTo(2)); - } - // Three nodes - { - Node node1 = new Node("n_1", scaleNodeSize(100), 4); - Node node2 = new Node("n_2", scaleNodeSize(100), 7); - Node node3 = new Node("n_3", scaleNodeSize(100), 15); - AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment)).computePlan(); - assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); - Map assignments = assignmentPlan.assignments(deployment).get(); - assertThat(assignments.get(node1), equalTo(1)); - assertThat(assignments.get(node2), equalTo(2)); - assertThat(assignments.get(node3), equalTo(5)); - } - } - - public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerAllocation_NewMemoryFields() { - Deployment deployment = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(50).getBytes(), - 10, - 3, - Map.of(), - 0, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(50).getBytes() - ); + Deployment deployment = new AssignmentPlan.Deployment("m_1", 30, 10, 3, Map.of(), 0); // Single node { - Node node = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 4); + Node node = new Node("n_1", 100, 4); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); Map assignments = assignmentPlan.assignments(deployment).get(); @@ -455,8 +184,8 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerA } // Two nodes { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 4); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(800).getBytes(), 8); + Node node1 = new Node("n_1", 100, 4); + Node node2 = new Node("n_2", 100, 8); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); Map assignments = assignmentPlan.assignments(deployment).get(); @@ -465,9 +194,9 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerA } // Three nodes { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 4); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(800).getBytes(), 7); - Node node3 = new Node("n_3", ByteSizeValue.ofMb(800).getBytes(), 15); + Node node1 = new Node("n_1", 100, 4); + Node node2 = new Node("n_2", 100, 7); + Node node3 = new Node("n_3", 100, 15); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); Map assignments = assignmentPlan.assignments(deployment).get(); @@ -478,17 +207,8 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerA } public void testModelWithPreviousAssignmentAndNoMoreCoresAvailable() { - Node node = new Node("n_1", scaleNodeSize(100), 4); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(30).getBytes(), - 4, - 1, - Map.of("n_1", 4), - 0, - 0, - 0 - ); + Node node = new Node("n_1", 100, 4); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 30, 4, 1, Map.of("n_1", 4), 0); AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertThat(plan.assignments(deployment).isPresent(), is(true)); @@ -497,117 +217,26 @@ public void testModelWithPreviousAssignmentAndNoMoreCoresAvailable() { public void testFullCoreUtilization_GivenModelsWithSingleThreadPerAllocation() { List nodes = List.of( - new Node("n_1", ByteSizeValue.ofGb(18).getBytes(), 8), - new Node("n_2", ByteSizeValue.ofGb(18).getBytes(), 8), - new Node("n_3", ByteSizeValue.ofGb(18).getBytes(), 8), - new Node("n_4", ByteSizeValue.ofGb(18).getBytes(), 8), - new Node("n_5", ByteSizeValue.ofGb(64).getBytes(), 16), - new Node("n_6", ByteSizeValue.ofGb(32).getBytes(), 16) + new Node("n_1", ByteSizeValue.ofGb(6).getBytes(), 8), + new Node("n_2", ByteSizeValue.ofGb(6).getBytes(), 8), + new Node("n_3", ByteSizeValue.ofGb(6).getBytes(), 8), + new Node("n_4", ByteSizeValue.ofGb(6).getBytes(), 8), + new Node("n_5", ByteSizeValue.ofGb(16).getBytes(), 16), + new Node("n_6", ByteSizeValue.ofGb(8).getBytes(), 16) ); List deployments = List.of( - new Deployment("m_1", ByteSizeValue.ofGb(4).getBytes(), 10, 1, Map.of("n_1", 5), 0, 0, 0), - new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of("n_3", 2), 0, 0, 0), - new AssignmentPlan.Deployment("m_3", ByteSizeValue.ofGb(3).getBytes(), 3, 1, Map.of(), 0, 0, 0), - new Deployment("m_4", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of("n_3", 2), 0, 0, 0), - new Deployment("m_5", ByteSizeValue.ofGb(6).getBytes(), 2, 1, Map.of(), 0, 0, 0), - new Deployment("m_6", ByteSizeValue.ofGb(1).getBytes(), 12, 1, Map.of(), 0, 0, 0), - new AssignmentPlan.Deployment("m_7", ByteSizeValue.ofGb(1).getBytes() / 2, 12, 1, Map.of("n_2", 6), 0, 0, 0), - new Deployment("m_8", ByteSizeValue.ofGb(2).getBytes(), 4, 1, Map.of(), 0, 0, 0), - new Deployment("m_9", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of(), 0, 0, 0), - new AssignmentPlan.Deployment("m_10", ByteSizeValue.ofGb(7).getBytes(), 7, 1, Map.of(), 0, 0, 0), - new Deployment("m_11", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of(), 0, 0, 0), - new Deployment("m_12", ByteSizeValue.ofGb(1).getBytes(), 10, 1, Map.of(), 0, 0, 0) - ); - - AssignmentPlan assignmentPlan = new AssignmentPlanner(nodes, deployments).computePlan(); - - int usedCores = 0; - for (AssignmentPlan.Deployment m : deployments) { - Map assignments = assignmentPlan.assignments(m).orElse(Map.of()); - usedCores += assignments.values().stream().mapToInt(Integer::intValue).sum(); - } - assertThat(usedCores, equalTo(64)); - - assertPreviousAssignmentsAreSatisfied(deployments, assignmentPlan); - } - - public void testFullCoreUtilization_GivenModelsWithSingleThreadPerAllocation_NewMemoryFields() { - List nodes = List.of( - new Node("n_1", ByteSizeValue.ofGb(18).getBytes(), 8), - new Node("n_2", ByteSizeValue.ofGb(18).getBytes(), 8), - new Node("n_3", ByteSizeValue.ofGb(18).getBytes(), 8), - new Node("n_4", ByteSizeValue.ofGb(18).getBytes(), 8), - new Node("n_5", ByteSizeValue.ofGb(64).getBytes(), 16), - new Node("n_6", ByteSizeValue.ofGb(32).getBytes(), 16) - ); - // Use mix of old and new memory fields - List deployments = List.of( - new Deployment( - "m_1", - ByteSizeValue.ofMb(100).getBytes(), - 10, - 1, - Map.of("n_1", 5), - 0, - ByteSizeValue.ofMb(400).getBytes(), - ByteSizeValue.ofMb(100).getBytes() - ), - new Deployment("m_2", ByteSizeValue.ofMb(100).getBytes(), 3, 1, Map.of("n_3", 2), 0, 0, 0), - new Deployment( - "m_3", - ByteSizeValue.ofMb(50).getBytes(), - 3, - 1, - Map.of(), - 0, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(50).getBytes() - ), - new Deployment( - "m_4", - ByteSizeValue.ofMb(50).getBytes(), - 4, - 1, - Map.of("n_3", 2), - 0, - ByteSizeValue.ofMb(400).getBytes(), - ByteSizeValue.ofMb(100).getBytes() - ), - new Deployment( - "m_5", - ByteSizeValue.ofMb(500).getBytes(), - 2, - 1, - Map.of(), - 0, - ByteSizeValue.ofMb(800).getBytes(), - ByteSizeValue.ofMb(100).getBytes() - ), - new Deployment( - "m_6", - ByteSizeValue.ofMb(50).getBytes(), - 12, - 1, - Map.of(), - 0, - ByteSizeValue.ofMb(50).getBytes(), - ByteSizeValue.ofMb(20).getBytes() - ), - new Deployment( - "m_7", - ByteSizeValue.ofMb(50).getBytes(), - 12, - 1, - Map.of("n_2", 6), - 0, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(50).getBytes() - ), - new Deployment("m_8", ByteSizeValue.ofGb(2).getBytes(), 4, 1, Map.of(), 0, 0, 0), - new Deployment("m_9", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of(), 0, 0, 0), - new Deployment("m_10", ByteSizeValue.ofGb(7).getBytes(), 7, 1, Map.of(), 0, 0, 0), - new Deployment("m_11", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of(), 0, 0, 0), - new Deployment("m_12", ByteSizeValue.ofGb(1).getBytes(), 10, 1, Map.of(), 0, 0, 0) + new Deployment("m_1", ByteSizeValue.ofGb(4).getBytes(), 10, 1, Map.of("n_1", 5), 0), + new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of("n_3", 2), 0), + new AssignmentPlan.Deployment("m_3", ByteSizeValue.ofGb(3).getBytes(), 3, 1, Map.of(), 0), + new Deployment("m_4", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of("n_3", 2), 0), + new Deployment("m_5", ByteSizeValue.ofGb(6).getBytes(), 2, 1, Map.of(), 0), + new Deployment("m_6", ByteSizeValue.ofGb(1).getBytes(), 12, 1, Map.of(), 0), + new AssignmentPlan.Deployment("m_7", ByteSizeValue.ofGb(1).getBytes() / 2, 12, 1, Map.of("n_2", 6), 0), + new Deployment("m_8", ByteSizeValue.ofGb(2).getBytes(), 4, 1, Map.of(), 0), + new Deployment("m_9", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of(), 0), + new AssignmentPlan.Deployment("m_10", ByteSizeValue.ofGb(7).getBytes(), 7, 1, Map.of(), 0), + new Deployment("m_11", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of(), 0), + new Deployment("m_12", ByteSizeValue.ofGb(1).getBytes(), 10, 1, Map.of(), 0) ); AssignmentPlan assignmentPlan = new AssignmentPlanner(nodes, deployments).computePlan(); @@ -668,9 +297,6 @@ public void testRandomBenchmark() { StopWatch stopWatch = new StopWatch(); stopWatch.start(); AssignmentPlan assignmentPlan = solver.computePlan(); - for (Node node : nodes) { - assertThat(assignmentPlan.getRemainingNodeMemory(node.id()), greaterThanOrEqualTo(0L)); - } stopWatch.stop(); Quality quality = computeQuality(nodes, deployments, assignmentPlan); @@ -710,16 +336,7 @@ public void testPreviousAssignmentsGetAtLeastAsManyAllocationsAfterAddingNewMode .stream() .collect(Collectors.toMap(e -> e.getKey().id(), Map.Entry::getValue)); previousModelsPlusNew.add( - new AssignmentPlan.Deployment( - m.id(), - m.memoryBytes(), - m.allocations(), - m.threadsPerAllocation(), - previousAssignments, - 0, - 0, - 0 - ) + new AssignmentPlan.Deployment(m.id(), m.memoryBytes(), m.allocations(), m.threadsPerAllocation(), previousAssignments, 0) ); } previousModelsPlusNew.add(randomModel("new")); @@ -730,20 +347,18 @@ public void testPreviousAssignmentsGetAtLeastAsManyAllocationsAfterAddingNewMode } public void testGivenLargerModelWithPreviousAssignmentsAndSmallerModelWithoutAssignments() { - Node node1 = new Node("n_1", scaleNodeSize(ByteSizeValue.ofGb(2).getMb()), 2); - Node node2 = new Node("n_2", scaleNodeSize(ByteSizeValue.ofGb(2).getMb()), 2); - Node node3 = new Node("n_3", scaleNodeSize(ByteSizeValue.ofGb(2).getMb()), 2); + Node node1 = new Node("n_1", ByteSizeValue.ofGb(2).getBytes(), 2); + Node node2 = new Node("n_2", ByteSizeValue.ofGb(2).getBytes(), 2); + Node node3 = new Node("n_3", ByteSizeValue.ofGb(2).getBytes(), 2); Deployment deployment1 = new AssignmentPlan.Deployment( "m_1", ByteSizeValue.ofMb(1200).getBytes(), 3, 1, Map.of("n_1", 2, "n_2", 1), - 0, - 0, 0 ); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 2, 1, Map.of(), 0, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 2, 1, Map.of(), 0); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment1, deployment2)) .computePlan(); assertThat(assignmentPlan.getRemainingNodeMemory("n_1"), greaterThanOrEqualTo(0L)); @@ -766,17 +381,15 @@ public void testGivenLargerModelWithPreviousAssignmentsAndSmallerModelWithoutAss } public void testModelWithoutCurrentAllocationsGetsAssignedIfAllocatedPreviously() { - Node node1 = new Node("n_1", ByteSizeValue.ofGb(6).getBytes(), 2); - Node node2 = new Node("n_2", ByteSizeValue.ofGb(6).getBytes(), 2); + Node node1 = new Node("n_1", ByteSizeValue.ofGb(4).getBytes(), 2); + Node node2 = new Node("n_2", ByteSizeValue.ofGb(4).getBytes(), 2); AssignmentPlan.Deployment deployment1 = new Deployment( "m_1", ByteSizeValue.ofMb(1200).getBytes(), 3, 1, Map.of("n_1", 2, "n_2", 1), - 3, - 0, - 0 + 3 ); AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment( "m_2", @@ -784,84 +397,35 @@ public void testModelWithoutCurrentAllocationsGetsAssignedIfAllocatedPreviously( 1, 2, Map.of(), - 1, - 0, - 0 + 1 ); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1, deployment2)).computePlan(); Map> indexedBasedPlan = convertToIdIndexed(assignmentPlan); assertThat(indexedBasedPlan.keySet(), hasItems("m_1", "m_2")); - if (indexedBasedPlan.get("m_2").containsKey("n_1")) { - assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_2", 2))); - assertThat(indexedBasedPlan.get("m_2"), equalTo(Map.of("n_1", 1))); - } else { - assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2))); - assertThat(indexedBasedPlan.get("m_2"), equalTo(Map.of("n_2", 1))); - } + assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2))); + assertThat(indexedBasedPlan.get("m_2"), equalTo(Map.of("n_2", 1))); assertThat(assignmentPlan.getRemainingNodeMemory("n_1"), greaterThanOrEqualTo(0L)); assertThat(assignmentPlan.getRemainingNodeMemory("n_2"), greaterThanOrEqualTo(0L)); } public void testGivenPreviouslyAssignedModels_CannotAllBeAllocated() { - Node node1 = new Node("n_1", scaleNodeSize(ByteSizeValue.ofGb(2).getMb()), 2); - AssignmentPlan.Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(1200).getBytes(), 1, 1, Map.of(), 1, 0, 0); - AssignmentPlan.Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 1, 1, Map.of(), 1, 0, 0); + Node node1 = new Node("n_1", ByteSizeValue.ofGb(2).getBytes(), 2); + AssignmentPlan.Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(1200).getBytes(), 1, 1, Map.of(), 1); + AssignmentPlan.Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 1, 1, Map.of(), 1); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1), List.of(deployment1, deployment2)).computePlan(); assertThat(assignmentPlan.countPreviouslyAssignedModelsThatAreStillAssigned(), equalTo(1L)); } - public void testGivenClusterResize_AllocationShouldNotExceedMemoryConstraints() { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(1840).getBytes(), 2); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(2580).getBytes(), 2); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, 0, 0); - Deployment deployment2 = new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, 0, 0); - - // First only start m_1 - AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1)).computePlan(); - - Map> indexedBasedPlan = convertToIdIndexed(assignmentPlan); - assertThat(indexedBasedPlan.keySet(), hasItems("m_1")); - assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2))); - - // Then start m_2 - assignmentPlan = new AssignmentPlanner( - List.of(node1, node2), - Stream.concat(createModelsFromPlan(assignmentPlan).stream(), Stream.of(deployment2)).toList() - ).computePlan(); - - indexedBasedPlan = convertToIdIndexed(assignmentPlan); - assertThat(indexedBasedPlan.keySet(), hasItems("m_1", "m_2")); - assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2))); - assertThat(indexedBasedPlan.get("m_2"), equalTo(Map.of("n_2", 1))); - - // Then start m_3 - assignmentPlan = new AssignmentPlanner( - List.of(node1, node2), - Stream.concat(createModelsFromPlan(assignmentPlan).stream(), Stream.of(deployment3)).toList() - ).computePlan(); - - indexedBasedPlan = convertToIdIndexed(assignmentPlan); - assertThat(indexedBasedPlan.keySet(), hasItems("m_1", "m_2", "m_3")); - assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2))); - assertThat(indexedBasedPlan.get("m_2"), equalTo(Map.of("n_2", 1))); - assertThat(indexedBasedPlan.get("m_3"), equalTo(Map.of("n_2", 1))); - - // First, one node goes away. - assignmentPlan = new AssignmentPlanner(List.of(node1), createModelsFromPlan(assignmentPlan)).computePlan(); - assertThat(assignmentPlan.getRemainingNodeMemory("n_1"), greaterThanOrEqualTo(0L)); - } - public void testGivenClusterResize_ShouldAllocateEachModelAtLeastOnce() { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(2600).getBytes(), 2); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(2600).getBytes(), 2); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, 0, 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1200).getBytes(), 2); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1200).getBytes(), 2); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0); // First only start m_1 AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1)).computePlan(); @@ -894,8 +458,8 @@ public void testGivenClusterResize_ShouldAllocateEachModelAtLeastOnce() { assertThat(indexedBasedPlan.get("m_3"), equalTo(Map.of("n_2", 1))); // Now the cluster starts getting resized. - Node node3 = new Node("n_3", ByteSizeValue.ofMb(2600).getBytes(), 2); - Node node4 = new Node("n_4", ByteSizeValue.ofMb(2600).getBytes(), 2); + Node node3 = new Node("n_3", ByteSizeValue.ofMb(2400).getBytes(), 2); + Node node4 = new Node("n_4", ByteSizeValue.ofMb(2400).getBytes(), 2); // First, one node goes away. assignmentPlan = new AssignmentPlanner(List.of(node1), createModelsFromPlan(assignmentPlan)).computePlan(); @@ -928,65 +492,11 @@ public void testGivenClusterResize_ShouldAllocateEachModelAtLeastOnce() { public void testGivenClusterResize_ShouldRemoveAllocatedModels() { // Ensure that plan is removing previously allocated models if not enough memory is available - Node node1 = new Node("n_1", ByteSizeValue.ofMb(1840).getBytes(), 2); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(2580).getBytes(), 2); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 1, 1, Map.of(), 0, 0, 0); - - // Create a plan where all deployments are assigned at least once - AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) - .computePlan(); - Map> indexedBasedPlan = convertToIdIndexed(assignmentPlan); - assertThat(indexedBasedPlan.keySet(), hasItems("m_1", "m_2", "m_3")); - assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2))); - assertThat(indexedBasedPlan.get("m_2"), equalTo(Map.of("n_2", 1))); - assertThat(indexedBasedPlan.get("m_3"), equalTo(Map.of("n_2", 1))); - assertThat(assignmentPlan.getRemainingNodeMemory(node1.id()), greaterThanOrEqualTo(0L)); - assertThat(assignmentPlan.getRemainingNodeMemory(node2.id()), greaterThanOrEqualTo(0L)); - - // Now the cluster starts getting resized. Ensure that resources are not over-allocated. - assignmentPlan = new AssignmentPlanner(List.of(node1), createModelsFromPlan(assignmentPlan)).computePlan(); - assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2))); - assertThat(assignmentPlan.getRemainingNodeMemory(node1.id()), greaterThanOrEqualTo(0L)); - assertThat(assignmentPlan.getRemainingNodeCores(node1.id()), greaterThanOrEqualTo(0)); - - } - - public void testGivenClusterResize_ShouldRemoveAllocatedModels_NewMemoryFields() { - // Ensure that plan is removing previously allocated models if not enough memory is available - Node node1 = new Node("n_1", ByteSizeValue.ofMb(700).getBytes(), 2); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 2); - Deployment deployment1 = new Deployment( - "m_1", - ByteSizeValue.ofMb(100).getBytes(), - 2, - 1, - Map.of(), - 0, - ByteSizeValue.ofMb(400).getBytes(), - ByteSizeValue.ofMb(100).getBytes() - ); - Deployment deployment2 = new Deployment( - "m_2", - ByteSizeValue.ofMb(100).getBytes(), - 1, - 1, - Map.of(), - 0, - ByteSizeValue.ofMb(400).getBytes(), - ByteSizeValue.ofMb(150).getBytes() - ); - Deployment deployment3 = new Deployment( - "m_3", - ByteSizeValue.ofMb(50).getBytes(), - 1, - 1, - Map.of(), - 0, - ByteSizeValue.ofMb(250).getBytes(), - ByteSizeValue.ofMb(50).getBytes() - ); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1200).getBytes(), 2); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1200).getBytes(), 2); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 1, 1, Map.of(), 0); // Create a plan where all deployments are assigned at least once AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) @@ -1026,9 +536,7 @@ public static List createModelsFromPlan(AssignmentPlan plan) { m.allocations(), m.threadsPerAllocation(), currentAllocations, - Math.max(m.maxAssignedAllocations(), totalAllocations), - 0, - 0 + Math.max(m.maxAssignedAllocations(), totalAllocations) ) ); } @@ -1071,7 +579,7 @@ public static List randomNodes(int scale, String nodeIdPrefix) { for (int i = 0; i < 1 + 3 * scale; i++) { int cores = randomIntBetween(2, 32); long memBytesPerCore = randomFrom(memBytesPerCoreValues); - nodes.add(new Node(nodeIdPrefix + "n_" + i, scaleNodeSize(ByteSizeValue.ofBytes(cores * memBytesPerCore).getMb()), cores)); + nodes.add(new Node(nodeIdPrefix + "n_" + i, cores * memBytesPerCore, cores)); } return nodes; } @@ -1086,30 +594,14 @@ public static List randomModels(int scale, double load) { public static Deployment randomModel(String idSuffix) { int allocations = randomIntBetween(1, 32); - // randomly choose between old and new memory fields format - if (randomBoolean()) { - return new Deployment( - "m_" + idSuffix, - randomLongBetween(ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofGb(10).getBytes()), - randomIntBetween(1, 32), - randomIntBetween(1, 4), - Map.of(), - 0, - 0, - 0 - ); - } else { - return new Deployment( - "m_" + idSuffix, - randomLongBetween(ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofGb(1).getBytes()), - randomIntBetween(1, 32), - randomIntBetween(1, 4), - Map.of(), - 0, - randomLongBetween(ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofGb(1).getBytes()), - randomLongBetween(ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofGb(1).getBytes()) - ); - } + return new Deployment( + "m_" + idSuffix, + randomLongBetween(ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofGb(10).getBytes()), + randomIntBetween(1, 32), + randomIntBetween(1, 4), + Map.of(), + 0 + ); } public static void assertPreviousAssignmentsAreSatisfied(List deployments, AssignmentPlan assignmentPlan) { @@ -1136,7 +628,7 @@ private void runTooManyNodesAndModels(int nodesSize, int modelsSize) { } List deployments = new ArrayList<>(); for (int i = 0; i < modelsSize; i++) { - deployments.add(new Deployment("m_" + i, ByteSizeValue.ofMb(200).getBytes(), 2, 1, Map.of(), 0, 0, 0)); + deployments.add(new Deployment("m_" + i, ByteSizeValue.ofMb(200).getBytes(), 2, 1, Map.of(), 0)); } // Check plan is computed without OOM exception diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java index c45ce36394109..4a9b01e535d88 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.ml.inference.assignment.planning; -import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Deployment; import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Node; @@ -15,6 +14,7 @@ import java.util.List; import java.util.Map; +import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; @@ -22,179 +22,77 @@ public class PreserveAllAllocationsTests extends ESTestCase { public void testGivenNoPreviousAssignments() { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 4, Map.of(), 0, 0, 0); + Node node1 = new Node("n_1", 100, 4); + Node node2 = new Node("n_2", 100, 4); + Deployment deployment1 = new Deployment("m_1", 30, 2, 1, Map.of(), 0); + Deployment deployment2 = new Deployment("m_2", 30, 2, 4, Map.of(), 0); PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations( List.of(node1, node2), List.of(deployment1, deployment2) ); + + List nodesPreservingAllocations = preserveAllAllocations.nodesPreservingAllocations(); + assertThat(nodesPreservingAllocations, contains(node1, node2)); + + List modelsPreservingAllocations = preserveAllAllocations.modelsPreservingAllocations(); + assertThat(modelsPreservingAllocations, contains(deployment1, deployment2)); } public void testGivenPreviousAssignments() { - { - // old memory format - Node node1 = new Node("n_1", ByteSizeValue.ofMb(640).getBytes(), 8); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(640).getBytes(), 8); - Deployment deployment1 = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(30).getBytes(), - 2, - 1, - Map.of("n_1", 1), - 1, - 0, - 0 - ); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(50).getBytes(), 6, 4, Map.of("n_1", 1, "n_2", 2), 3, 0, 0); - PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations( - List.of(node1, node2), - List.of(deployment1, deployment2) - ); - - List nodesPreservingAllocations = preserveAllAllocations.nodesPreservingAllocations(); - assertThat(nodesPreservingAllocations, hasSize(2)); - - assertThat(nodesPreservingAllocations.get(0).id(), equalTo("n_1")); - // 640 - [(2*30 + 240) + (2*50 + 240)] = 0: deployments use 640 MB on the node 1 - assertThat(nodesPreservingAllocations.get(0).availableMemoryBytes(), equalTo(0L)); - // 8 - (1*1+1*4) = 3 : deployments use 5 cores on the node - assertThat(nodesPreservingAllocations.get(0).cores(), equalTo(3)); - - assertThat(nodesPreservingAllocations.get(1).id(), equalTo("n_2")); - // 640 - (50*2+240) = 300 : deployments use 340MB on the node - assertThat(nodesPreservingAllocations.get(1).availableMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); - // 8 - (2*4) = 0 : preserving all allocation2 of deployment 2 should use 8 cores on the node - assertThat(nodesPreservingAllocations.get(1).cores(), equalTo(0)); - - List modelsPreservingAllocations = preserveAllAllocations.modelsPreservingAllocations(); - assertThat(modelsPreservingAllocations, hasSize(2)); - - assertThat(modelsPreservingAllocations.get(0).id(), equalTo("m_1")); - assertThat(modelsPreservingAllocations.get(0).memoryBytes(), equalTo(ByteSizeValue.ofMb(30).getBytes())); - assertThat(modelsPreservingAllocations.get(0).allocations(), equalTo(1)); - assertThat(modelsPreservingAllocations.get(0).threadsPerAllocation(), equalTo(1)); - assertThat(modelsPreservingAllocations.get(0).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0))); - - assertThat(modelsPreservingAllocations.get(1).id(), equalTo("m_2")); - assertThat(modelsPreservingAllocations.get(1).memoryBytes(), equalTo(ByteSizeValue.ofMb(50).getBytes())); - assertThat(modelsPreservingAllocations.get(1).allocations(), equalTo(3)); - assertThat(modelsPreservingAllocations.get(1).threadsPerAllocation(), equalTo(4)); - assertThat(modelsPreservingAllocations.get(1).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0, "n_2", 0))); - - // Now we have a plan with 2 deployments assigned to 2 nodes. - // Note that deployment 1 has already 1 allocation on node 1, and it gets 2 more. It's more than 2 allocations defined during - // initialization of deployment1, but we don't care at this point. - AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) - .assignModelToNode(deployment1, node1, 2) - .build(); - assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 2))); - assertThat(plan.assignments(deployment2).isEmpty(), is(true)); - - plan = preserveAllAllocations.mergePreservedAllocations(plan); - assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 3))); - assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node1, 1, node2, 2))); - - // Node 1 already had deployments 1 and 2 assigned to it so adding more allocation doesn't change memory usage. - assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(0L)); - // 8 - ((1*1+1*4) + 2*1) = 1 : deployments use 7 cores on the node - assertThat(plan.getRemainingNodeCores("n_1"), equalTo(1)); - // Nothing changed for Node 2 - assertThat(plan.getRemainingNodeMemory("n_2"), equalTo(ByteSizeValue.ofMb(300).getBytes())); - // Nothing changed for Node 2 - assertThat(plan.getRemainingNodeCores("n_2"), equalTo(0)); - } - { - // new memory format - Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 8); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 8); - Deployment deployment1 = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(30).getBytes(), - 2, - 1, - Map.of("n_1", 1), - 1, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(10).getBytes() - ); - Deployment deployment2 = new Deployment( - "m_2", - ByteSizeValue.ofMb(50).getBytes(), - 6, - 4, - Map.of("n_1", 1, "n_2", 2), - 3, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(10).getBytes() - ); - PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations( - List.of(node1, node2), - List.of(deployment1, deployment2) - ); - - List nodesPreservingAllocations = preserveAllAllocations.nodesPreservingAllocations(); - assertThat(nodesPreservingAllocations, hasSize(2)); - - assertThat(nodesPreservingAllocations.get(0).id(), equalTo("n_1")); - // 1000 - [(30 + 300+10) + (50 + 300 + 10)] = 300: deployments use 700 MB on the node 1 - assertThat(nodesPreservingAllocations.get(0).availableMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); - // 8 - (1*1+1*4) = 3 : deployments use 5 cores on the node - assertThat(nodesPreservingAllocations.get(0).cores(), equalTo(3)); - - assertThat(nodesPreservingAllocations.get(1).id(), equalTo("n_2")); - // 1000 - (50 + 300 + 2*10) = 630 : deployments use 370MB on the node - assertThat(nodesPreservingAllocations.get(1).availableMemoryBytes(), equalTo(ByteSizeValue.ofMb(630).getBytes())); - // 8 - (2*4) = 0 : preserving all allocation2 of deployment 2 should use 8 cores on the node - assertThat(nodesPreservingAllocations.get(1).cores(), equalTo(0)); - - List modelsPreservingAllocations = preserveAllAllocations.modelsPreservingAllocations(); - assertThat(modelsPreservingAllocations, hasSize(2)); - - assertThat(modelsPreservingAllocations.get(0).id(), equalTo("m_1")); - assertThat(modelsPreservingAllocations.get(0).memoryBytes(), equalTo(ByteSizeValue.ofMb(30).getBytes())); - assertThat(modelsPreservingAllocations.get(0).perDeploymentMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); - assertThat(modelsPreservingAllocations.get(0).perAllocationMemoryBytes(), equalTo(ByteSizeValue.ofMb(10).getBytes())); - assertThat(modelsPreservingAllocations.get(0).allocations(), equalTo(1)); - assertThat(modelsPreservingAllocations.get(0).threadsPerAllocation(), equalTo(1)); - assertThat(modelsPreservingAllocations.get(0).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0))); - - assertThat(modelsPreservingAllocations.get(1).id(), equalTo("m_2")); - assertThat(modelsPreservingAllocations.get(1).memoryBytes(), equalTo(ByteSizeValue.ofMb(50).getBytes())); - assertThat(modelsPreservingAllocations.get(1).perDeploymentMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); - assertThat(modelsPreservingAllocations.get(1).perAllocationMemoryBytes(), equalTo(ByteSizeValue.ofMb(10).getBytes())); - assertThat(modelsPreservingAllocations.get(1).allocations(), equalTo(3)); - assertThat(modelsPreservingAllocations.get(1).threadsPerAllocation(), equalTo(4)); - assertThat(modelsPreservingAllocations.get(1).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0, "n_2", 0))); - - // Now we have a plan with 2 deployments assigned to 2 nodes. - // Note that deployment 1 has already 1 allocation on node 1, and it gets 2 more. It's more than 2 allocations defined during - // initialization of deployment1, but we don't care at this point. - AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) - .assignModelToNode(deployment1, node1, 2) - .build(); - assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 2))); - assertThat(plan.assignments(deployment2).isEmpty(), is(true)); - - plan = preserveAllAllocations.mergePreservedAllocations(plan); - assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 3))); - assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node1, 1, node2, 2))); - - // 1000 - ((30 + 300 + 3*10) + (50 + 300 + 10)) = 280 : deployments use 720 MB on the node 1 - assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(ByteSizeValue.ofMb(280).getBytes())); - // 8 - ((1*1+1*4) + 2*1) = 1 : deployments use 7 cores on the node - assertThat(plan.getRemainingNodeCores("n_1"), equalTo(1)); - // Nothing changed for Node 2 - assertThat(plan.getRemainingNodeMemory("n_2"), equalTo(ByteSizeValue.ofMb(630).getBytes())); - // Nothing changed for Node 2 - assertThat(plan.getRemainingNodeCores("n_2"), equalTo(0)); - } + Node node1 = new Node("n_1", 100, 8); + Node node2 = new Node("n_2", 100, 8); + Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 30, 2, 1, Map.of("n_1", 1), 1); + Deployment deployment2 = new Deployment("m_2", 50, 6, 4, Map.of("n_1", 1, "n_2", 2), 3); + PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations( + List.of(node1, node2), + List.of(deployment1, deployment2) + ); + + List nodesPreservingAllocations = preserveAllAllocations.nodesPreservingAllocations(); + assertThat(nodesPreservingAllocations, hasSize(2)); + + assertThat(nodesPreservingAllocations.get(0).id(), equalTo("n_1")); + assertThat(nodesPreservingAllocations.get(0).availableMemoryBytes(), equalTo(20L)); + assertThat(nodesPreservingAllocations.get(0).cores(), equalTo(3)); + + assertThat(nodesPreservingAllocations.get(1).id(), equalTo("n_2")); + assertThat(nodesPreservingAllocations.get(1).availableMemoryBytes(), equalTo(50L)); + assertThat(nodesPreservingAllocations.get(1).cores(), equalTo(0)); + + List modelsPreservingAllocations = preserveAllAllocations.modelsPreservingAllocations(); + assertThat(modelsPreservingAllocations, hasSize(2)); + + assertThat(modelsPreservingAllocations.get(0).id(), equalTo("m_1")); + assertThat(modelsPreservingAllocations.get(0).memoryBytes(), equalTo(30L)); + assertThat(modelsPreservingAllocations.get(0).allocations(), equalTo(1)); + assertThat(modelsPreservingAllocations.get(0).threadsPerAllocation(), equalTo(1)); + assertThat(modelsPreservingAllocations.get(0).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0))); + + assertThat(modelsPreservingAllocations.get(1).id(), equalTo("m_2")); + assertThat(modelsPreservingAllocations.get(1).memoryBytes(), equalTo(50L)); + assertThat(modelsPreservingAllocations.get(1).allocations(), equalTo(3)); + assertThat(modelsPreservingAllocations.get(1).threadsPerAllocation(), equalTo(4)); + assertThat(modelsPreservingAllocations.get(1).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0, "n_2", 0))); + + AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) + .assignModelToNode(deployment1, node1, 2) + .build(); + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 2))); + assertThat(plan.assignments(deployment2).isEmpty(), is(true)); + + plan = preserveAllAllocations.mergePreservedAllocations(plan); + + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 3))); + assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node1, 1, node2, 2))); + assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(20L)); + assertThat(plan.getRemainingNodeCores("n_1"), equalTo(1)); + assertThat(plan.getRemainingNodeMemory("n_2"), equalTo(50L)); + assertThat(plan.getRemainingNodeCores("n_2"), equalTo(0)); } public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments() { - Node node = new Node("n_1", ByteSizeValue.ofMb(400).getBytes(), 4); - Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 2, 0, 0); + Node node = new Node("n_1", 100, 4); + AssignmentPlan.Deployment deployment = new Deployment("m_1", 30, 2, 2, Map.of("n_1", 2), 2); PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations(List.of(node), List.of(deployment)); AssignmentPlan plan = AssignmentPlan.builder(List.of(node), List.of(deployment)).build(); @@ -203,7 +101,7 @@ public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments plan = preserveAllAllocations.mergePreservedAllocations(plan); assertThat(plan.assignments(deployment).isPresent(), is(true)); assertThat(plan.assignments(deployment).get(), equalTo(Map.of(node, 2))); - assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(ByteSizeValue.ofMb(100).getBytes())); + assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(70L)); assertThat(plan.getRemainingNodeCores("n_1"), equalTo(0)); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java index f646bf5cb2e9d..d8c3b09422e92 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.ml.inference.assignment.planning; -import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Deployment; import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Node; @@ -23,10 +22,10 @@ public class PreserveOneAllocationTests extends ESTestCase { public void testGivenNoPreviousAssignments() { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4); - Deployment deployment1 = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, 0, 0); - AssignmentPlan.Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 4, Map.of(), 0, 0, 0); + Node node1 = new Node("n_1", 100, 4); + Node node2 = new Node("n_2", 100, 4); + Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 30, 2, 1, Map.of(), 0); + AssignmentPlan.Deployment deployment2 = new Deployment("m_2", 30, 2, 4, Map.of(), 0); PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation(List.of(node1, node2), List.of(deployment1, deployment2)); List nodesPreservingAllocations = preserveOneAllocation.nodesPreservingAllocations(); @@ -37,204 +36,67 @@ public void testGivenNoPreviousAssignments() { } public void testGivenPreviousAssignments() { - { - // old memory format - Node node1 = new Node("n_1", ByteSizeValue.ofMb(640).getBytes(), 8); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(640).getBytes(), 8); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of("n_1", 1), 1, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(50).getBytes(), 6, 4, Map.of("n_1", 1, "n_2", 2), 3, 0, 0); - PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation( - List.of(node1, node2), - List.of(deployment1, deployment2) - ); - - List nodesPreservingAllocations = preserveOneAllocation.nodesPreservingAllocations(); - assertThat(nodesPreservingAllocations, hasSize(2)); - - assertThat(nodesPreservingAllocations.get(0).id(), equalTo("n_1")); - // 640 - [(30*2+240)+(50*2+240)] = 0 : deployments use all memory on the node - assertThat(nodesPreservingAllocations.get(0).availableMemoryBytes(), equalTo(0L)); - // 8 - (1*1+1*4) = 3 : deployments use 5 cores on the node - assertThat(nodesPreservingAllocations.get(0).cores(), equalTo(3)); - - assertThat(nodesPreservingAllocations.get(1).id(), equalTo("n_2")); - // 640 - (50*2+240) = 300 : deployments use 340MB on the node - assertThat(nodesPreservingAllocations.get(1).availableMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); - // 8 - (1*4) = 4 : preserving 1 allocation of deployment 2 should use 4 cores on the node - assertThat(nodesPreservingAllocations.get(1).cores(), equalTo(4)); - - List modelsPreservingAllocations = preserveOneAllocation.modelsPreservingAllocations(); - assertThat(modelsPreservingAllocations, hasSize(2)); - - assertThat(modelsPreservingAllocations.get(0).id(), equalTo("m_1")); - assertThat(modelsPreservingAllocations.get(0).memoryBytes(), equalTo(ByteSizeValue.ofMb(30).getBytes())); - assertThat(modelsPreservingAllocations.get(0).perDeploymentMemoryBytes(), equalTo(ByteSizeValue.ofMb(0).getBytes())); - assertThat(modelsPreservingAllocations.get(0).perAllocationMemoryBytes(), equalTo(ByteSizeValue.ofMb(0).getBytes())); - assertThat(modelsPreservingAllocations.get(0).allocations(), equalTo(1)); - assertThat(modelsPreservingAllocations.get(0).threadsPerAllocation(), equalTo(1)); - assertThat(modelsPreservingAllocations.get(0).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0))); - - assertThat(modelsPreservingAllocations.get(1).id(), equalTo("m_2")); - assertThat(modelsPreservingAllocations.get(1).memoryBytes(), equalTo(ByteSizeValue.ofMb(50).getBytes())); - assertThat(modelsPreservingAllocations.get(1).perDeploymentMemoryBytes(), equalTo(ByteSizeValue.ofMb(0).getBytes())); - assertThat(modelsPreservingAllocations.get(1).perAllocationMemoryBytes(), equalTo(ByteSizeValue.ofMb(0).getBytes())); - assertThat(modelsPreservingAllocations.get(1).allocations(), equalTo(4)); - assertThat(modelsPreservingAllocations.get(1).threadsPerAllocation(), equalTo(4)); - assertThat(modelsPreservingAllocations.get(1).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0, "n_2", 1))); - - // Now we have a plan with 2 deployments assigned to 2 nodes. - // Note that deployment 1 has already 1 allocation on node 1, and it gets 2 more. It's more than 2 allocations defined during - // initialization of deployment1, but we don't care at this point. - AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) - .assignModelToNode(deployment1, node1, 2) - .assignModelToNode(deployment2, node2, 1) - .build(); - assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 2))); - assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node2, 1))); - - plan = preserveOneAllocation.mergePreservedAllocations(plan); - - assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 3))); - assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node1, 1, node2, 2))); - // Node 1 already had deployments 1 and 2 assigned to it so adding more allocation doesn't change memory usage. - assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(0L)); - // 8 - ((1*1+1*4) + 2*1) = 1 : deployments use 7 cores on the node - assertThat(plan.getRemainingNodeCores("n_1"), equalTo(1)); - // Node 2 already had deployment 2 assigned to it so adding more allocation doesn't change memory usage. - assertThat(plan.getRemainingNodeMemory("n_2"), equalTo(ByteSizeValue.ofMb(300).getBytes())); - // 8 - [(1*4) + (1*4)] = 4 : deployment 2 should use all cores on the node - assertThat(plan.getRemainingNodeCores("n_2"), equalTo(0)); - } - { - // new memory format - Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 8); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 8); - Deployment deployment1 = new Deployment( - "m_1", - ByteSizeValue.ofMb(30).getBytes(), - 2, - 1, - Map.of("n_1", 1), - 1, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(10).getBytes() - ); - Deployment deployment2 = new Deployment( - "m_2", - ByteSizeValue.ofMb(50).getBytes(), - 6, - 4, - Map.of("n_1", 1, "n_2", 2), - 3, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(10).getBytes() - ); - PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation( - List.of(node1, node2), - List.of(deployment1, deployment2) - ); - - List nodesPreservingAllocations = preserveOneAllocation.nodesPreservingAllocations(); - assertThat(nodesPreservingAllocations, hasSize(2)); - - assertThat(nodesPreservingAllocations.get(0).id(), equalTo("n_1")); - // 1000 - [(30+300+10)+(50 + 300 +10)] = 300 : deployments use 700 memory on the node - assertThat(nodesPreservingAllocations.get(0).availableMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); - // 8 - (1*1+1*4) = 3 : deployments use 5 cores on the node - assertThat(nodesPreservingAllocations.get(0).cores(), equalTo(3)); - - assertThat(nodesPreservingAllocations.get(1).id(), equalTo("n_2")); - // 1000 - (50 +300 + 2*10) = 630 : deployments use 340MB on the node - assertThat(nodesPreservingAllocations.get(1).availableMemoryBytes(), equalTo(ByteSizeValue.ofMb(630).getBytes())); - // 8 - (1*4) = 0 : preserving 1 allocation of deployment 2 should use 4 cores on the node - assertThat(nodesPreservingAllocations.get(1).cores(), equalTo(4)); - - List modelsPreservingAllocations = preserveOneAllocation.modelsPreservingAllocations(); - assertThat(modelsPreservingAllocations, hasSize(2)); - - assertThat(modelsPreservingAllocations.get(0).id(), equalTo("m_1")); - assertThat(modelsPreservingAllocations.get(0).memoryBytes(), equalTo(ByteSizeValue.ofMb(30).getBytes())); - assertThat(modelsPreservingAllocations.get(0).perDeploymentMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); - assertThat(modelsPreservingAllocations.get(0).perAllocationMemoryBytes(), equalTo(ByteSizeValue.ofMb(10).getBytes())); - assertThat(modelsPreservingAllocations.get(0).allocations(), equalTo(1)); - assertThat(modelsPreservingAllocations.get(0).threadsPerAllocation(), equalTo(1)); - assertThat(modelsPreservingAllocations.get(0).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0))); - - assertThat(modelsPreservingAllocations.get(1).id(), equalTo("m_2")); - assertThat(modelsPreservingAllocations.get(1).memoryBytes(), equalTo(ByteSizeValue.ofMb(50).getBytes())); - assertThat(modelsPreservingAllocations.get(1).perDeploymentMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); - assertThat(modelsPreservingAllocations.get(1).perAllocationMemoryBytes(), equalTo(ByteSizeValue.ofMb(10).getBytes())); - assertThat(modelsPreservingAllocations.get(1).allocations(), equalTo(4)); - assertThat(modelsPreservingAllocations.get(1).threadsPerAllocation(), equalTo(4)); - assertThat(modelsPreservingAllocations.get(1).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0, "n_2", 1))); - - // Now we have a plan with 2 deployments assigned to 2 nodes. - // Note that deployment 1 has already 1 allocation on node 1, and it gets 2 more. It's more than 2 allocations defined during - // initialization of deployment1, but we don't care at this point. - AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) - .assignModelToNode(deployment1, node1, 2) - .assignModelToNode(deployment2, node2, 1) - .build(); - assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 2))); - assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node2, 1))); - - plan = preserveOneAllocation.mergePreservedAllocations(plan); - - assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 3))); - assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node1, 1, node2, 2))); - // 1000 - [(30+300+3*10) + (50+300+10)] = 280 : deployments use 720MB on the node - assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(ByteSizeValue.ofMb(280).getBytes())); - // 8 - ((1*1+1*4) + 2*1) = 1 : deployments use 7 cores on the node - assertThat(plan.getRemainingNodeCores("n_1"), equalTo(1)); - // 1000 - (50 + 300 + 2*10) = 630 : deployments use 370MB on the node - assertThat(plan.getRemainingNodeMemory("n_2"), equalTo(ByteSizeValue.ofMb(630).getBytes())); - // 8 - [(1*4) + (1*4)] = 4 : deployment 2 should use all cores on the node - assertThat(plan.getRemainingNodeCores("n_2"), equalTo(0)); - - } + Node node1 = new Node("n_1", 100, 8); + Node node2 = new Node("n_2", 100, 8); + AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 30, 2, 1, Map.of("n_1", 1), 1); + AssignmentPlan.Deployment deployment2 = new Deployment("m_2", 50, 6, 4, Map.of("n_1", 1, "n_2", 2), 3); + PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation(List.of(node1, node2), List.of(deployment1, deployment2)); + + List nodesPreservingAllocations = preserveOneAllocation.nodesPreservingAllocations(); + assertThat(nodesPreservingAllocations, hasSize(2)); + + assertThat(nodesPreservingAllocations.get(0).id(), equalTo("n_1")); + assertThat(nodesPreservingAllocations.get(0).availableMemoryBytes(), equalTo(20L)); + assertThat(nodesPreservingAllocations.get(0).cores(), equalTo(3)); + + assertThat(nodesPreservingAllocations.get(1).id(), equalTo("n_2")); + assertThat(nodesPreservingAllocations.get(1).availableMemoryBytes(), equalTo(50L)); + assertThat(nodesPreservingAllocations.get(1).cores(), equalTo(4)); + + List modelsPreservingAllocations = preserveOneAllocation.modelsPreservingAllocations(); + assertThat(modelsPreservingAllocations, hasSize(2)); + + assertThat(modelsPreservingAllocations.get(0).id(), equalTo("m_1")); + assertThat(modelsPreservingAllocations.get(0).memoryBytes(), equalTo(30L)); + assertThat(modelsPreservingAllocations.get(0).allocations(), equalTo(1)); + assertThat(modelsPreservingAllocations.get(0).threadsPerAllocation(), equalTo(1)); + assertThat(modelsPreservingAllocations.get(0).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0))); + + assertThat(modelsPreservingAllocations.get(1).id(), equalTo("m_2")); + assertThat(modelsPreservingAllocations.get(1).memoryBytes(), equalTo(50L)); + assertThat(modelsPreservingAllocations.get(1).allocations(), equalTo(4)); + assertThat(modelsPreservingAllocations.get(1).threadsPerAllocation(), equalTo(4)); + assertThat(modelsPreservingAllocations.get(1).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0, "n_2", 1))); + + AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) + .assignModelToNode(deployment1, node1, 2) + .assignModelToNode(deployment2, node2, 1) + .build(); + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 2))); + assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node2, 1))); + + plan = preserveOneAllocation.mergePreservedAllocations(plan); + + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 3))); + assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node1, 1, node2, 2))); + assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(20L)); + assertThat(plan.getRemainingNodeCores("n_1"), equalTo(1)); + assertThat(plan.getRemainingNodeMemory("n_2"), equalTo(50L)); + assertThat(plan.getRemainingNodeCores("n_2"), equalTo(0)); } public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments() { - { - // old memory format - Node node = new Node("n_1", ByteSizeValue.ofMb(400).getBytes(), 4); - Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 2, 0, 0); - PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation(List.of(node), List.of(deployment)); - - AssignmentPlan plan = AssignmentPlan.builder(List.of(node), List.of(deployment)).build(); - assertThat(plan.assignments(deployment).isEmpty(), is(true)); - - plan = preserveOneAllocation.mergePreservedAllocations(plan); - assertThat(plan.assignments(deployment).isPresent(), is(true)); - assertThat(plan.assignments(deployment).get(), equalTo(Map.of(node, 1))); - // 400 - (30*2 + 240) = 100 : deployments use 300MB on the node - assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(ByteSizeValue.ofMb(100).getBytes())); - assertThat(plan.getRemainingNodeCores("n_1"), equalTo(2)); - } - { - // new memory format - Node node = new Node("n_1", ByteSizeValue.ofMb(400).getBytes(), 4); - Deployment deployment = new Deployment( - "m_1", - ByteSizeValue.ofMb(30).getBytes(), - 2, - 2, - Map.of("n_1", 2), - 2, - ByteSizeValue.ofMb(300).getBytes(), - ByteSizeValue.ofMb(10).getBytes() - ); - PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation(List.of(node), List.of(deployment)); - - AssignmentPlan plan = AssignmentPlan.builder(List.of(node), List.of(deployment)).build(); - assertThat(plan.assignments(deployment).isEmpty(), is(true)); - - plan = preserveOneAllocation.mergePreservedAllocations(plan); - assertThat(plan.assignments(deployment).isPresent(), is(true)); - assertThat(plan.assignments(deployment).get(), equalTo(Map.of(node, 1))); - // 400 - (30 + 300 + 10) = 60 : deployments use 340MB on the node - assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(ByteSizeValue.ofMb(60).getBytes())); - assertThat(plan.getRemainingNodeCores("n_1"), equalTo(2)); - } + Node node = new Node("n_1", 100, 4); + AssignmentPlan.Deployment deployment = new Deployment("m_1", 30, 2, 2, Map.of("n_1", 2), 2); + PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation(List.of(node), List.of(deployment)); + + AssignmentPlan plan = AssignmentPlan.builder(List.of(node), List.of(deployment)).build(); + assertThat(plan.assignments(deployment).isEmpty(), is(true)); + + plan = preserveOneAllocation.mergePreservedAllocations(plan); + assertThat(plan.assignments(deployment).isPresent(), is(true)); + assertThat(plan.assignments(deployment).get(), equalTo(Map.of(node, 1))); + assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(70L)); + assertThat(plan.getRemainingNodeCores("n_1"), equalTo(2)); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java index 651e4764cb894..7ceb8bbb86869 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java @@ -36,7 +36,7 @@ public class ZoneAwareAssignmentPlannerTests extends ESTestCase { public void testGivenOneModel_OneNode_OneZone_DoesNotFit() { Node node = new Node("n_1", 100, 1); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 1, 2, Map.of(), 0, 0, 0); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 1, 2, Map.of(), 0); AssignmentPlan plan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node)), List.of(deployment)).computePlan(); @@ -44,17 +44,8 @@ public void testGivenOneModel_OneNode_OneZone_DoesNotFit() { } public void testGivenOneModel_OneNode_OneZone_FullyFits() { - Node node = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(100).getBytes(), - 2, - 2, - Map.of(), - 0, - 0, - 0 - ); + Node node = new Node("n_1", 100, 4); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 2, 2, Map.of(), 0); AssignmentPlan plan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node)), List.of(deployment)).computePlan(); @@ -62,17 +53,8 @@ public void testGivenOneModel_OneNode_OneZone_FullyFits() { } public void testGivenOneModel_OneNode_OneZone_PartiallyFits() { - Node node = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 5); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(100).getBytes(), - 3, - 2, - Map.of(), - 0, - 0, - 0 - ); + Node node = new Node("n_1", 100, 5); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 3, 2, Map.of(), 0); AssignmentPlan plan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node)), List.of(deployment)).computePlan(); @@ -82,18 +64,9 @@ public void testGivenOneModel_OneNode_OneZone_PartiallyFits() { } public void testGivenOneModelWithSingleAllocation_OneNode_TwoZones() { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(100).getBytes(), - 1, - 2, - Map.of(), - 0, - 0, - 0 - ); + Node node1 = new Node("n_1", 100, 4); + Node node2 = new Node("n_2", 100, 4); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 1, 2, Map.of(), 0); AssignmentPlan plan = new ZoneAwareAssignmentPlanner( Map.of(List.of("z1"), List.of(node1), List.of("z2"), List.of(node2)), @@ -109,18 +82,9 @@ public void testGivenOneModelWithSingleAllocation_OneNode_TwoZones() { } public void testGivenOneModel_OneNodePerZone_TwoZones_FullyFits() { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(100).getBytes(), - 2, - 2, - Map.of(), - 0, - 0, - 0 - ); + Node node1 = new Node("n_1", 100, 4); + Node node2 = new Node("n_2", 100, 4); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 2, 2, Map.of(), 0); AssignmentPlan plan = new ZoneAwareAssignmentPlanner( Map.of(List.of("z_1"), List.of(node1), List.of("z_2"), List.of(node2)), @@ -135,18 +99,9 @@ public void testGivenOneModel_OneNodePerZone_TwoZones_FullyFits() { } public void testGivenOneModel_OneNodePerZone_TwoZones_PartiallyFits() { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment( - "m_1", - ByteSizeValue.ofMb(100).getBytes(), - 3, - 3, - Map.of(), - 0, - 0, - 0 - ); + Node node1 = new Node("n_1", 100, 4); + Node node2 = new Node("n_2", 100, 4); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 3, 3, Map.of(), 0); AssignmentPlan plan = new ZoneAwareAssignmentPlanner( Map.of(List.of("z_1"), List.of(node1), List.of("z_2"), List.of(node2)), @@ -162,15 +117,15 @@ public void testGivenOneModel_OneNodePerZone_TwoZones_PartiallyFits() { } public void testGivenThreeModels_TwoNodesPerZone_ThreeZones_FullyFit() { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); - Node node3 = new Node("n_3", ByteSizeValue.ofMb(1000).getBytes(), 4); - Node node4 = new Node("n_4", ByteSizeValue.ofMb(1000).getBytes(), 4); - Node node5 = new Node("n_5", ByteSizeValue.ofMb(1000).getBytes(), 4); - Node node6 = new Node("n_6", ByteSizeValue.ofMb(1000).getBytes(), 4); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 4, 1, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 6, 2, Map.of(), 0, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(30).getBytes(), 2, 3, Map.of(), 0, 0, 0); + Node node1 = new Node("n_1", 100, 4); + Node node2 = new Node("n_2", 100, 4); + Node node3 = new Node("n_3", 100, 4); + Node node4 = new Node("n_4", 100, 4); + Node node5 = new Node("n_5", 100, 4); + Node node6 = new Node("n_6", 100, 4); + AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 25, 4, 1, Map.of(), 0); + Deployment deployment2 = new AssignmentPlan.Deployment("m_2", 25, 6, 2, Map.of(), 0); + AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment("m_3", 25, 2, 3, Map.of(), 0); Map, List> nodesByZone = Map.of( List.of("z_1"), @@ -213,11 +168,11 @@ public void testGivenThreeModels_TwoNodesPerZone_ThreeZones_FullyFit() { } public void testGivenTwoModelsWithSingleAllocation_OneNode_ThreeZones() { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); - Node node3 = new Node("n_3", ByteSizeValue.ofMb(1000).getBytes(), 4); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 1, 1, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 1, 1, Map.of(), 0, 0, 0); + Node node1 = new Node("n_1", 100, 4); + Node node2 = new Node("n_2", 100, 4); + Node node3 = new Node("n_3", 100, 4); + AssignmentPlan.Deployment deployment1 = new Deployment("m_1", 25, 1, 1, Map.of(), 0); + AssignmentPlan.Deployment deployment2 = new Deployment("m_2", 25, 1, 1, Map.of(), 0); AssignmentPlan plan = new ZoneAwareAssignmentPlanner( Map.of(List.of("z1"), List.of(node1), List.of("z2"), List.of(node2), List.of("z3"), List.of(node3)), @@ -248,16 +203,7 @@ public void testPreviousAssignmentsGetAtLeastAsManyAllocationsAfterAddingNewMode .stream() .collect(Collectors.toMap(e -> e.getKey().id(), Map.Entry::getValue)); previousModelsPlusNew.add( - new AssignmentPlan.Deployment( - m.id(), - m.memoryBytes(), - m.allocations(), - m.threadsPerAllocation(), - previousAssignments, - 0, - 0, - 0 - ) + new AssignmentPlan.Deployment(m.id(), m.memoryBytes(), m.allocations(), m.threadsPerAllocation(), previousAssignments, 0) ); } previousModelsPlusNew.add(randomModel("new")); @@ -268,11 +214,11 @@ public void testPreviousAssignmentsGetAtLeastAsManyAllocationsAfterAddingNewMode } public void testGivenClusterResize_GivenOneZone_ShouldAllocateEachModelAtLeastOnce() { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(2580).getBytes(), 2); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(2580).getBytes(), 2); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, 0, 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1200).getBytes(), 2); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1200).getBytes(), 2); + AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0); + AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0); + AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0); // First only start m_1 AssignmentPlan assignmentPlan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node1, node2)), List.of(deployment1)) @@ -306,8 +252,8 @@ public void testGivenClusterResize_GivenOneZone_ShouldAllocateEachModelAtLeastOn assertThat(indexedBasedPlan.get("m_3"), equalTo(Map.of("n_2", 1))); // Now the cluster starts getting resized. - Node node3 = new Node("n_3", ByteSizeValue.ofMb(5160).getBytes(), 2); - Node node4 = new Node("n_4", ByteSizeValue.ofMb(5160).getBytes(), 2); + Node node3 = new Node("n_3", ByteSizeValue.ofMb(2400).getBytes(), 2); + Node node4 = new Node("n_4", ByteSizeValue.ofMb(2400).getBytes(), 2); // First, one node goes away. assignmentPlan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node1)), createModelsFromPlan(assignmentPlan)) diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MlAssignmentPlannerUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MlAssignmentPlannerUpgradeIT.java deleted file mode 100644 index 549ac23e16845..0000000000000 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MlAssignmentPlannerUpgradeIT.java +++ /dev/null @@ -1,287 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.upgrades; - -import org.elasticsearch.Version; -import org.elasticsearch.client.Request; -import org.elasticsearch.client.Response; -import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.xcontent.support.XContentMapValues; -import org.elasticsearch.core.Strings; -import org.elasticsearch.logging.LogManager; -import org.elasticsearch.logging.Logger; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Base64; -import java.util.List; -import java.util.Map; -import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; - -import static org.elasticsearch.client.WarningsHandler.PERMISSIVE; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.hasSize; - -public class MlAssignmentPlannerUpgradeIT extends AbstractUpgradeTestCase { - - private Logger logger = LogManager.getLogger(MlAssignmentPlannerUpgradeIT.class); - - // See PyTorchModelIT for how this model was created - static final String BASE_64_ENCODED_MODEL = - "UEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWlpaWlpaWlpaWoACY19fdG9yY2hfXwp" - + "TdXBlclNpbXBsZQpxACmBfShYCAAAAHRyYWluaW5ncQGIdWJxAi5QSwcIXOpBBDQAAAA0AAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAA" - + "AAAAAdAEEAc2ltcGxlbW9kZWwvY29kZS9fX3RvcmNoX18ucHlGQj0AWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaW" - + "lpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWnWOMWvDMBCF9/yKI5MMrnHTQsHgjt2aJdlCEIp9SgWSTpykFvfXV1htaYds0nfv473Jqhjh" - + "kAPywbhgUbzSnC02wwZAyqBYOUzIUUoY4XRe6SVr/Q8lVsYbf4UBLkS2kBk1aOIPxbOIaPVQtEQ8vUnZ/WlrSxTA+JCTNHMc4Ig+Ele" - + "s+Jod+iR3N/jDDf74wxu4e/5+DmtE9mUyhdgFNq7bZ3ekehbruC6aTxS/c1rom6Z698WrEfIYxcn4JGTftLA7tzCnJeD41IJVC+U07k" - + "umUHw3E47Vqh+xnULeFisYLx064mV8UTZibWFMmX0p23wBUEsHCE0EGH3yAAAAlwEAAFBLAwQUAAgICAAAAAAAAAAAAAAAAAAAAAAAJ" - + "wA5AHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYnVnX3BrbEZCNQBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpa" - + "WlpaWlpaWlpaWlpaWlpaWlpaWlpaWrWST0+DMBiHW6bOod/BGS94kKpo2Mwyox5x3pbgiXSAFtdR/nQu3IwHiZ9oX88CaeGu9tL0efq" - + "+v8P7fmiGA1wgTgoIcECZQqe6vmYD6G4hAJOcB1E8NazTm+ELyzY4C3Q0z8MsRwF+j4JlQUPEEo5wjH0WB9hCNFqgpOCExZY5QnnEw7" - + "ME+0v8GuaIs8wnKI7RigVrKkBzm0lh2OdjkeHllG28f066vK6SfEypF60S+vuYt4gjj2fYr/uPrSvRv356TepfJ9iWJRN0OaELQSZN3" - + "FRPNbcP1PTSntMr0x0HzLZQjPYIEo3UaFeiISRKH0Mil+BE/dyT1m7tCBLwVO1MX4DK3bbuTlXuy8r71j5Aoho66udAoseOnrdVzx28" - + "UFW6ROuO/lT6QKKyo79VU54emj9QSwcInsUTEDMBAAAFAwAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAZAAYAc2ltcGxlbW9kZWw" - + "vY29uc3RhbnRzLnBrbEZCAgBaWoACKS5QSwcIbS8JVwQAAAAEAAAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAATADsAc2ltcGxlbW" - + "9kZWwvdmVyc2lvbkZCNwBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaMwpQSwcI0" - + "Z5nVQIAAAACAAAAUEsBAgAAAAAICAAAAAAAAFzqQQQ0AAAANAAAABQAAAAAAAAAAAAAAAAAAAAAAHNpbXBsZW1vZGVsL2RhdGEucGts" - + "UEsBAgAAFAAICAgAAAAAAE0EGH3yAAAAlwEAAB0AAAAAAAAAAAAAAAAAhAAAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5UEs" - + "BAgAAFAAICAgAAAAAAJ7FExAzAQAABQMAACcAAAAAAAAAAAAAAAAAAgIAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYn" - + "VnX3BrbFBLAQIAAAAACAgAAAAAAABtLwlXBAAAAAQAAAAZAAAAAAAAAAAAAAAAAMMDAABzaW1wbGVtb2RlbC9jb25zdGFudHMucGtsU" - + "EsBAgAAAAAICAAAAAAAANGeZ1UCAAAAAgAAABMAAAAAAAAAAAAAAAAAFAQAAHNpbXBsZW1vZGVsL3ZlcnNpb25QSwYGLAAAAAAAAAAe" - + "Ay0AAAAAAAAAAAAFAAAAAAAAAAUAAAAAAAAAagEAAAAAAACSBAAAAAAAAFBLBgcAAAAA/AUAAAAAAAABAAAAUEsFBgAAAAAFAAUAagE" - + "AAJIEAAAAAA=="; - static final long RAW_MODEL_SIZE; // size of the model before base64 encoding - static { - RAW_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_MODEL).length; - } - - public void testMlAssignmentPlannerUpgrade() throws Exception { - assumeTrue("NLP model deployments added in 8.0", isOriginalClusterVersionAtLeast(Version.V_8_0_0)); - - logger.info("Starting testMlAssignmentPlannerUpgrade, model size {}", RAW_MODEL_SIZE); - - switch (CLUSTER_TYPE) { - case OLD -> { - // setup deployments using old and new memory format - setupDeployments(); - - waitForDeploymentStarted("old_memory_format"); - waitForDeploymentStarted("new_memory_format"); - - // assert correct memory format is used - assertOldMemoryFormat("old_memory_format"); - if (isOriginalClusterVersionAtLeast(Version.V_8_11_0)) { - assertNewMemoryFormat("new_memory_format"); - } else { - assertOldMemoryFormat("new_memory_format"); - } - } - case MIXED -> { - ensureHealth(".ml-inference-*,.ml-config*", (request -> { - request.addParameter("wait_for_status", "yellow"); - request.addParameter("timeout", "70s"); - })); - waitForDeploymentStarted("old_memory_format"); - waitForDeploymentStarted("new_memory_format"); - - // assert correct memory format is used - assertOldMemoryFormat("old_memory_format"); - if (isOriginalClusterVersionAtLeast(Version.V_8_11_0)) { - assertNewMemoryFormat("new_memory_format"); - } else { - assertOldMemoryFormat("new_memory_format"); - } - - } - case UPGRADED -> { - ensureHealth(".ml-inference-*,.ml-config*", (request -> { - request.addParameter("wait_for_status", "yellow"); - request.addParameter("timeout", "70s"); - })); - waitForDeploymentStarted("old_memory_format"); - waitForDeploymentStarted("new_memory_format"); - - // assert correct memory format is used - assertOldMemoryFormat("old_memory_format"); - assertNewMemoryFormat("new_memory_format"); - - cleanupDeployments(); - } - } - } - - @SuppressWarnings("unchecked") - private void waitForDeploymentStarted(String modelId) throws Exception { - assertBusy(() -> { - var response = getTrainedModelStats(modelId); - Map map = entityAsMap(response); - List> stats = (List>) map.get("trained_model_stats"); - assertThat(stats, hasSize(1)); - var stat = stats.get(0); - assertThat(stat.toString(), XContentMapValues.extractValue("deployment_stats.state", stat), equalTo("started")); - }, 30, TimeUnit.SECONDS); - } - - @SuppressWarnings("unchecked") - private void assertOldMemoryFormat(String modelId) throws Exception { - var response = getTrainedModelStats(modelId); - Map map = entityAsMap(response); - List> stats = (List>) map.get("trained_model_stats"); - assertThat(stats, hasSize(1)); - var stat = stats.get(0); - Long expectedMemoryUsage = ByteSizeValue.ofMb(240).getBytes() + RAW_MODEL_SIZE * 2; - Integer actualMemoryUsage = (Integer) XContentMapValues.extractValue("model_size_stats.required_native_memory_bytes", stat); - assertThat( - Strings.format("Memory usage mismatch for the model %s in cluster state %s", modelId, CLUSTER_TYPE.toString()), - actualMemoryUsage, - equalTo(expectedMemoryUsage.intValue()) - ); - } - - @SuppressWarnings("unchecked") - private void assertNewMemoryFormat(String modelId) throws Exception { - var response = getTrainedModelStats(modelId); - Map map = entityAsMap(response); - List> stats = (List>) map.get("trained_model_stats"); - assertThat(stats, hasSize(1)); - var stat = stats.get(0); - Long expectedMemoryUsage = ByteSizeValue.ofMb(300).getBytes() + RAW_MODEL_SIZE + ByteSizeValue.ofMb(10).getBytes(); - Integer actualMemoryUsage = (Integer) XContentMapValues.extractValue("model_size_stats.required_native_memory_bytes", stat); - assertThat(stat.toString(), actualMemoryUsage.toString(), equalTo(expectedMemoryUsage.toString())); - } - - private Response getTrainedModelStats(String modelId) throws IOException { - Request request = new Request("GET", "/_ml/trained_models/" + modelId + "/_stats"); - request.setOptions(request.getOptions().toBuilder().setWarningsHandler(PERMISSIVE).build()); - var response = client().performRequest(request); - assertOK(response); - return response; - } - - private Response infer(String input, String modelId) throws IOException { - Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/deployment/_infer"); - request.setJsonEntity(Strings.format(""" - { "docs": [{"input":"%s"}] } - """, input)); - request.setOptions(request.getOptions().toBuilder().setWarningsHandler(PERMISSIVE).build()); - var response = client().performRequest(request); - assertOK(response); - return response; - } - - private void putModelDefinition(String modelId) throws IOException { - Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/definition/0"); - request.setJsonEntity(Strings.format(""" - {"total_definition_length":%s,"definition": "%s","total_parts": 1}""", RAW_MODEL_SIZE, BASE_64_ENCODED_MODEL)); - client().performRequest(request); - } - - private void putVocabulary(List vocabulary, String modelId) throws IOException { - List vocabularyWithPad = new ArrayList<>(); - vocabularyWithPad.add("[PAD]"); - vocabularyWithPad.add("[UNK]"); - vocabularyWithPad.addAll(vocabulary); - String quotedWords = vocabularyWithPad.stream().map(s -> "\"" + s + "\"").collect(Collectors.joining(",")); - - Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/vocabulary"); - request.setJsonEntity(Strings.format(""" - { "vocabulary": [%s] } - """, quotedWords)); - client().performRequest(request); - } - - private void setupDeployments() throws Exception { - createTrainedModel("old_memory_format", 0, 0); - putModelDefinition("old_memory_format"); - putVocabulary(List.of("these", "are", "my", "words"), "old_memory_format"); - startDeployment("old_memory_format"); - - createTrainedModel("new_memory_format", ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(10).getBytes()); - putModelDefinition("new_memory_format"); - putVocabulary(List.of("these", "are", "my", "words"), "new_memory_format"); - startDeployment("new_memory_format"); - } - - private void cleanupDeployments() throws IOException { - stopDeployment("old_memory_format"); - deleteTrainedModel("old_memory_format"); - stopDeployment("new_memory_format"); - deleteTrainedModel("new_memory_format"); - } - - private void createTrainedModel(String modelId, long perDeploymentMemoryBytes, long perAllocationMemoryBytes) throws IOException { - Request request = new Request("PUT", "/_ml/trained_models/" + modelId); - if (perAllocationMemoryBytes > 0 && perDeploymentMemoryBytes > 0) { - request.setJsonEntity(Strings.format(""" - { - "description": "simple model for testing", - "model_type": "pytorch", - "inference_config": { - "pass_through": { - "tokenization": { - "bert": { - "with_special_tokens": false - } - } - } - }, - "metadata": { - "per_deployment_memory_bytes": %s, - "per_allocation_memory_bytes": %s - } - }""", perDeploymentMemoryBytes, perAllocationMemoryBytes)); - } else { - request.setJsonEntity(""" - { - "description": "simple model for testing", - "model_type": "pytorch", - "inference_config": { - "pass_through": { - "tokenization": { - "bert": { - "with_special_tokens": false - } - } - } - } - }"""); - } - client().performRequest(request); - } - - private void deleteTrainedModel(String modelId) throws IOException { - Request request = new Request("DELETE", "_ml/trained_models/" + modelId); - client().performRequest(request); - } - - private Response startDeployment(String modelId) throws IOException { - return startDeployment(modelId, "started"); - } - - private Response startDeployment(String modelId, String waitForState) throws IOException { - Request request = new Request( - "POST", - "/_ml/trained_models/" - + modelId - + "/deployment/_start?timeout=40s&wait_for=" - + waitForState - + "&inference_threads=1&model_threads=1" - ); - request.setOptions(request.getOptions().toBuilder().setWarningsHandler(PERMISSIVE).build()); - var response = client().performRequest(request); - assertOK(response); - return response; - } - - private void stopDeployment(String modelId) throws IOException { - String endpoint = "/_ml/trained_models/" + modelId + "/deployment/_stop"; - Request request = new Request("POST", endpoint); - client().performRequest(request); - } -} From cef2b800251d8f4ad44f375ba0e58344d3f26f63 Mon Sep 17 00:00:00 2001 From: Jake Landis Date: Mon, 6 Nov 2023 11:06:25 -0600 Subject: [PATCH 10/21] Fix memory leak from JWT cache (and fix the usage of the JWT auth cache) (#101799) This commit fixes a memory leak and ensures that the JWT authentication cache is actually used to short circuit authenticate. When JWT authentication is successful we cache a hash(JWT) -> User. On subsequent authentication attempts we should short circuit some of the more expensive validation thanks to this cache. The existing cache key is a `BytesArray` constructed from `jwtAuthenticationToken.getUserCredentialsHash()` (the hash value of the JWT). However, `jwtAuthenticationToken` is comes from the security context which is effectively a request scoped object. When the security context goes out of scope the value of `jwtAuthenticationToken.getUserCredentialsHash()` is zero'ed out to help keep sensitive data out of the heap. It is arguable if zero'ing out that data is useful especially for a hashed value, but is inline with the internal contract/expectations. Since the cache key is derived from `jwtAuthenticationToken.getUserCredentialsHash()` when that get object is zero'ed out, so is the cache key. This results in a cache key that changes from a valid value to an empty byte array. This results in a junk cache entry. Subsequent authentication with the same JWT will result in a new cache entry which will then follow the same pattern of getting zero'ed out. This results in a useless cache with nothing but zero'ed out cache keys. This negates any benefits of having a cache at all in that a full authentication is preformed all the time which can be expensive for JWT (especially since JWT requires role mappings and role mappings are not cached). Fortunately the default cache size is 100k (by count) so the actual memory leak is technically capped but can vary depending on how large the cache values. This is an approximate cap of ~55MB where 3.125 MB (100k * 256 bits) for the sha256 cache keys + ~50MB (100k * ~50 bytes) cache values, however, it is possible for the cache to be larger if the values in the cache are larger. The fix here is to ensure the cache key used is a copy of the value that is zero'ed out (before it is zero'ed out). fixes: https://github.com/elastic/elasticsearch/issues/101752 --- docs/changelog/101799.yaml | 5 ++ .../xpack/security/authc/jwt/JwtRealm.java | 10 +++- .../authc/jwt/JwtRealmAuthenticateTests.java | 20 +++++++ .../security/authc/jwt/JwtRealmTestCase.java | 55 ++++++++++++++----- 4 files changed, 74 insertions(+), 16 deletions(-) create mode 100644 docs/changelog/101799.yaml diff --git a/docs/changelog/101799.yaml b/docs/changelog/101799.yaml new file mode 100644 index 0000000000000..a3ef5fb190177 --- /dev/null +++ b/docs/changelog/101799.yaml @@ -0,0 +1,5 @@ +pr: 101799 +summary: Fix memory leak from JWT cache (and fix the usage of the JWT auth cache) +area: Authentication +type: bug +issues: [] diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java index d56d8dd6b968f..dea471846b9f4 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.lucene.util.BytesRef; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesArray; @@ -283,7 +284,9 @@ public void authenticate(final AuthenticationToken authenticationToken, final Ac return; // FAILED (secret is missing or mismatched) } - final BytesArray jwtCacheKey = isCacheEnabled() ? new BytesArray(jwtAuthenticationToken.getUserCredentialsHash()) : null; + final BytesArray jwtCacheKey = isCacheEnabled() + ? new BytesArray(new BytesRef(jwtAuthenticationToken.getUserCredentialsHash()), true) + : null; if (jwtCacheKey != null) { final User cachedUser = tryAuthenticateWithCache(tokenPrincipal, jwtCacheKey); if (cachedUser != null) { @@ -483,6 +486,11 @@ private boolean isCacheEnabled() { return jwtCache != null && jwtCacheHelper != null; } + // package private for testing + Cache getJwtCache() { + return jwtCache; + } + /** * Format and filter JWT contents as user metadata. * @param claimsSet Claims are supported. Claim keys are prefixed by "jwt_claim_". diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java index 7fd4a70f7505e..f75876a755557 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java @@ -67,6 +67,26 @@ public void testJwtAuthcRealmAuthcAuthzWithEmptyRoles() throws Exception { doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, jwtAuthcCount); } + public void testJwtCache() throws Exception { + jwtIssuerAndRealms = generateJwtIssuerRealmPairs(1, 1, 1, 1, 1, 1, 99, false); + JwtRealm realm = jwtIssuerAndRealms.get(0).realm(); + realm.expireAll(); + assertThat(realm.getJwtCache().count(), is(0)); + final JwtIssuerAndRealm jwtIssuerAndRealm = randomJwtIssuerRealmPair(); + final SecureString clientSecret = JwtRealmInspector.getClientAuthenticationSharedSecret(jwtIssuerAndRealm.realm()); + for (int i = 1; i <= randomIntBetween(2, 10); i++) { + User user = randomUser(jwtIssuerAndRealm.issuer()); + doMultipleAuthcAuthzAndVerifySuccess( + jwtIssuerAndRealm.realm(), + user, + randomJwt(jwtIssuerAndRealm, user), + clientSecret, + randomIntBetween(2, 10) + ); + assertThat(realm.getJwtCache().count(), is(i)); + } + } + /** * Test with no authz realms. * @throws Exception Unexpected test failure diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java index ab1b7867ffa04..64f2444e0182d 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java @@ -12,6 +12,8 @@ import com.nimbusds.openid.connect.sdk.Nonce; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.cache.Cache; import org.elasticsearch.common.settings.MockSecureSettings; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; @@ -46,6 +48,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Date; +import java.util.HexFormat; import java.util.List; import java.util.Map; import java.util.TreeSet; @@ -290,7 +293,7 @@ protected JwtRealmSettingsBuilder createJwtRealmSettingsBuilder(final JwtIssuer if (randomBoolean()) { authcSettings.put( RealmSettings.getFullSettingKey(authcRealmName, JwtRealmSettings.JWT_CACHE_TTL), - randomIntBetween(10, 120) + randomFrom("s", "m", "h") + randomIntBetween(10, 120) + randomFrom("m", "h") ); } authcSettings.put(RealmSettings.getFullSettingKey(authcRealmName, JwtRealmSettings.JWT_CACHE_SIZE), jwtCacheSize); @@ -378,11 +381,12 @@ protected void doMultipleAuthcAuthzAndVerifySuccess( final int jwtAuthcRepeats ) { final List jwtRealmsList = jwtIssuerAndRealms.stream().map(p -> p.realm).toList(); - + BytesArray firstCacheKeyFound = null; // Select different test JWKs from the JWT realm, and generate test JWTs for the test user. Run the JWT through the chain. for (int authcRun = 1; authcRun <= jwtAuthcRepeats; authcRun++) { + final ThreadContext requestThreadContext = createThreadContext(jwt, sharedSecret); - logger.info("REQ[" + authcRun + "/" + jwtAuthcRepeats + "] HEADERS=" + requestThreadContext.getHeaders()); + logger.debug("REQ[" + authcRun + "/" + jwtAuthcRepeats + "] HEADERS=" + requestThreadContext.getHeaders()); // Any JWT realm can recognize and extract the request headers. final var jwtAuthenticationToken = (JwtAuthenticationToken) randomFrom(jwtRealmsList).token(requestThreadContext); @@ -393,11 +397,11 @@ protected void doMultipleAuthcAuthzAndVerifySuccess( // Loop through all authc/authz realms. Confirm user is returned with expected principal and roles. User authenticatedUser = null; realmLoop: for (final JwtRealm candidateJwtRealm : jwtRealmsList) { - logger.info("TRY AUTHC: expected=[" + jwtRealm.name() + "], candidate[" + candidateJwtRealm.name() + "]."); + logger.debug("TRY AUTHC: expected=[" + jwtRealm.name() + "], candidate[" + candidateJwtRealm.name() + "]."); final PlainActionFuture> authenticateFuture = PlainActionFuture.newFuture(); candidateJwtRealm.authenticate(jwtAuthenticationToken, authenticateFuture); final AuthenticationResult authenticationResult = authenticateFuture.actionGet(); - logger.info("Authentication result with realm [{}]: [{}]", candidateJwtRealm.name(), authenticationResult); + logger.debug("Authentication result with realm [{}]: [{}]", candidateJwtRealm.name(), authenticationResult); switch (authenticationResult.getStatus()) { case SUCCESS: assertThat("Unexpected realm SUCCESS status", candidateJwtRealm.name(), equalTo(jwtRealm.name())); @@ -430,20 +434,41 @@ protected void doMultipleAuthcAuthzAndVerifySuccess( equalTo(Map.of("jwt_token_type", JwtRealmInspector.getTokenType(jwtRealm).value())) ); } + // if the cache is enabled ensure the cache is used and does not change for the provided jwt + if (jwtRealm.getJwtCache() != null) { + Cache cache = jwtRealm.getJwtCache(); + if (firstCacheKeyFound == null) { + assertNotNull("could not find cache keys", cache.keys()); + firstCacheKeyFound = cache.keys().iterator().next(); + } + jwtAuthenticationToken.clearCredentials(); // simulates the realm's context closing which clears the credential + boolean foundInCache = false; + for (BytesArray key : cache.keys()) { + logger.trace("cache key: " + HexFormat.of().formatHex(key.array())); + if (key.equals(firstCacheKeyFound)) { + foundInCache = true; + } + assertFalse( + "cache key should not be nulled out", + IntStream.range(0, key.array().length).map(idx -> key.array()[idx]).allMatch(b -> b == 0) + ); + } + assertTrue("cache key was not found in cache", foundInCache); + } } - logger.info("Test succeeded"); + logger.debug("Test succeeded"); } protected User randomUser(final JwtIssuer jwtIssuer) { final User user = randomFrom(jwtIssuer.principals.values()); - logger.info("USER[" + user.principal() + "]: roles=[" + String.join(",", user.roles()) + "]."); + logger.debug("USER[" + user.principal() + "]: roles=[" + String.join(",", user.roles()) + "]."); return user; } protected SecureString randomJwt(final JwtIssuerAndRealm jwtIssuerAndRealm, User user) throws Exception { final JwtIssuer.AlgJwkPair algJwkPair = randomFrom(jwtIssuerAndRealm.issuer.algAndJwksAll); final JWK jwk = algJwkPair.jwk(); - logger.info( + logger.debug( "ALG[" + algJwkPair.alg() + "]. JWK: kty=[" @@ -491,7 +516,7 @@ protected void printJwtRealmAndIssuer(JwtIssuerAndRealm jwtIssuerAndRealm) throw } protected void printJwtRealm(final JwtRealm jwtRealm) { - logger.info( + logger.debug( "REALM[" + jwtRealm.name() + "," @@ -527,15 +552,15 @@ protected void printJwtRealm(final JwtRealm jwtRealm) { + "]." ); for (final JWK jwk : JwtRealmInspector.getJwksAlgsHmac(jwtRealm).jwks()) { - logger.info("REALM HMAC: jwk=[{}]", jwk); + logger.debug("REALM HMAC: jwk=[{}]", jwk); } for (final JWK jwk : JwtRealmInspector.getJwksAlgsPkc(jwtRealm).jwks()) { - logger.info("REALM PKC: jwk=[{}]", jwk); + logger.debug("REALM PKC: jwk=[{}]", jwk); } } protected void printJwtIssuer(final JwtIssuer jwtIssuer) { - logger.info( + logger.debug( "ISSUER: iss=[" + jwtIssuer.issuerClaimValue + "], aud=[" @@ -549,13 +574,13 @@ protected void printJwtIssuer(final JwtIssuer jwtIssuer) { + "]." ); if (jwtIssuer.algAndJwkHmacOidc != null) { - logger.info("ISSUER HMAC OIDC: alg=[{}] jwk=[{}]", jwtIssuer.algAndJwkHmacOidc.alg(), jwtIssuer.encodedKeyHmacOidc); + logger.debug("ISSUER HMAC OIDC: alg=[{}] jwk=[{}]", jwtIssuer.algAndJwkHmacOidc.alg(), jwtIssuer.encodedKeyHmacOidc); } for (final JwtIssuer.AlgJwkPair pair : jwtIssuer.algAndJwksHmac) { - logger.info("ISSUER HMAC: alg=[{}] jwk=[{}]", pair.alg(), pair.jwk()); + logger.debug("ISSUER HMAC: alg=[{}] jwk=[{}]", pair.alg(), pair.jwk()); } for (final JwtIssuer.AlgJwkPair pair : jwtIssuer.algAndJwksPkc) { - logger.info("ISSUER PKC: alg=[{}] jwk=[{}]", pair.alg(), pair.jwk()); + logger.debug("ISSUER PKC: alg=[{}] jwk=[{}]", pair.alg(), pair.jwk()); } } } From 4a37aef80bb34550b1c70c278a1cf07f0ca8020f Mon Sep 17 00:00:00 2001 From: David Turner Date: Mon, 6 Nov 2023 19:30:07 +0000 Subject: [PATCH 11/21] Deprecate `ExternalTestCluster` (#101844) `ExternalTestCluster` doesn't really make sense now that the transport client is removed. We only use it in the ML integ test suite and it'd be good to avoid expanding its usage further, so this commit deprecates it and removes the functionality in `ESIntegTestCase` that might quietly switch to using it in a new test suite if running with certain system properties. Relates #49582 --- .../common/ReloadSynonymAnalyzerIT.java | 10 ----- .../lifecycle/CrudDataStreamLifecycleIT.java | 4 -- .../DataStreamLifecycleServiceIT.java | 4 -- .../ExplainDataStreamLifecycleIT.java | 4 -- .../ingest/common/IngestRestartIT.java | 5 --- .../join/query/ParentChildTestCase.java | 5 --- .../documentation/ReindexDocumentationIT.java | 5 --- .../elasticsearch/ESNetty4IntegTestCase.java | 5 --- .../elasticsearch/http/HttpSmokeTestCase.java | 5 --- .../ingest/IngestStatsNamesAndTypesIT.java | 5 --- .../PersistentTasksExecutorFullRestartIT.java | 4 -- .../persistent/PersistentTasksExecutorIT.java | 4 -- .../decider/EnableAssignmentDeciderIT.java | 5 --- .../search/fieldcaps/FieldCapabilitiesIT.java | 5 --- .../elasticsearch/test/ESIntegTestCase.java | 41 ------------------- .../test/ExternalTestCluster.java | 4 ++ ...AutoscalingCapacityRestCancellationIT.java | 5 --- .../DataStreamLifecycleDownsampleIT.java | 4 -- .../eql/action/RestEqlCancellationIT.java | 4 -- .../ClusterStateWaitThresholdBreachTests.java | 5 --- ...ataStreamAndIndexLifecycleMixingTests.java | 4 -- .../xpack/ilm/DataTiersMigrationsTests.java | 5 --- .../IndexLifecycleInitialisationTests.java | 5 --- .../ml/integration/MlNativeIntegTestCase.java | 39 +++++++++++++++++- .../xpack/ml/support/BaseMlIntegTestCase.java | 5 --- .../exporter/http/HttpExporterIT.java | 5 --- .../exporter/http/HttpExporterSslIT.java | 5 --- .../xpack/profiling/ProfilingTestCase.java | 5 --- .../ShrinkIndexWithSecurityTests.java | 5 --- .../sql/action/RestSqlCancellationIT.java | 4 -- .../xpack/sql/action/SqlLicenseIT.java | 4 -- 31 files changed, 41 insertions(+), 178 deletions(-) diff --git a/modules/analysis-common/src/internalClusterTest/java/org/elasticsearch/analysis/common/ReloadSynonymAnalyzerIT.java b/modules/analysis-common/src/internalClusterTest/java/org/elasticsearch/analysis/common/ReloadSynonymAnalyzerIT.java index a9ffdb60419f9..f0063f663142d 100644 --- a/modules/analysis-common/src/internalClusterTest/java/org/elasticsearch/analysis/common/ReloadSynonymAnalyzerIT.java +++ b/modules/analysis-common/src/internalClusterTest/java/org/elasticsearch/analysis/common/ReloadSynonymAnalyzerIT.java @@ -17,7 +17,6 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.test.ESIntegTestCase; -import org.elasticsearch.test.InternalTestCluster; import java.io.FileNotFoundException; import java.io.IOException; @@ -44,15 +43,6 @@ protected Collection> nodePlugins() { return Arrays.asList(CommonAnalysisPlugin.class); } - /** - * This test needs to write to the config directory, this is difficult in an external cluster so we overwrite this to force running with - * {@link InternalTestCluster} - */ - @Override - protected boolean ignoreExternalCluster() { - return true; - } - public void testSynonymsUpdateable() throws FileNotFoundException, IOException, InterruptedException { testSynonymsUpdate(false); } diff --git a/modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/lifecycle/CrudDataStreamLifecycleIT.java b/modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/lifecycle/CrudDataStreamLifecycleIT.java index ff84501697e21..e33b1fdcfa57a 100644 --- a/modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/lifecycle/CrudDataStreamLifecycleIT.java +++ b/modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/lifecycle/CrudDataStreamLifecycleIT.java @@ -36,10 +36,6 @@ protected Collection> nodePlugins() { return List.of(DataStreamsPlugin.class, MockTransportService.TestPlugin.class); } - protected boolean ignoreExternalCluster() { - return true; - } - public void testGetLifecycle() throws Exception { DataStreamLifecycle lifecycle = randomLifecycle(); putComposableIndexTemplate("id1", null, List.of("with-lifecycle*"), null, null, lifecycle); diff --git a/modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/lifecycle/DataStreamLifecycleServiceIT.java b/modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/lifecycle/DataStreamLifecycleServiceIT.java index e4f4f88254977..0d3588ba20b9a 100644 --- a/modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/lifecycle/DataStreamLifecycleServiceIT.java +++ b/modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/lifecycle/DataStreamLifecycleServiceIT.java @@ -90,10 +90,6 @@ protected Collection> nodePlugins() { return List.of(DataStreamsPlugin.class, MockTransportService.TestPlugin.class); } - protected boolean ignoreExternalCluster() { - return true; - } - @Override protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { Settings.Builder settings = Settings.builder().put(super.nodeSettings(nodeOrdinal, otherSettings)); diff --git a/modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/lifecycle/ExplainDataStreamLifecycleIT.java b/modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/lifecycle/ExplainDataStreamLifecycleIT.java index 6ff50d88aeb05..c9968a545cb7d 100644 --- a/modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/lifecycle/ExplainDataStreamLifecycleIT.java +++ b/modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/lifecycle/ExplainDataStreamLifecycleIT.java @@ -62,10 +62,6 @@ protected Collection> nodePlugins() { return List.of(DataStreamsPlugin.class, MockTransportService.TestPlugin.class); } - protected boolean ignoreExternalCluster() { - return true; - } - @Override protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { Settings.Builder settings = Settings.builder().put(super.nodeSettings(nodeOrdinal, otherSettings)); diff --git a/modules/ingest-common/src/internalClusterTest/java/org/elasticsearch/ingest/common/IngestRestartIT.java b/modules/ingest-common/src/internalClusterTest/java/org/elasticsearch/ingest/common/IngestRestartIT.java index 96ca77a5f65f9..5709fbd9d8bfc 100644 --- a/modules/ingest-common/src/internalClusterTest/java/org/elasticsearch/ingest/common/IngestRestartIT.java +++ b/modules/ingest-common/src/internalClusterTest/java/org/elasticsearch/ingest/common/IngestRestartIT.java @@ -50,11 +50,6 @@ protected Collection> nodePlugins() { return Arrays.asList(IngestCommonPlugin.class, CustomScriptPlugin.class); } - @Override - protected boolean ignoreExternalCluster() { - return true; - } - public static class CustomScriptPlugin extends MockScriptPlugin { @Override protected Map, Object>> pluginScripts() { diff --git a/modules/parent-join/src/internalClusterTest/java/org/elasticsearch/join/query/ParentChildTestCase.java b/modules/parent-join/src/internalClusterTest/java/org/elasticsearch/join/query/ParentChildTestCase.java index b4846a1c003a6..a67ebd4cbca22 100644 --- a/modules/parent-join/src/internalClusterTest/java/org/elasticsearch/join/query/ParentChildTestCase.java +++ b/modules/parent-join/src/internalClusterTest/java/org/elasticsearch/join/query/ParentChildTestCase.java @@ -29,11 +29,6 @@ @ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE) public abstract class ParentChildTestCase extends ESIntegTestCase { - @Override - protected boolean ignoreExternalCluster() { - return true; - } - @Override protected Collection> nodePlugins() { return Arrays.asList(InternalSettingsPlugin.class, ParentJoinPlugin.class); diff --git a/modules/reindex/src/internalClusterTest/java/org/elasticsearch/client/documentation/ReindexDocumentationIT.java b/modules/reindex/src/internalClusterTest/java/org/elasticsearch/client/documentation/ReindexDocumentationIT.java index 9947d8a727d28..50284008eef48 100644 --- a/modules/reindex/src/internalClusterTest/java/org/elasticsearch/client/documentation/ReindexDocumentationIT.java +++ b/modules/reindex/src/internalClusterTest/java/org/elasticsearch/client/documentation/ReindexDocumentationIT.java @@ -56,11 +56,6 @@ public class ReindexDocumentationIT extends ESIntegTestCase { private static final Semaphore ALLOWED_OPERATIONS = new Semaphore(0); private static final String INDEX_NAME = "source_index"; - @Override - protected boolean ignoreExternalCluster() { - return true; - } - @Override protected Collection> nodePlugins() { return Arrays.asList(ReindexPlugin.class, ReindexCancellationPlugin.class); diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/ESNetty4IntegTestCase.java b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/ESNetty4IntegTestCase.java index 09c6b3d50a380..c996f55198bf6 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/ESNetty4IntegTestCase.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/ESNetty4IntegTestCase.java @@ -19,11 +19,6 @@ public abstract class ESNetty4IntegTestCase extends ESIntegTestCase { - @Override - protected boolean ignoreExternalCluster() { - return true; - } - @Override protected boolean addMockTransportService() { return false; diff --git a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/HttpSmokeTestCase.java b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/HttpSmokeTestCase.java index 2533b213d469c..4536e2ee25fd6 100644 --- a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/HttpSmokeTestCase.java +++ b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/HttpSmokeTestCase.java @@ -41,11 +41,6 @@ protected Collection> nodePlugins() { return List.of(getTestTransportPlugin(), MainRestPlugin.class); } - @Override - protected boolean ignoreExternalCluster() { - return true; - } - public static void assertOK(Response response) { assertThat(response.getStatusLine().getStatusCode(), oneOf(200, 201)); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/ingest/IngestStatsNamesAndTypesIT.java b/server/src/internalClusterTest/java/org/elasticsearch/ingest/IngestStatsNamesAndTypesIT.java index 547da987dcb91..2a4174ba427af 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/ingest/IngestStatsNamesAndTypesIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/ingest/IngestStatsNamesAndTypesIT.java @@ -45,11 +45,6 @@ protected Collection> nodePlugins() { return List.of(CustomIngestTestPlugin.class, CustomScriptPlugin.class); } - @Override - protected boolean ignoreExternalCluster() { - return true; - } - @SuppressWarnings("unchecked") public void testIngestStatsNamesAndTypes() throws IOException { String pipeline1 = org.elasticsearch.core.Strings.format(""" diff --git a/server/src/internalClusterTest/java/org/elasticsearch/persistent/PersistentTasksExecutorFullRestartIT.java b/server/src/internalClusterTest/java/org/elasticsearch/persistent/PersistentTasksExecutorFullRestartIT.java index f51ff1da9bfc9..d1c72a9650b85 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/persistent/PersistentTasksExecutorFullRestartIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/persistent/PersistentTasksExecutorFullRestartIT.java @@ -34,10 +34,6 @@ protected Collection> nodePlugins() { return Collections.singletonList(TestPersistentTasksPlugin.class); } - protected boolean ignoreExternalCluster() { - return true; - } - public void testFullClusterRestart() throws Exception { PersistentTasksService service = internalCluster().getInstance(PersistentTasksService.class); int numberOfTasks = randomIntBetween(1, 10); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/persistent/PersistentTasksExecutorIT.java b/server/src/internalClusterTest/java/org/elasticsearch/persistent/PersistentTasksExecutorIT.java index c91f5138e919f..3cc90a6795e37 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/persistent/PersistentTasksExecutorIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/persistent/PersistentTasksExecutorIT.java @@ -51,10 +51,6 @@ protected Collection> nodePlugins() { return Collections.singletonList(TestPersistentTasksPlugin.class); } - protected boolean ignoreExternalCluster() { - return true; - } - @Before public void resetNonClusterStateCondition() { TestPersistentTasksExecutor.setNonClusterStateCondition(true); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/persistent/decider/EnableAssignmentDeciderIT.java b/server/src/internalClusterTest/java/org/elasticsearch/persistent/decider/EnableAssignmentDeciderIT.java index cb24b78a499ac..d9aa15ed6e2f5 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/persistent/decider/EnableAssignmentDeciderIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/persistent/decider/EnableAssignmentDeciderIT.java @@ -36,11 +36,6 @@ protected Collection> nodePlugins() { return singletonList(TestPersistentTasksPlugin.class); } - @Override - protected boolean ignoreExternalCluster() { - return true; - } - /** * Test that the {@link EnableAssignmentDecider#CLUSTER_TASKS_ALLOCATION_ENABLE_SETTING} setting correctly * prevents persistent tasks to be assigned after a cluster restart. diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/FieldCapabilitiesIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/FieldCapabilitiesIT.java index 474d4ebc12843..480556b942ac8 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/FieldCapabilitiesIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/FieldCapabilitiesIT.java @@ -193,11 +193,6 @@ protected boolean addMockHttpTransport() { return false; // enable http } - @Override - protected boolean ignoreExternalCluster() { - return true; - } - public void testFieldAlias() { FieldCapabilitiesResponse response = client().prepareFieldCaps().setFields("distance", "route_length_miles").get(); diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java index c75a52a82caf1..37e4176e1818d 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java @@ -158,9 +158,7 @@ import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; -import java.net.InetAddress; import java.net.InetSocketAddress; -import java.net.URL; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; @@ -1983,46 +1981,7 @@ protected Collection> nodePlugins() { return Collections.emptyList(); } - private ExternalTestCluster buildExternalCluster(String clusterAddresses, String clusterName) throws IOException { - String[] stringAddresses = clusterAddresses.split(","); - TransportAddress[] transportAddresses = new TransportAddress[stringAddresses.length]; - int i = 0; - for (String stringAddress : stringAddresses) { - URL url = new URL("http://" + stringAddress); - InetAddress inetAddress = InetAddress.getByName(url.getHost()); - transportAddresses[i++] = new TransportAddress(new InetSocketAddress(inetAddress, url.getPort())); - } - return new ExternalTestCluster( - createTempDir(), - externalClusterClientSettings(), - nodePlugins(), - getClientWrapper(), - clusterName, - transportAddresses - ); - } - - protected Settings externalClusterClientSettings() { - return Settings.EMPTY; - } - - protected boolean ignoreExternalCluster() { - return false; - } - protected TestCluster buildTestCluster(Scope scope, long seed) throws IOException { - String clusterAddresses = System.getProperty(TESTS_CLUSTER); - if (Strings.hasLength(clusterAddresses) && ignoreExternalCluster() == false) { - if (scope == Scope.TEST) { - throw new IllegalArgumentException("Cannot run TEST scope test with " + TESTS_CLUSTER); - } - String clusterName = System.getProperty(TESTS_CLUSTER_NAME); - if (Strings.isNullOrEmpty(clusterName)) { - throw new IllegalArgumentException("External test cluster name must be provided"); - } - return buildExternalCluster(clusterAddresses, clusterName); - } - final String nodePrefix = switch (scope) { case TEST -> TEST_CLUSTER_NODE_PREFIX; case SUITE -> SUITE_CLUSTER_NODE_PREFIX; diff --git a/test/framework/src/main/java/org/elasticsearch/test/ExternalTestCluster.java b/test/framework/src/main/java/org/elasticsearch/test/ExternalTestCluster.java index dcddbbbcece64..5423263c88da6 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ExternalTestCluster.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ExternalTestCluster.java @@ -51,7 +51,11 @@ * External cluster to run the tests against. * It is a pure immutable test cluster that allows to send requests to a pre-existing cluster * and supports by nature all the needed test operations like wipeIndices etc. + * + * @deprecated not a realistic test setup since the removal of the transport client, use {@link ESIntegTestCase} for internal-cluster tests + * or {@link org.elasticsearch.test.rest.ESRestTestCase} otherwise. */ +@Deprecated(forRemoval = true) public final class ExternalTestCluster extends TestCluster { private static final Logger logger = LogManager.getLogger(ExternalTestCluster.class); diff --git a/x-pack/plugin/autoscaling/src/internalClusterTest/java/org/elasticsearch/xpack/autoscaling/action/GetAutoscalingCapacityRestCancellationIT.java b/x-pack/plugin/autoscaling/src/internalClusterTest/java/org/elasticsearch/xpack/autoscaling/action/GetAutoscalingCapacityRestCancellationIT.java index d3d7744bddc4c..afe8759acc7a3 100644 --- a/x-pack/plugin/autoscaling/src/internalClusterTest/java/org/elasticsearch/xpack/autoscaling/action/GetAutoscalingCapacityRestCancellationIT.java +++ b/x-pack/plugin/autoscaling/src/internalClusterTest/java/org/elasticsearch/xpack/autoscaling/action/GetAutoscalingCapacityRestCancellationIT.java @@ -60,11 +60,6 @@ protected Collection> nodePlugins() { return Collections.unmodifiableList(result); } - @Override - protected boolean ignoreExternalCluster() { - return true; - } - public void testCapacityRestCancellationAndResponse() throws Exception { internalCluster().startMasterOnlyNode(); diff --git a/x-pack/plugin/downsample/src/internalClusterTest/java/org/elasticsearch/xpack/downsample/DataStreamLifecycleDownsampleIT.java b/x-pack/plugin/downsample/src/internalClusterTest/java/org/elasticsearch/xpack/downsample/DataStreamLifecycleDownsampleIT.java index c38ed182abc64..cb945e8ffa418 100644 --- a/x-pack/plugin/downsample/src/internalClusterTest/java/org/elasticsearch/xpack/downsample/DataStreamLifecycleDownsampleIT.java +++ b/x-pack/plugin/downsample/src/internalClusterTest/java/org/elasticsearch/xpack/downsample/DataStreamLifecycleDownsampleIT.java @@ -43,10 +43,6 @@ protected Collection> nodePlugins() { return List.of(DataStreamsPlugin.class, LocalStateCompositeXPackPlugin.class, Downsample.class, AggregateMetricMapperPlugin.class); } - protected boolean ignoreExternalCluster() { - return true; - } - @Override protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { Settings.Builder settings = Settings.builder().put(super.nodeSettings(nodeOrdinal, otherSettings)); diff --git a/x-pack/plugin/eql/src/internalClusterTest/java/org/elasticsearch/xpack/eql/action/RestEqlCancellationIT.java b/x-pack/plugin/eql/src/internalClusterTest/java/org/elasticsearch/xpack/eql/action/RestEqlCancellationIT.java index 31f2a4e178c91..6ae49ea7416bb 100644 --- a/x-pack/plugin/eql/src/internalClusterTest/java/org/elasticsearch/xpack/eql/action/RestEqlCancellationIT.java +++ b/x-pack/plugin/eql/src/internalClusterTest/java/org/elasticsearch/xpack/eql/action/RestEqlCancellationIT.java @@ -139,8 +139,4 @@ public void testRestCancellation() throws Exception { expectThrows(CancellationException.class, future::actionGet); } - @Override - protected boolean ignoreExternalCluster() { - return true; - } } diff --git a/x-pack/plugin/ilm/src/internalClusterTest/java/org/elasticsearch/xpack/ilm/ClusterStateWaitThresholdBreachTests.java b/x-pack/plugin/ilm/src/internalClusterTest/java/org/elasticsearch/xpack/ilm/ClusterStateWaitThresholdBreachTests.java index 83eee53195c8c..da3323966fb94 100644 --- a/x-pack/plugin/ilm/src/internalClusterTest/java/org/elasticsearch/xpack/ilm/ClusterStateWaitThresholdBreachTests.java +++ b/x-pack/plugin/ilm/src/internalClusterTest/java/org/elasticsearch/xpack/ilm/ClusterStateWaitThresholdBreachTests.java @@ -62,11 +62,6 @@ public void refreshDataStreamAndPolicy() { managedIndex = "index-" + randomAlphaOfLengthBetween(10, 15).toLowerCase(Locale.ROOT); } - @Override - protected boolean ignoreExternalCluster() { - return true; - } - @Override protected Collection> nodePlugins() { return Arrays.asList(LocalStateCompositeXPackPlugin.class, IndexLifecycle.class, Ccr.class); diff --git a/x-pack/plugin/ilm/src/internalClusterTest/java/org/elasticsearch/xpack/ilm/DataStreamAndIndexLifecycleMixingTests.java b/x-pack/plugin/ilm/src/internalClusterTest/java/org/elasticsearch/xpack/ilm/DataStreamAndIndexLifecycleMixingTests.java index f1f1e1b967d5f..668cc4121b7b5 100644 --- a/x-pack/plugin/ilm/src/internalClusterTest/java/org/elasticsearch/xpack/ilm/DataStreamAndIndexLifecycleMixingTests.java +++ b/x-pack/plugin/ilm/src/internalClusterTest/java/org/elasticsearch/xpack/ilm/DataStreamAndIndexLifecycleMixingTests.java @@ -80,10 +80,6 @@ protected Collection> nodePlugins() { return List.of(LocalStateCompositeXPackPlugin.class, IndexLifecycle.class, DataStreamsPlugin.class); } - protected boolean ignoreExternalCluster() { - return true; - } - @Override protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { Settings.Builder settings = Settings.builder().put(super.nodeSettings(nodeOrdinal, otherSettings)); diff --git a/x-pack/plugin/ilm/src/internalClusterTest/java/org/elasticsearch/xpack/ilm/DataTiersMigrationsTests.java b/x-pack/plugin/ilm/src/internalClusterTest/java/org/elasticsearch/xpack/ilm/DataTiersMigrationsTests.java index a2cdc377a20e2..9dfc3ddcda91e 100644 --- a/x-pack/plugin/ilm/src/internalClusterTest/java/org/elasticsearch/xpack/ilm/DataTiersMigrationsTests.java +++ b/x-pack/plugin/ilm/src/internalClusterTest/java/org/elasticsearch/xpack/ilm/DataTiersMigrationsTests.java @@ -55,11 +55,6 @@ public void refreshDataStreamAndPolicy() { managedIndex = "index-" + randomAlphaOfLengthBetween(10, 15).toLowerCase(Locale.ROOT); } - @Override - protected boolean ignoreExternalCluster() { - return true; - } - @Override protected Collection> nodePlugins() { return Arrays.asList(LocalStateCompositeXPackPlugin.class, IndexLifecycle.class); diff --git a/x-pack/plugin/ilm/src/internalClusterTest/java/org/elasticsearch/xpack/ilm/IndexLifecycleInitialisationTests.java b/x-pack/plugin/ilm/src/internalClusterTest/java/org/elasticsearch/xpack/ilm/IndexLifecycleInitialisationTests.java index 0a252a0d62958..069771515d1b6 100644 --- a/x-pack/plugin/ilm/src/internalClusterTest/java/org/elasticsearch/xpack/ilm/IndexLifecycleInitialisationTests.java +++ b/x-pack/plugin/ilm/src/internalClusterTest/java/org/elasticsearch/xpack/ilm/IndexLifecycleInitialisationTests.java @@ -109,11 +109,6 @@ protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { return nodeSettings.build(); } - @Override - protected boolean ignoreExternalCluster() { - return true; - } - @Override protected Collection> nodePlugins() { return Arrays.asList(LocalStateCompositeXPackPlugin.class, IndexLifecycle.class, TestILMPlugin.class); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java index 46a4e008cc752..e88464c1ff5c4 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java @@ -25,10 +25,12 @@ import org.elasticsearch.cluster.metadata.ComposableIndexTemplate; import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.metadata.Template; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.compress.CompressedXContent; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkModule; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.core.PathUtils; import org.elasticsearch.datastreams.DataStreamsPlugin; import org.elasticsearch.env.Environment; @@ -49,7 +51,9 @@ import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.ExternalTestCluster; import org.elasticsearch.test.SecuritySettingsSourceField; +import org.elasticsearch.test.TestCluster; import org.elasticsearch.transport.netty4.Netty4Plugin; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xpack.autoscaling.Autoscaling; @@ -96,7 +100,10 @@ import java.io.IOException; import java.io.UncheckedIOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; import java.net.URISyntaxException; +import java.net.URL; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; @@ -165,8 +172,7 @@ protected Function getClientWrapper() { return client -> client.filterWithHeader(headers); } - @Override - protected Settings externalClusterClientSettings() { + private Settings externalClusterClientSettings() { final Path home = createTempDir(); final Path xpackConf = home.resolve("config"); try { @@ -207,6 +213,35 @@ protected Settings externalClusterClientSettings() { return builder.build(); } + @Override + protected TestCluster buildTestCluster(Scope scope, long seed) throws IOException { + final String clusterAddresses = System.getProperty(TESTS_CLUSTER); + assertTrue(TESTS_CLUSTER + " must be set", Strings.hasLength(clusterAddresses)); + if (scope == Scope.TEST) { + throw new IllegalArgumentException("Cannot run TEST scope test with " + TESTS_CLUSTER); + } + final String clusterName = System.getProperty(TESTS_CLUSTER_NAME); + if (Strings.isNullOrEmpty(clusterName)) { + throw new IllegalArgumentException("External test cluster name must be provided"); + } + final String[] stringAddresses = clusterAddresses.split(","); + final TransportAddress[] transportAddresses = new TransportAddress[stringAddresses.length]; + int i = 0; + for (String stringAddress : stringAddresses) { + URL url = new URL("http://" + stringAddress); + InetAddress inetAddress = InetAddress.getByName(url.getHost()); + transportAddresses[i++] = new TransportAddress(new InetSocketAddress(inetAddress, url.getPort())); + } + return new ExternalTestCluster( + createTempDir(), + externalClusterClientSettings(), + nodePlugins(), + getClientWrapper(), + clusterName, + transportAddresses + ); + } + protected void cleanUp() { setUpgradeModeTo(false); cleanUpResources(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java index 47ed0e9e269ab..e9a89a81f62e2 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java @@ -122,11 +122,6 @@ public abstract class BaseMlIntegTestCase extends ESIntegTestCase { // all the tasks that should be excluded from the cleanup jobs because they are not related to the tests. private static final Set UNRELATED_TASKS = Set.of(TransportListTasksAction.TYPE.name(), HealthNode.TASK_NAME); - @Override - protected boolean ignoreExternalCluster() { - return true; - } - @Override protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { Settings.Builder settings = Settings.builder().put(super.nodeSettings(nodeOrdinal, otherSettings)); diff --git a/x-pack/plugin/monitoring/src/internalClusterTest/java/org/elasticsearch/xpack/monitoring/exporter/http/HttpExporterIT.java b/x-pack/plugin/monitoring/src/internalClusterTest/java/org/elasticsearch/xpack/monitoring/exporter/http/HttpExporterIT.java index 256d12cd07510..36246902e5597 100644 --- a/x-pack/plugin/monitoring/src/internalClusterTest/java/org/elasticsearch/xpack/monitoring/exporter/http/HttpExporterIT.java +++ b/x-pack/plugin/monitoring/src/internalClusterTest/java/org/elasticsearch/xpack/monitoring/exporter/http/HttpExporterIT.java @@ -128,11 +128,6 @@ public void stopWebServer() { } } - @Override - protected boolean ignoreExternalCluster() { - return true; - } - private Settings.Builder secureSettings(String password) { mockSecureSettings.setString("xpack.monitoring.exporters._http.auth.secure_password", password); return baseSettings().setSecureSettings(mockSecureSettings); diff --git a/x-pack/plugin/monitoring/src/internalClusterTest/java/org/elasticsearch/xpack/monitoring/exporter/http/HttpExporterSslIT.java b/x-pack/plugin/monitoring/src/internalClusterTest/java/org/elasticsearch/xpack/monitoring/exporter/http/HttpExporterSslIT.java index db37b09095e61..3b61b0496c64d 100644 --- a/x-pack/plugin/monitoring/src/internalClusterTest/java/org/elasticsearch/xpack/monitoring/exporter/http/HttpExporterSslIT.java +++ b/x-pack/plugin/monitoring/src/internalClusterTest/java/org/elasticsearch/xpack/monitoring/exporter/http/HttpExporterSslIT.java @@ -52,11 +52,6 @@ public static void cleanUpStatics() { } } - @Override - protected boolean ignoreExternalCluster() { - return true; - } - @Override protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { final Path truststore = getDataPath("/org/elasticsearch/xpack/monitoring/exporter/http/testnode.jks"); diff --git a/x-pack/plugin/profiling/src/internalClusterTest/java/org/elasticsearch/xpack/profiling/ProfilingTestCase.java b/x-pack/plugin/profiling/src/internalClusterTest/java/org/elasticsearch/xpack/profiling/ProfilingTestCase.java index 199662f9ca1f6..29981c8e2f2a3 100644 --- a/x-pack/plugin/profiling/src/internalClusterTest/java/org/elasticsearch/xpack/profiling/ProfilingTestCase.java +++ b/x-pack/plugin/profiling/src/internalClusterTest/java/org/elasticsearch/xpack/profiling/ProfilingTestCase.java @@ -66,11 +66,6 @@ protected boolean addMockHttpTransport() { return false; // enable http } - @Override - protected boolean ignoreExternalCluster() { - return true; - } - private void indexDoc(String index, String id, Map source) { DocWriteResponse indexResponse = client().prepareIndex(index).setId(id).setSource(source).setCreate(true).get(); assertEquals(RestStatus.CREATED, indexResponse.status()); diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/ShrinkIndexWithSecurityTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/ShrinkIndexWithSecurityTests.java index c5efabfca13db..c3016a810c27f 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/ShrinkIndexWithSecurityTests.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/ShrinkIndexWithSecurityTests.java @@ -23,11 +23,6 @@ @ClusterScope(minNumDataNodes = 2) public class ShrinkIndexWithSecurityTests extends SecurityIntegTestCase { - @Override - protected final boolean ignoreExternalCluster() { - return true; - } - @Override protected int minimumNumberOfShards() { return 2; diff --git a/x-pack/plugin/sql/src/internalClusterTest/java/org/elasticsearch/xpack/sql/action/RestSqlCancellationIT.java b/x-pack/plugin/sql/src/internalClusterTest/java/org/elasticsearch/xpack/sql/action/RestSqlCancellationIT.java index 0100634766bfe..48ee5b05ffe0e 100644 --- a/x-pack/plugin/sql/src/internalClusterTest/java/org/elasticsearch/xpack/sql/action/RestSqlCancellationIT.java +++ b/x-pack/plugin/sql/src/internalClusterTest/java/org/elasticsearch/xpack/sql/action/RestSqlCancellationIT.java @@ -183,8 +183,4 @@ private static String queryAsJson(String query) throws IOException { return out.bytes().utf8ToString(); } - @Override - protected boolean ignoreExternalCluster() { - return true; - } } diff --git a/x-pack/plugin/sql/src/internalClusterTest/java/org/elasticsearch/xpack/sql/action/SqlLicenseIT.java b/x-pack/plugin/sql/src/internalClusterTest/java/org/elasticsearch/xpack/sql/action/SqlLicenseIT.java index 3e0841873b6db..0374818d7e3b5 100644 --- a/x-pack/plugin/sql/src/internalClusterTest/java/org/elasticsearch/xpack/sql/action/SqlLicenseIT.java +++ b/x-pack/plugin/sql/src/internalClusterTest/java/org/elasticsearch/xpack/sql/action/SqlLicenseIT.java @@ -36,10 +36,6 @@ @LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/37320") public class SqlLicenseIT extends AbstractLicensesIntegrationTestCase { - @Override - protected boolean ignoreExternalCluster() { - return true; - } @Before public void resetLicensing() throws Exception { From 23734f93a3f85e14869d503de3f0122d2fba835f Mon Sep 17 00:00:00 2001 From: Brian Seeders Date: Mon, 6 Nov 2023 15:43:44 -0500 Subject: [PATCH 12/21] [ci] Migrate pull requests to running in Buildkite (#101843) --- .buildkite/pull-requests.json | 3 --- ...rch+pull-request+build-benchmark-part1.yml | 21 ++++++++----------- ...rch+pull-request+build-benchmark-part2.yml | 21 ++++++++----------- ...rch+pull-request+bwc-snapshots-windows.yml | 10 +++------ ...asticsearch+pull-request+bwc-snapshots.yml | 9 +++----- ...lasticsearch+pull-request+cloud-deploy.yml | 9 +++----- ...+elasticsearch+pull-request+docs-check.yml | 8 +++---- ...ticsearch+pull-request+eql-correctness.yml | 8 +++---- ...ticsearch+pull-request+example-plugins.yml | 7 +++---- ...ic+elasticsearch+pull-request+full-bwc.yml | 10 +++------ ...ll-request+packaging-tests-unix-sample.yml | 7 ++----- ...arch+pull-request+packaging-tests-unix.yml | 8 ++----- ...-request+packaging-tests-windows-nojdk.yml | 14 +++++-------- ...t+packaging-tests-windows-sample-nojdk.yml | 13 +++++------- ...request+packaging-tests-windows-sample.yml | 11 ++++------ ...h+pull-request+packaging-tests-windows.yml | 12 ++++------- ...h+pull-request+packaging-upgrade-tests.yml | 8 ++----- ...elasticsearch+pull-request+part-1-fips.yml | 10 +++------ ...sticsearch+pull-request+part-1-windows.yml | 10 +++------ ...elasticsearch+pull-request+part-2-fips.yml | 10 +++------ ...sticsearch+pull-request+part-2-windows.yml | 10 +++------ ...elasticsearch+pull-request+part-3-fips.yml | 10 +++------ ...sticsearch+pull-request+part-3-windows.yml | 10 +++------ ...stic+elasticsearch+pull-request+part-3.yml | 8 +++---- ...c+elasticsearch+pull-request+precommit.yml | 9 +++----- ...asticsearch+pull-request+release-tests.yml | 9 +++----- ...search+pull-request+rest-compatibility.yml | 8 +++---- .ci/templates.t/pull-request-gradle-unix.yml | 8 +++---- 28 files changed, 96 insertions(+), 185 deletions(-) diff --git a/.buildkite/pull-requests.json b/.buildkite/pull-requests.json index 456fce6aba519..b59bdc79ad293 100644 --- a/.buildkite/pull-requests.json +++ b/.buildkite/pull-requests.json @@ -12,9 +12,6 @@ "build_on_commit": true, "build_on_comment": true, "trigger_comment_regex": "run\\W+elasticsearch-ci.+", - "labels": [ - "buildkite-opt-in" - ], "cancel_intermediate_builds": true, "cancel_intermediate_builds_on_comment": false }, diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+build-benchmark-part1.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+build-benchmark-part1.yml index a3f1345a07f13..173c8dbf805c0 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+build-benchmark-part1.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+build-benchmark-part1.yml @@ -23,7 +23,8 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/build-bench.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/build-bench.*' github-hooks: true status-context: elasticsearch-ci/build-benchmark-part1 cancel-builds-on-update: true @@ -32,21 +33,17 @@ excluded-regions: - ^docs/.* - ^x-pack/docs/.* - white-list-labels: - - 'build-benchmark' - black-list-labels: - - 'buildkite-opt-in' builders: - inject: - properties-file: '.ci/java-versions.properties' + properties-file: ".ci/java-versions.properties" properties-content: | JAVA_HOME=$HOME/.java/$ES_BUILD_JAVA JAVA8_HOME=$HOME/.java/java8 JAVA11_HOME=$HOME/.java/java11 - shell: | - #!/usr/local/bin/runbld --redirect-stderr - $WORKSPACE/.ci/scripts/run-gradle.sh :build-tools-internal:bootstrapPerformanceTests - $WORKSPACE/.ci/scripts/install-gradle-profiler.sh - $WORKSPACE/.ci/scripts/run-gradle-profiler.sh --benchmark --scenario-file build-tools-internal/build/performanceTests/elasticsearch-build-benchmark-part1.scenarios --project-dir . --output-dir profile-out - mkdir $WORKSPACE/build - tar -czf $WORKSPACE/build/${BUILD_NUMBER}.tar.bz2 profile-out + #!/usr/local/bin/runbld --redirect-stderr + $WORKSPACE/.ci/scripts/run-gradle.sh :build-tools-internal:bootstrapPerformanceTests + $WORKSPACE/.ci/scripts/install-gradle-profiler.sh + $WORKSPACE/.ci/scripts/run-gradle-profiler.sh --benchmark --scenario-file build-tools-internal/build/performanceTests/elasticsearch-build-benchmark-part1.scenarios --project-dir . --output-dir profile-out + mkdir $WORKSPACE/build + tar -czf $WORKSPACE/build/${BUILD_NUMBER}.tar.bz2 profile-out diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+build-benchmark-part2.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+build-benchmark-part2.yml index f1b11ab1ec75a..5f25c9153040e 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+build-benchmark-part2.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+build-benchmark-part2.yml @@ -23,7 +23,8 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/build-bench.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/build-bench.*' github-hooks: true status-context: elasticsearch-ci/build-benchmark-part2 cancel-builds-on-update: true @@ -32,21 +33,17 @@ excluded-regions: - ^docs/.* - ^x-pack/docs/.* - white-list-labels: - - 'build-benchmark' - black-list-labels: - - 'buildkite-opt-in' builders: - inject: - properties-file: '.ci/java-versions.properties' + properties-file: ".ci/java-versions.properties" properties-content: | JAVA_HOME=$HOME/.java/$ES_BUILD_JAVA JAVA8_HOME=$HOME/.java/java8 JAVA11_HOME=$HOME/.java/java11 - shell: | - #!/usr/local/bin/runbld --redirect-stderr - $WORKSPACE/.ci/scripts/run-gradle.sh :build-tools-internal:bootstrapPerformanceTests - $WORKSPACE/.ci/scripts/install-gradle-profiler.sh - $WORKSPACE/.ci/scripts/run-gradle-profiler.sh --benchmark --scenario-file build-tools-internal/build/performanceTests/elasticsearch-build-benchmark-part2.scenarios --project-dir . --output-dir profile-out - mkdir $WORKSPACE/build - tar -czf $WORKSPACE/build/${BUILD_NUMBER}.tar.bz2 profile-out + #!/usr/local/bin/runbld --redirect-stderr + $WORKSPACE/.ci/scripts/run-gradle.sh :build-tools-internal:bootstrapPerformanceTests + $WORKSPACE/.ci/scripts/install-gradle-profiler.sh + $WORKSPACE/.ci/scripts/run-gradle-profiler.sh --benchmark --scenario-file build-tools-internal/build/performanceTests/elasticsearch-build-benchmark-part2.scenarios --project-dir . --output-dir profile-out + mkdir $WORKSPACE/build + tar -czf $WORKSPACE/build/${BUILD_NUMBER}.tar.bz2 profile-out diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+bwc-snapshots-windows.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+bwc-snapshots-windows.yml index c0ed9bf998159..1a0652204b2f2 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+bwc-snapshots-windows.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+bwc-snapshots-windows.yml @@ -16,7 +16,8 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/bwc-snapshots-windows.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/bwc-snapshots-windows.*' github-hooks: true status-context: elasticsearch-ci/bwc-snapshots-windows cancel-builds-on-update: true @@ -25,11 +26,6 @@ excluded-regions: - ^docs/.* - ^x-pack/docs/.* - white-list-labels: - - 'test-windows' - black-list-labels: - - '>test-mute' - - 'buildkite-opt-in' axes: - axis: type: slave @@ -42,7 +38,7 @@ name: "BWC_VERSION" builders: - inject: - properties-file: '.ci/java-versions.properties' + properties-file: ".ci/java-versions.properties" properties-content: | JAVA_HOME=$USERPROFILE\\.java\\$ES_BUILD_JAVA JAVA11_HOME=$USERPROFILE\\.java\\java11 diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+bwc-snapshots.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+bwc-snapshots.yml index 676f5f6f629b7..9a20115a72f1c 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+bwc-snapshots.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+bwc-snapshots.yml @@ -16,17 +16,14 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/bwc.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/bwc.*' github-hooks: true status-context: elasticsearch-ci/bwc cancel-builds-on-update: true excluded-regions: - ^docs/.* - ^x-pack/docs/.* - black-list-labels: - - '>test-mute' - - 'test-full-bwc' - - 'buildkite-opt-in' axes: - axis: type: slave @@ -39,7 +36,7 @@ name: "BWC_VERSION" builders: - inject: - properties-file: '.ci/java-versions.properties' + properties-file: ".ci/java-versions.properties" properties-content: | JAVA_HOME=$HOME/.java/$ES_BUILD_JAVA JAVA8_HOME=$HOME/.java/java8 diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+cloud-deploy.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+cloud-deploy.yml index 24548954d8a10..a6f42c147dbeb 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+cloud-deploy.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+cloud-deploy.yml @@ -15,7 +15,8 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/cloud-deploy.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/cloud-deploy.*' github-hooks: true status-context: elasticsearch-ci/cloud-deploy cancel-builds-on-update: true @@ -24,13 +25,9 @@ excluded-regions: - ^docs/.* - ^x-pack/docs/.* - white-list-labels: - - 'cloud-deploy' - black-list-labels: - - 'buildkite-opt-in' builders: - inject: - properties-file: '.ci/java-versions.properties' + properties-file: ".ci/java-versions.properties" properties-content: | JAVA_HOME=$HOME/.java/$ES_BUILD_JAVA - shell: | diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+docs-check.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+docs-check.yml index c766b4379a1f6..58b273de2beb9 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+docs-check.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+docs-check.yml @@ -14,19 +14,17 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/docs-check.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/docs-check.*' github-hooks: true status-context: elasticsearch-ci/docs-check cancel-builds-on-update: true included-regions: - ^docs/.* - ^x-pack/docs/.* - black-list-labels: - - '>test-mute' - - 'buildkite-opt-in' builders: - inject: - properties-file: '.ci/java-versions.properties' + properties-file: ".ci/java-versions.properties" properties-content: | JAVA_HOME=$HOME/.java/$ES_BUILD_JAVA JAVA8_HOME=$HOME/.java/java8 diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+eql-correctness.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+eql-correctness.yml index 0b9eea62ad9bf..c1789e3b8595a 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+eql-correctness.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+eql-correctness.yml @@ -14,7 +14,8 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/eql-correctness.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/eql-correctness.*' github-hooks: true status-context: elasticsearch-ci/eql-correctness cancel-builds-on-update: true @@ -23,12 +24,9 @@ excluded-regions: - ^docs/.* - ^x-pack/docs/.* - black-list-labels: - - '>test-mute' - - 'buildkite-opt-in' builders: - inject: - properties-file: '.ci/java-versions.properties' + properties-file: ".ci/java-versions.properties" properties-content: | JAVA_HOME=$HOME/.java/$ES_BUILD_JAVA - shell: | diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+example-plugins.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+example-plugins.yml index 320a9c6176d5f..339fcd17ec77c 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+example-plugins.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+example-plugins.yml @@ -14,7 +14,8 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/example-plugins.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/example-plugins.*' github-hooks: true status-context: elasticsearch-ci/example-plugins cancel-builds-on-update: true @@ -23,11 +24,9 @@ - build-tools/.* - build-tools-internal/.* - plugins/examples/.* - black-list-labels: - - 'buildkite-opt-in' builders: - inject: - properties-file: '.ci/java-versions.properties' + properties-file: ".ci/java-versions.properties" properties-content: | JAVA_HOME=$HOME/.java/$ES_BUILD_JAVA JAVA8_HOME=$HOME/.java/java8 diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+full-bwc.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+full-bwc.yml index 2a7920e4bae89..4bb38a810e8f1 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+full-bwc.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+full-bwc.yml @@ -16,18 +16,14 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/full-bwc.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/full-bwc.*' github-hooks: true status-context: elasticsearch-ci/full-bwc cancel-builds-on-update: true excluded-regions: - ^docs/.* - ^x-pack/docs/.* - white-list-labels: - - 'test-full-bwc' - black-list-labels: - - '>test-mute' - - 'buildkite-opt-in' axes: - axis: type: slave @@ -40,7 +36,7 @@ name: "BWC_VERSION" builders: - inject: - properties-file: '.ci/java-versions.properties' + properties-file: ".ci/java-versions.properties" properties-content: | JAVA_HOME=$HOME/.java/$ES_BUILD_JAVA JAVA8_HOME=$HOME/.java/java8 diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-unix-sample.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-unix-sample.yml index 2d4f372142512..23d94e665f8a3 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-unix-sample.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-unix-sample.yml @@ -15,7 +15,8 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/packaging-tests-unix-sample.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/packaging-tests-unix-sample.*' github-hooks: true status-context: elasticsearch-ci/packaging-tests-unix-sample cancel-builds-on-update: true @@ -24,10 +25,6 @@ excluded-regions: - ^docs/.* - ^x-pack/docs/.* - black-list-labels: - - ">test-mute" - - ":Delivery/Packaging" - - "buildkite-opt-in" axes: - axis: type: label-expression diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-unix.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-unix.yml index af1d3f493eeb0..901f7bcac3caa 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-unix.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-unix.yml @@ -15,7 +15,8 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/packaging-tests-unix.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/packaging-tests-unix.*' github-hooks: true status-context: elasticsearch-ci/packaging-tests-unix cancel-builds-on-update: true @@ -24,11 +25,6 @@ excluded-regions: - ^docs/.* - ^x-pack/docs/.* - white-list-labels: - - ":Delivery/Packaging" - black-list-labels: - - ">test-mute" - - "buildkite-opt-in" axes: - axis: type: label-expression diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-windows-nojdk.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-windows-nojdk.yml index ea4097b1a0b93..c39326380fdaf 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-windows-nojdk.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-windows-nojdk.yml @@ -17,7 +17,8 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/packaging-tests-windows.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/packaging-tests-windows.*' github-hooks: true status-context: elasticsearch-ci/packaging-tests-windows cancel-builds-on-update: true @@ -28,11 +29,6 @@ excluded-regions: - ^docs/.* - ^x-pack/docs/.* - white-list-labels: - - ':Delivery/Packaging' - black-list-labels: - - '>test-mute' - - 'buildkite-opt-in' axes: - axis: type: label-expression @@ -46,11 +42,11 @@ type: user-defined name: PACKAGING_TASK values: - - 'default-windows-archive' - - 'default-windows-archive-no-jdk' + - "default-windows-archive" + - "default-windows-archive-no-jdk" builders: - inject: - properties-file: '.ci/java-versions.properties' + properties-file: ".ci/java-versions.properties" properties-content: | JAVA_HOME=$USERPROFILE\\.java\\$ES_BUILD_JAVA - batch: | diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-windows-sample-nojdk.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-windows-sample-nojdk.yml index ec644445ef8de..35705f7e759b1 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-windows-sample-nojdk.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-windows-sample-nojdk.yml @@ -17,7 +17,8 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/packaging-tests-windows-sample.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/packaging-tests-windows-sample.*' github-hooks: true status-context: elasticsearch-ci/packaging-tests-windows-sample cancel-builds-on-update: true @@ -28,10 +29,6 @@ excluded-regions: - ^docs/.* - ^x-pack/docs/.* - black-list-labels: - - '>test-mute' - - ':Delivery/Packaging' - - 'buildkite-opt-in' axes: - axis: type: label-expression @@ -42,11 +39,11 @@ type: user-defined name: PACKAGING_TASK values: - - 'default-windows-archive' - - 'default-windows-archive-no-jdk' + - "default-windows-archive" + - "default-windows-archive-no-jdk" builders: - inject: - properties-file: '.ci/java-versions.properties' + properties-file: ".ci/java-versions.properties" properties-content: | JAVA_HOME=$USERPROFILE\\.java\\$ES_BUILD_JAVA - batch: | diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-windows-sample.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-windows-sample.yml index 242e137cb1d83..8a4eff2d30822 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-windows-sample.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-windows-sample.yml @@ -17,7 +17,8 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/packaging-tests-windows-sample.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/packaging-tests-windows-sample.*' github-hooks: true status-context: elasticsearch-ci/packaging-tests-windows-sample cancel-builds-on-update: true @@ -27,10 +28,6 @@ excluded-regions: - ^docs/.* - ^x-pack/docs/.* - black-list-labels: - - '>test-mute' - - ':Delivery/Packaging' - - 'buildkite-opt-in' axes: - axis: type: label-expression @@ -41,10 +38,10 @@ type: user-defined name: PACKAGING_TASK values: - - 'default-windows-archive' + - "default-windows-archive" builders: - inject: - properties-file: '.ci/java-versions.properties' + properties-file: ".ci/java-versions.properties" properties-content: | JAVA_HOME=$USERPROFILE\\.java\\$ES_BUILD_JAVA - batch: | diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-windows.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-windows.yml index a2ffc7b4050ec..d109477620386 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-windows.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-tests-windows.yml @@ -17,7 +17,8 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/packaging-tests-windows.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/packaging-tests-windows.*' github-hooks: true status-context: elasticsearch-ci/packaging-tests-windows cancel-builds-on-update: true @@ -28,11 +29,6 @@ excluded-regions: - ^docs/.* - ^x-pack/docs/.* - white-list-labels: - - ':Delivery/Packaging' - black-list-labels: - - '>test-mute' - - 'buildkite-opt-in' axes: - axis: type: label-expression @@ -46,10 +42,10 @@ type: user-defined name: PACKAGING_TASK values: - - 'default-windows-archive' + - "default-windows-archive" builders: - inject: - properties-file: '.ci/java-versions.properties' + properties-file: ".ci/java-versions.properties" properties-content: | JAVA_HOME=$USERPROFILE\\.java\\$ES_BUILD_JAVA - batch: | diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-upgrade-tests.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-upgrade-tests.yml index 19ed5398e3e1d..0cc14224375fb 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-upgrade-tests.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+packaging-upgrade-tests.yml @@ -16,7 +16,8 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/packaging-upgrade-tests.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/packaging-upgrade-tests.*' github-hooks: true status-context: elasticsearch-ci/packaging-upgrade-tests cancel-builds-on-update: true @@ -25,11 +26,6 @@ excluded-regions: - ^docs/.* - ^x-pack/docs/.* - white-list-labels: - - ":Delivery/Packaging" - black-list-labels: - - ">test-mute" - - "buildkite-opt-in" axes: - axis: type: label-expression diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+part-1-fips.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+part-1-fips.yml index a661230d3b93b..aaeeed2f0d52b 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+part-1-fips.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+part-1-fips.yml @@ -14,7 +14,8 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/part-1-fips.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/part-1-fips.*' github-hooks: true status-context: elasticsearch-ci/part-1-fips cancel-builds-on-update: true @@ -23,15 +24,10 @@ excluded-regions: - ^docs/.* - ^x-pack/docs/.* - white-list-labels: - - 'Team:Security' - black-list-labels: - - '>test-mute' - - 'buildkite-opt-in' builders: - inject: # Use FIPS-specific Java versions - properties-file: '.ci/java-versions-fips.properties' + properties-file: ".ci/java-versions-fips.properties" properties-content: | JAVA_HOME=$HOME/.java/$ES_BUILD_JAVA JAVA16_HOME=$HOME/.java/openjdk16 diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+part-1-windows.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+part-1-windows.yml index d7afdd0ac3277..8b348f94026e0 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+part-1-windows.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+part-1-windows.yml @@ -15,7 +15,8 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/part-1-windows.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/part-1-windows.*' github-hooks: true status-context: elasticsearch-ci/part-1-windows cancel-builds-on-update: true @@ -24,14 +25,9 @@ excluded-regions: - ^docs/.* - ^x-pack/docs/.* - white-list-labels: - - 'test-windows' - black-list-labels: - - '>test-mute' - - 'buildkite-opt-in' builders: - inject: - properties-file: '.ci/java-versions.properties' + properties-file: ".ci/java-versions.properties" properties-content: | JAVA_HOME=$USERPROFILE\\.java\\$ES_BUILD_JAVA JAVA11_HOME=$USERPROFILE\\.java\\java11 diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+part-2-fips.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+part-2-fips.yml index 913820709dabc..11d168d7567d9 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+part-2-fips.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+part-2-fips.yml @@ -14,7 +14,8 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/part-2-fips.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/part-2-fips.*' github-hooks: true status-context: elasticsearch-ci/part-2-fips cancel-builds-on-update: true @@ -23,15 +24,10 @@ excluded-regions: - ^docs/.* - ^x-pack/docs/.* - white-list-labels: - - 'Team:Security' - black-list-labels: - - '>test-mute' - - 'buildkite-opt-in' builders: - inject: # Use FIPS-specific Java versions - properties-file: '.ci/java-versions-fips.properties' + properties-file: ".ci/java-versions-fips.properties" properties-content: | JAVA_HOME=$HOME/.java/$ES_BUILD_JAVA JAVA16_HOME=$HOME/.java/openjdk16 diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+part-2-windows.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+part-2-windows.yml index ae590872be16e..927117cc3bced 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+part-2-windows.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+part-2-windows.yml @@ -15,7 +15,8 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/part-2-windows.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/part-2-windows.*' github-hooks: true status-context: elasticsearch-ci/part-2-windows cancel-builds-on-update: true @@ -24,14 +25,9 @@ excluded-regions: - ^docs/.* - ^x-pack/docs/.* - white-list-labels: - - 'test-windows' - black-list-labels: - - '>test-mute' - - 'buildkite-opt-in' builders: - inject: - properties-file: '.ci/java-versions.properties' + properties-file: ".ci/java-versions.properties" properties-content: | JAVA_HOME=$USERPROFILE\\.java\\$ES_BUILD_JAVA JAVA11_HOME=$USERPROFILE\\.java\\java11 diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+part-3-fips.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+part-3-fips.yml index 6bf6544d40310..3b7984ecbdc43 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+part-3-fips.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+part-3-fips.yml @@ -14,7 +14,8 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/part-3-fips.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/part-3-fips.*' github-hooks: true status-context: elasticsearch-ci/part-3-fips cancel-builds-on-update: true @@ -24,15 +25,10 @@ excluded-regions: - ^docs/.* - ^x-pack/docs/.* - white-list-labels: - - 'Team:Security' - black-list-labels: - - '>test-mute' - - 'buildkite-opt-in' builders: - inject: # Use FIPS-specific Java versions - properties-file: '.ci/java-versions-fips.properties' + properties-file: ".ci/java-versions-fips.properties" properties-content: | JAVA_HOME=$HOME/.java/$ES_BUILD_JAVA JAVA16_HOME=$HOME/.java/openjdk16 diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+part-3-windows.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+part-3-windows.yml index 58bad17954b24..7e835b85015ba 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+part-3-windows.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+part-3-windows.yml @@ -15,7 +15,8 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/part-3-windows.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/part-3-windows.*' github-hooks: true status-context: elasticsearch-ci/part-3-windows cancel-builds-on-update: true @@ -25,14 +26,9 @@ excluded-regions: - ^docs/.* - ^x-pack/docs/.* - white-list-labels: - - 'test-windows' - black-list-labels: - - '>test-mute' - - 'buildkite-opt-in' builders: - inject: - properties-file: '.ci/java-versions.properties' + properties-file: ".ci/java-versions.properties" properties-content: | JAVA_HOME=$USERPROFILE\\.java\\$ES_BUILD_JAVA JAVA11_HOME=$USERPROFILE\\.java\\java11 diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+part-3.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+part-3.yml index 0158b909903b4..e306657693f5f 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+part-3.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+part-3.yml @@ -14,22 +14,20 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/part-3.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/part-3.*' github-hooks: true status-context: elasticsearch-ci/part-3 cancel-builds-on-update: true excluded-regions: - ^docs/.* - ^x-pack/docs/.* - black-list-labels: - - '>test-mute' - - 'buildkite-opt-in' black-list-target-branches: - 6.8 - 7.17 builders: - inject: - properties-file: '.ci/java-versions.properties' + properties-file: ".ci/java-versions.properties" properties-content: | JAVA_HOME=$HOME/.java/$ES_BUILD_JAVA JAVA8_HOME=$HOME/.java/java8 diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+precommit.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+precommit.yml index 1267b6a21778e..3994164fba0f3 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+precommit.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+precommit.yml @@ -14,17 +14,14 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/precommit.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/precommit.*' github-hooks: true status-context: elasticsearch-ci/precommit cancel-builds-on-update: true - white-list-labels: - - '>test-mute' - black-list-labels: - - 'buildkite-opt-in' builders: - inject: - properties-file: '.ci/java-versions.properties' + properties-file: ".ci/java-versions.properties" properties-content: | JAVA_HOME=$HOME/.java/$ES_BUILD_JAVA JAVA8_HOME=$HOME/.java/java8 diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+release-tests.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+release-tests.yml index 1ab6bd1ce0e5d..a86496d7199f5 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+release-tests.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+release-tests.yml @@ -16,23 +16,20 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/release-tests.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/release-tests.*' github-hooks: true status-context: elasticsearch-ci/release-tests cancel-builds-on-update: true excluded-regions: - ^docs/.* - ^x-pack/docs/.* - white-list-labels: - - 'test-release' - black-list-labels: - - 'buildkite-opt-in' black-list-target-branches: - 7.15 - 6.8 builders: - inject: - properties-file: '.ci/java-versions.properties' + properties-file: ".ci/java-versions.properties" properties-content: | JAVA_HOME=$HOME/.java/$ES_BUILD_JAVA JAVA8_HOME=$HOME/.java/java8 diff --git a/.ci/jobs.t/elastic+elasticsearch+pull-request+rest-compatibility.yml b/.ci/jobs.t/elastic+elasticsearch+pull-request+rest-compatibility.yml index 216f8ceae2078..0ed86851c7f33 100644 --- a/.ci/jobs.t/elastic+elasticsearch+pull-request+rest-compatibility.yml +++ b/.ci/jobs.t/elastic+elasticsearch+pull-request+rest-compatibility.yml @@ -14,7 +14,8 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/rest-compatibility.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/rest-compatibility.*' github-hooks: true status-context: elasticsearch-ci/rest-compatibility cancel-builds-on-update: true @@ -26,12 +27,9 @@ excluded-regions: - ^docs/.* - ^x-pack/docs/.* - black-list-labels: - - '>test-mute' - - 'buildkite-opt-in' builders: - inject: - properties-file: '.ci/java-versions.properties' + properties-file: ".ci/java-versions.properties" properties-content: | JAVA_HOME=$HOME/.java/$ES_BUILD_JAVA JAVA8_HOME=$HOME/.java/java8 diff --git a/.ci/templates.t/pull-request-gradle-unix.yml b/.ci/templates.t/pull-request-gradle-unix.yml index c09e64c56f32d..7c0711a4e3a97 100644 --- a/.ci/templates.t/pull-request-gradle-unix.yml +++ b/.ci/templates.t/pull-request-gradle-unix.yml @@ -14,19 +14,17 @@ org-list: - elastic allow-whitelist-orgs-as-admins: true - trigger-phrase: '.*run\W+elasticsearch-ci/{pr-job}.*' + only-trigger-phrase: true + trigger-phrase: '.*run\W+jenkins\W+elasticsearch-ci/{pr-job}.*' github-hooks: true status-context: elasticsearch-ci/{pr-job} cancel-builds-on-update: true excluded-regions: - ^docs/.* - ^x-pack/docs/.* - black-list-labels: - - '>test-mute' - - 'buildkite-opt-in' builders: - inject: - properties-file: '.ci/java-versions.properties' + properties-file: ".ci/java-versions.properties" properties-content: | JAVA_HOME=$HOME/.java/$ES_BUILD_JAVA JAVA8_HOME=$HOME/.java/java8 From 461a17508e05fbaa61007e61bfe17e3e26f2197e Mon Sep 17 00:00:00 2001 From: David Turner Date: Mon, 6 Nov 2023 21:16:22 +0000 Subject: [PATCH 13/21] Separate NoOpClient and ThreadPool lifecycles (#101835) Since the removal of the transport client there are no _real_ clients with their own threadpools, but in tests we still have `NoOpClient` and `NoOpNodeClient` which sometimes implicitly create a threadpool, and terminate their threadpools on close too (even if the threadpool is owned by the caller). This commit detaches the lifecycle of the client from the lifecycle of its threadpool, moving the threadpool creation and termination up into the caller. --- .../DataStreamLifecycleServiceTests.java | 6 +- .../reindex/AsyncBulkByScrollActionTests.java | 9 ++- .../CreateIndexRequestBuilderTests.java | 7 +- .../action/bulk/Retry2Tests.java | 12 ++- .../elasticsearch/action/bulk/RetryTests.java | 12 ++- .../index/IndexRequestBuilderTests.java | 7 +- .../synonyms/PutSynonymRuleActionTests.java | 4 +- .../synonyms/PutSynonymsActionTests.java | 4 +- .../internal/OriginSettingClientTests.java | 33 +++++---- .../ParentTaskAssigningClientTests.java | 28 +++---- .../rest/RestControllerTests.java | 7 +- .../admin/indices/RestAnalyzeActionTests.java | 4 +- .../RestCatComponentTemplateActionTests.java | 11 ++- .../rest/action/cat/RestTasksActionTests.java | 8 +- .../action/document/RestBulkActionTests.java | 43 +++++------ .../scroll/RestClearScrollActionTests.java | 18 ++--- .../scroll/RestSearchScrollActionTests.java | 18 ++--- .../usage/UsageServiceTests.java | 3 +- .../AbstractQueryVectorBuilderTestCase.java | 11 ++- .../elasticsearch/test/client/NoOpClient.java | 19 +---- .../test/client/NoOpNodeClient.java | 23 ------ .../test/rest/RestActionTestCase.java | 12 ++- .../threadpool/TestThreadPool.java | 9 ++- .../AutoscalingNodesInfoServiceTests.java | 12 ++- .../validation/SourceDestValidatorTests.java | 19 +++-- .../core/ilm/CleanupShrinkIndexStepTests.java | 15 ++-- .../core/ilm/CleanupSnapshotStepTests.java | 8 +- .../core/ilm/CleanupTargetIndexStepTests.java | 15 ++-- .../core/ilm/CreateSnapshotStepTests.java | 22 ++++-- .../xpack/core/ilm/DownsampleStepTests.java | 6 +- .../core/ilm/MountSnapshotStepTests.java | 38 +++++----- ...pAliasesAndDeleteSourceIndexStepTests.java | 8 +- .../enrich/EnrichProcessorFactoryTests.java | 74 ++++++++++--------- ...stractRestEnterpriseSearchActionTests.java | 4 +- .../EnterpriseSearchBaseRestHandlerTests.java | 3 +- .../execution/sample/CircuitBreakerTests.java | 8 +- .../sequence/CircuitBreakerTests.java | 31 +++++--- .../execution/sequence/PITFailureTests.java | 8 +- .../ilm/IndexLifecycleTransitionTests.java | 16 ++-- .../TransportDeletePipelineActionTests.java | 8 +- .../TransportGetPipelineActionTests.java | 18 +++-- .../TrainedModelProviderTests.java | 52 ++++++++----- .../exporter/local/LocalExporterTests.java | 3 +- .../xpack/rollup/job/RollupJobTaskTests.java | 12 +-- .../action/role/PutRoleBuilderTests.java | 4 +- .../store/NativePrivilegeStoreTests.java | 6 +- .../action/SecurityBaseRestHandlerTests.java | 3 +- .../apikey/ApiKeyBaseRestHandlerTests.java | 3 +- .../action/role/RestPutRoleActionTests.java | 4 +- .../RestGetUserPrivilegesActionTests.java | 4 +- .../user/RestHasPrivilegesActionTests.java | 13 +--- .../action/user/RestPutUserActionTests.java | 4 +- .../xpack/slm/SnapshotRetentionTaskTests.java | 25 ++++--- .../TransformCheckpointServiceNodeTests.java | 16 ++-- .../TransformPrivilegeCheckerTests.java | 9 +-- .../action/TransformUpdaterTests.java | 16 ++-- .../ClientTransformIndexerTests.java | 22 ++++-- .../TransformIndexerFailureHandlingTests.java | 6 +- ...IndexerFailureOnStatePersistenceTests.java | 10 ++- .../TransformIndexerStateTests.java | 3 +- .../transforms/TransformIndexerTests.java | 3 +- .../transforms/TransformTaskTests.java | 11 ++- .../AggregationSchemaAndResultTests.java | 16 ++-- .../transforms/pivot/PivotTests.java | 67 +++++++++-------- .../transforms/pivot/SchemaUtilTests.java | 12 +-- 65 files changed, 523 insertions(+), 422 deletions(-) diff --git a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/lifecycle/DataStreamLifecycleServiceTests.java b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/lifecycle/DataStreamLifecycleServiceTests.java index bd6100c95b412..0ee168d130986 100644 --- a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/lifecycle/DataStreamLifecycleServiceTests.java +++ b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/lifecycle/DataStreamLifecycleServiceTests.java @@ -126,7 +126,6 @@ public class DataStreamLifecycleServiceTests extends ESTestCase { private ThreadPool threadPool; private DataStreamLifecycleService dataStreamLifecycleService; private List clientSeenRequests; - private Client client; private DoExecuteDelegate clientDelegate; private ClusterService clusterService; @@ -145,7 +144,7 @@ public void setupServices() { Clock clock = Clock.fixed(Instant.ofEpochMilli(now), ZoneId.of(randomFrom(ZoneId.getAvailableZoneIds()))); clientSeenRequests = new CopyOnWriteArrayList<>(); - client = getTransportRequestsRecordingClient(); + final Client client = getTransportRequestsRecordingClient(); AllocationService allocationService = new AllocationService( new AllocationDeciders( new HashSet<>( @@ -178,7 +177,6 @@ public void cleanup() { dataStreamLifecycleService.close(); clusterService.close(); threadPool.shutdownNow(); - client.close(); } public void testOperationsExecutedOnce() { @@ -1499,7 +1497,7 @@ private static DiscoveryNode getNode(String nodeId) { * (it does not even notify the listener), but tests can provide an implementation of clientDelegate to provide any needed behavior. */ private Client getTransportRequestsRecordingClient() { - return new NoOpClient(getTestName()) { + return new NoOpClient(threadPool) { @Override protected void doExecute( ActionType action, diff --git a/modules/reindex/src/test/java/org/elasticsearch/reindex/AsyncBulkByScrollActionTests.java b/modules/reindex/src/test/java/org/elasticsearch/reindex/AsyncBulkByScrollActionTests.java index 8878e988eb4fb..3c5a3eb2e40f9 100644 --- a/modules/reindex/src/test/java/org/elasticsearch/reindex/AsyncBulkByScrollActionTests.java +++ b/modules/reindex/src/test/java/org/elasticsearch/reindex/AsyncBulkByScrollActionTests.java @@ -126,6 +126,7 @@ public class AsyncBulkByScrollActionTests extends ESTestCase { private PlainActionFuture listener; private String scrollId; private ThreadPool threadPool; + private ThreadPool clientThreadPool; private TaskManager taskManager; private BulkByScrollTask testTask; private WorkerBulkByScrollTaskState worker; @@ -154,16 +155,18 @@ public void setupForTest() { } private void setupClient(ThreadPool threadPool) { - if (client != null) { - client.close(); + if (clientThreadPool != null) { + terminate(clientThreadPool); } + clientThreadPool = threadPool; client = new MyMockClient(new NoOpClient(threadPool)); client.threadPool().getThreadContext().putHeader(expectedHeaders); } @After public void tearDownAndVerifyCommonStuff() throws Exception { - client.close(); + terminate(clientThreadPool); + clientThreadPool = null; terminate(threadPool); } diff --git a/server/src/test/java/org/elasticsearch/action/admin/indices/create/CreateIndexRequestBuilderTests.java b/server/src/test/java/org/elasticsearch/action/admin/indices/create/CreateIndexRequestBuilderTests.java index fa2546a98697b..9941f84da7b9a 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/indices/create/CreateIndexRequestBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/indices/create/CreateIndexRequestBuilderTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.core.Strings; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; @@ -28,19 +29,21 @@ public class CreateIndexRequestBuilderTests extends ESTestCase { private static final String KEY = "my.settings.key"; private static final String VALUE = "my.settings.value"; + private TestThreadPool threadPool; private NoOpClient testClient; @Override @Before public void setUp() throws Exception { super.setUp(); - this.testClient = new NoOpClient(getTestName()); + this.threadPool = createThreadPool(); + this.testClient = new NoOpClient(threadPool); } @Override @After public void tearDown() throws Exception { - this.testClient.close(); + this.threadPool.close(); super.tearDown(); } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/Retry2Tests.java b/server/src/test/java/org/elasticsearch/action/bulk/Retry2Tests.java index a33fadc13f4e4..5075c98421af0 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/Retry2Tests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/Retry2Tests.java @@ -18,6 +18,8 @@ import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; import org.junit.After; import org.junit.Before; @@ -39,6 +41,7 @@ public class Retry2Tests extends ESTestCase { private static final int CALLS_TO_FAIL = 5; + private TestThreadPool threadPool; private MockBulkClient bulkClient; /** * Headers that are expected to be sent with all bulk requests. @@ -49,7 +52,8 @@ public class Retry2Tests extends ESTestCase { @Before public void setUp() throws Exception { super.setUp(); - this.bulkClient = new MockBulkClient(getTestName(), CALLS_TO_FAIL); + this.threadPool = createThreadPool(); + this.bulkClient = new MockBulkClient(threadPool, CALLS_TO_FAIL); // Stash some random headers so we can assert that we preserve them bulkClient.threadPool().getThreadContext().stashContext(); expectedHeaders.clear(); @@ -60,8 +64,8 @@ public void setUp() throws Exception { @Override @After public void tearDown() throws Exception { + this.threadPool.close(); super.tearDown(); - this.bulkClient.close(); } private BulkRequest createBulkRequest() { @@ -267,8 +271,8 @@ public void assertOnFailureNeverCalled() { private class MockBulkClient extends NoOpClient { private int numberOfCallsToFail; - private MockBulkClient(String testName, int numberOfCallsToFail) { - super(testName); + private MockBulkClient(ThreadPool threadPool, int numberOfCallsToFail) { + super(threadPool); this.numberOfCallsToFail = numberOfCallsToFail; } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/RetryTests.java b/server/src/test/java/org/elasticsearch/action/bulk/RetryTests.java index 65931846bd366..c780c436e78aa 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/RetryTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/RetryTests.java @@ -18,6 +18,8 @@ import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; import org.junit.After; import org.junit.Before; @@ -36,6 +38,7 @@ public class RetryTests extends ESTestCase { private static final TimeValue DELAY = TimeValue.timeValueMillis(1L); private static final int CALLS_TO_FAIL = 5; + private TestThreadPool threadPool; private MockBulkClient bulkClient; /** * Headers that are expected to be sent with all bulk requests. @@ -46,7 +49,8 @@ public class RetryTests extends ESTestCase { @Before public void setUp() throws Exception { super.setUp(); - this.bulkClient = new MockBulkClient(getTestName(), CALLS_TO_FAIL); + this.threadPool = createThreadPool(); + this.bulkClient = new MockBulkClient(threadPool, CALLS_TO_FAIL); // Stash some random headers so we can assert that we preserve them bulkClient.threadPool().getThreadContext().stashContext(); expectedHeaders.clear(); @@ -57,8 +61,8 @@ public void setUp() throws Exception { @Override @After public void tearDown() throws Exception { + this.threadPool.close(); super.tearDown(); - this.bulkClient.close(); } private BulkRequest createBulkRequest() { @@ -195,8 +199,8 @@ public void assertOnFailureNeverCalled() { private class MockBulkClient extends NoOpClient { private int numberOfCallsToFail; - private MockBulkClient(String testName, int numberOfCallsToFail) { - super(testName); + private MockBulkClient(ThreadPool threadPool, int numberOfCallsToFail) { + super(threadPool); this.numberOfCallsToFail = numberOfCallsToFail; } diff --git a/server/src/test/java/org/elasticsearch/action/index/IndexRequestBuilderTests.java b/server/src/test/java/org/elasticsearch/action/index/IndexRequestBuilderTests.java index b06f833806059..2f66b9d3b70f8 100644 --- a/server/src/test/java/org/elasticsearch/action/index/IndexRequestBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/action/index/IndexRequestBuilderTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; @@ -24,19 +25,21 @@ public class IndexRequestBuilderTests extends ESTestCase { private static final String EXPECTED_SOURCE = "{\"SomeKey\":\"SomeValue\"}"; + private TestThreadPool threadPool; private NoOpClient testClient; @Override @Before public void setUp() throws Exception { super.setUp(); - this.testClient = new NoOpClient(getTestName()); + this.threadPool = createThreadPool(); + this.testClient = new NoOpClient(threadPool); } @Override @After public void tearDown() throws Exception { - this.testClient.close(); + this.threadPool.close(); super.tearDown(); } diff --git a/server/src/test/java/org/elasticsearch/action/synonyms/PutSynonymRuleActionTests.java b/server/src/test/java/org/elasticsearch/action/synonyms/PutSynonymRuleActionTests.java index b05dd1e0abbd6..7d73281e7c86b 100644 --- a/server/src/test/java/org/elasticsearch/action/synonyms/PutSynonymRuleActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/synonyms/PutSynonymRuleActionTests.java @@ -8,7 +8,6 @@ package org.elasticsearch.action.synonyms; -import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.synonyms.RestPutSynonymRuleAction; import org.elasticsearch.test.ESTestCase; @@ -27,7 +26,8 @@ public void testEmptyRequestBody() throws Exception { .build(); FakeRestChannel channel = new FakeRestChannel(request, false, 0); - try (NodeClient nodeClient = new NoOpNodeClient(this.getTestName())) { + try (var threadPool = createThreadPool()) { + final var nodeClient = new NoOpNodeClient(threadPool); expectThrows(IllegalArgumentException.class, () -> action.handleRequest(request, channel, nodeClient)); } } diff --git a/server/src/test/java/org/elasticsearch/action/synonyms/PutSynonymsActionTests.java b/server/src/test/java/org/elasticsearch/action/synonyms/PutSynonymsActionTests.java index 94674134e30ad..df596469c4e1b 100644 --- a/server/src/test/java/org/elasticsearch/action/synonyms/PutSynonymsActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/synonyms/PutSynonymsActionTests.java @@ -8,7 +8,6 @@ package org.elasticsearch.action.synonyms; -import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.synonyms.RestPutSynonymsAction; import org.elasticsearch.test.ESTestCase; @@ -27,7 +26,8 @@ public void testEmptyRequestBody() throws Exception { .build(); FakeRestChannel channel = new FakeRestChannel(request, false, 0); - try (NodeClient nodeClient = new NoOpNodeClient(this.getTestName())) { + try (var threadPool = createThreadPool()) { + final var nodeClient = new NoOpNodeClient(threadPool); expectThrows(IllegalArgumentException.class, () -> action.handleRequest(request, channel, nodeClient)); } } diff --git a/server/src/test/java/org/elasticsearch/client/internal/OriginSettingClientTests.java b/server/src/test/java/org/elasticsearch/client/internal/OriginSettingClientTests.java index cf0bd108e327d..3a93f559284ca 100644 --- a/server/src/test/java/org/elasticsearch/client/internal/OriginSettingClientTests.java +++ b/server/src/test/java/org/elasticsearch/client/internal/OriginSettingClientTests.java @@ -24,23 +24,24 @@ public class OriginSettingClientTests extends ESTestCase { public void testSetsParentId() { String origin = randomAlphaOfLength(7); - /* - * This mock will do nothing but verify that origin is set in the - * thread context before executing the action. - */ - NoOpClient mock = new NoOpClient(getTestName()) { - @Override - protected void doExecute( - ActionType action, - Request request, - ActionListener listener - ) { - assertEquals(origin, threadPool().getThreadContext().getTransient(ThreadContext.ACTION_ORIGIN_TRANSIENT_NAME)); - super.doExecute(action, request, listener); - } - }; + try (var threadPool = createThreadPool()) { + /* + * This mock will do nothing but verify that origin is set in the + * thread context before executing the action. + */ + final var mock = new NoOpClient(threadPool) { + @Override + protected void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + assertEquals(origin, threadPool().getThreadContext().getTransient(ThreadContext.ACTION_ORIGIN_TRANSIENT_NAME)); + super.doExecute(action, request, listener); + } + }; - try (OriginSettingClient client = new OriginSettingClient(mock, origin)) { + final var client = new OriginSettingClient(mock, origin); // All of these should have the origin set client.bulk(new BulkRequest()); client.search(new SearchRequest()); diff --git a/server/src/test/java/org/elasticsearch/client/internal/ParentTaskAssigningClientTests.java b/server/src/test/java/org/elasticsearch/client/internal/ParentTaskAssigningClientTests.java index f140e624cc674..0f12076dd53b6 100644 --- a/server/src/test/java/org/elasticsearch/client/internal/ParentTaskAssigningClientTests.java +++ b/server/src/test/java/org/elasticsearch/client/internal/ParentTaskAssigningClientTests.java @@ -23,19 +23,21 @@ public class ParentTaskAssigningClientTests extends ESTestCase { public void testSetsParentId() { TaskId[] parentTaskId = new TaskId[] { new TaskId(randomAlphaOfLength(3), randomLong()) }; - // This mock will do nothing but verify that parentTaskId is set on all requests sent to it. - NoOpClient mock = new NoOpClient(getTestName()) { - @Override - protected void doExecute( - ActionType action, - Request request, - ActionListener listener - ) { - assertEquals(parentTaskId[0], request.getParentTask()); - super.doExecute(action, request, listener); - } - }; - try (ParentTaskAssigningClient client = new ParentTaskAssigningClient(mock, parentTaskId[0])) { + try (var threadPool = createThreadPool()) { + // This mock will do nothing but verify that parentTaskId is set on all requests sent to it. + final var mock = new NoOpClient(threadPool) { + @Override + protected void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + assertEquals(parentTaskId[0], request.getParentTask()); + super.doExecute(action, request, listener); + } + }; + + final var client = new ParentTaskAssigningClient(mock, parentTaskId[0]); assertEquals(parentTaskId[0], client.getParentTask()); // All of these should have the parentTaskId set diff --git a/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java b/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java index 068e5933bee1e..06328734e394d 100644 --- a/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java +++ b/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java @@ -39,6 +39,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpNodeClient; import org.elasticsearch.test.rest.FakeRestRequest; +import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.transport.BytesRefRecycler; import org.elasticsearch.usage.UsageService; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -87,6 +88,7 @@ public class RestControllerTests extends ESTestCase { private RestController restController; private HierarchyCircuitBreakerService circuitBreakerService; private UsageService usageService; + private TestThreadPool threadPool; private NodeClient client; private Tracer tracer; private List methodList; @@ -107,7 +109,8 @@ public void setup() { inFlightRequestsBreaker = circuitBreakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS); HttpServerTransport httpServerTransport = new TestHttpServerTransport(); - client = new NoOpNodeClient(this.getTestName()); + threadPool = createThreadPool(); + client = new NoOpNodeClient(threadPool); tracer = mock(Tracer.class); restController = new RestController(null, client, circuitBreakerService, usageService, tracer); restController.registerHandler( @@ -126,7 +129,7 @@ public void setup() { @After public void teardown() throws IOException { - IOUtils.close(client); + IOUtils.close(threadPool); } public void testApplyProductSpecificResponseHeaders() { diff --git a/server/src/test/java/org/elasticsearch/rest/action/admin/indices/RestAnalyzeActionTests.java b/server/src/test/java/org/elasticsearch/rest/action/admin/indices/RestAnalyzeActionTests.java index 1151965a46ce6..f42c450221383 100644 --- a/server/src/test/java/org/elasticsearch/rest/action/admin/indices/RestAnalyzeActionTests.java +++ b/server/src/test/java/org/elasticsearch/rest/action/admin/indices/RestAnalyzeActionTests.java @@ -8,7 +8,6 @@ package org.elasticsearch.rest.action.admin.indices; import org.elasticsearch.action.admin.indices.analyze.AnalyzeAction; -import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.index.analysis.NameOrDefinition; import org.elasticsearch.rest.RestRequest; @@ -95,7 +94,8 @@ public void testParseXContentForAnalyzeRequestWithInvalidJsonThrowsException() { new BytesArray("{invalid_json}"), XContentType.JSON ).build(); - try (NodeClient client = new NoOpNodeClient(this.getClass().getSimpleName())) { + try (var threadPool = createThreadPool()) { + final var client = new NoOpNodeClient(threadPool); var e = expectThrows(XContentParseException.class, () -> action.handleRequest(request, null, client)); assertThat(e.getMessage(), containsString("expecting double-quote")); } diff --git a/server/src/test/java/org/elasticsearch/rest/action/cat/RestCatComponentTemplateActionTests.java b/server/src/test/java/org/elasticsearch/rest/action/cat/RestCatComponentTemplateActionTests.java index 6a422a3fd97aa..33b20cfeee959 100644 --- a/server/src/test/java/org/elasticsearch/rest/action/cat/RestCatComponentTemplateActionTests.java +++ b/server/src/test/java/org/elasticsearch/rest/action/cat/RestCatComponentTemplateActionTests.java @@ -28,6 +28,7 @@ import org.elasticsearch.test.rest.FakeRestChannel; import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.test.rest.RestActionTestCase; +import org.elasticsearch.threadpool.ThreadPool; import java.util.HashMap; import java.util.Map; @@ -78,7 +79,8 @@ public void testRestCatComponentAction() throws Exception { FakeRestChannel channel = new FakeRestChannel(getCatComponentTemplateRequest, true, 0); // execute action - try (NoOpNodeClient nodeClient = buildNodeClient()) { + try (var threadPool = createThreadPool()) { + final var nodeClient = buildNodeClient(threadPool); action.handleRequest(getCatComponentTemplateRequest, channel, nodeClient); } @@ -96,7 +98,8 @@ public void testRestCatComponentActionWithParam() throws Exception { FakeRestChannel channel = new FakeRestChannel(getCatComponentTemplateRequest, true, 0); // execute action - try (NoOpNodeClient nodeClient = buildNodeClient()) { + try (var threadPool = createThreadPool()) { + final var nodeClient = buildNodeClient(threadPool); action.handleRequest(getCatComponentTemplateRequest, channel, nodeClient); } @@ -106,10 +109,10 @@ public void testRestCatComponentActionWithParam() throws Exception { assertThat(channel.capturedResponse().content().utf8ToString(), emptyString()); } - private NoOpNodeClient buildNodeClient() { + private NoOpNodeClient buildNodeClient(ThreadPool threadPool) { ClusterStateResponse clusterStateResponse = new ClusterStateResponse(clusterName, clusterState, false); - return new NoOpNodeClient(getTestName()) { + return new NoOpNodeClient(threadPool) { @Override @SuppressWarnings("unchecked") public void doExecute( diff --git a/server/src/test/java/org/elasticsearch/rest/action/cat/RestTasksActionTests.java b/server/src/test/java/org/elasticsearch/rest/action/cat/RestTasksActionTests.java index 38d39fda898aa..803f9c2fb1b01 100644 --- a/server/src/test/java/org/elasticsearch/rest/action/cat/RestTasksActionTests.java +++ b/server/src/test/java/org/elasticsearch/rest/action/cat/RestTasksActionTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.test.client.NoOpNodeClient; import org.elasticsearch.test.rest.FakeRestChannel; import org.elasticsearch.test.rest.FakeRestRequest; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.NamedXContentRegistry; import java.util.List; @@ -33,7 +34,8 @@ public void testConsumesParameters() throws Exception { Map.of("parent_task_id", "the node:3", "nodes", "node1,node2", "actions", "*") ).build(); FakeRestChannel fakeRestChannel = new FakeRestChannel(fakeRestRequest, false, 1); - try (NoOpNodeClient nodeClient = buildNodeClient()) { + try (var threadPool = createThreadPool()) { + final var nodeClient = buildNodeClient(threadPool); action.handleRequest(fakeRestRequest, fakeRestChannel, nodeClient); } @@ -41,8 +43,8 @@ public void testConsumesParameters() throws Exception { assertThat(fakeRestChannel.responses().get(), is(1)); } - private NoOpNodeClient buildNodeClient() { - return new NoOpNodeClient(getTestName()) { + private NoOpNodeClient buildNodeClient(ThreadPool threadPool) { + return new NoOpNodeClient(threadPool) { @Override @SuppressWarnings("unchecked") public void doExecute( diff --git a/server/src/test/java/org/elasticsearch/rest/action/document/RestBulkActionTests.java b/server/src/test/java/org/elasticsearch/rest/action/document/RestBulkActionTests.java index d7e3aaf326075..caeb0f36a1000 100644 --- a/server/src/test/java/org/elasticsearch/rest/action/document/RestBulkActionTests.java +++ b/server/src/test/java/org/elasticsearch/rest/action/document/RestBulkActionTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.update.UpdateRequest; -import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.rest.RestChannel; @@ -39,15 +38,16 @@ public class RestBulkActionTests extends ESTestCase { public void testBulkPipelineUpsert() throws Exception { SetOnce bulkCalled = new SetOnce<>(); - try (NodeClient verifyingClient = new NoOpNodeClient(this.getTestName()) { - @Override - public void bulk(BulkRequest request, ActionListener listener) { - bulkCalled.set(true); - assertThat(request.requests(), hasSize(2)); - UpdateRequest updateRequest = (UpdateRequest) request.requests().get(1); - assertThat(updateRequest.upsertRequest().getPipeline(), equalTo("timestamps")); - } - }) { + try (var threadPool = createThreadPool()) { + final var verifyingClient = new NoOpNodeClient(threadPool) { + @Override + public void bulk(BulkRequest request, ActionListener listener) { + bulkCalled.set(true); + assertThat(request.requests(), hasSize(2)); + UpdateRequest updateRequest = (UpdateRequest) request.requests().get(1); + assertThat(updateRequest.upsertRequest().getPipeline(), equalTo("timestamps")); + } + }; final Map params = new HashMap<>(); params.put("pipeline", "timestamps"); new RestBulkAction(settings(IndexVersion.current()).build()).handleRequest( @@ -68,17 +68,18 @@ public void testListExecutedPipelines() throws Exception { AtomicBoolean bulkCalled = new AtomicBoolean(false); AtomicBoolean listExecutedPipelinesRequest1 = new AtomicBoolean(false); AtomicBoolean listExecutedPipelinesRequest2 = new AtomicBoolean(false); - try (NodeClient verifyingClient = new NoOpNodeClient(this.getTestName()) { - @Override - public void bulk(BulkRequest request, ActionListener listener) { - bulkCalled.set(true); - assertThat(request.requests(), hasSize(2)); - IndexRequest indexRequest1 = (IndexRequest) request.requests().get(0); - listExecutedPipelinesRequest1.set(indexRequest1.getListExecutedPipelines()); - IndexRequest indexRequest2 = (IndexRequest) request.requests().get(1); - listExecutedPipelinesRequest2.set(indexRequest2.getListExecutedPipelines()); - } - }) { + try (var threadPool = createThreadPool()) { + final var verifyingClient = new NoOpNodeClient(threadPool) { + @Override + public void bulk(BulkRequest request, ActionListener listener) { + bulkCalled.set(true); + assertThat(request.requests(), hasSize(2)); + IndexRequest indexRequest1 = (IndexRequest) request.requests().get(0); + listExecutedPipelinesRequest1.set(indexRequest1.getListExecutedPipelines()); + IndexRequest indexRequest2 = (IndexRequest) request.requests().get(1); + listExecutedPipelinesRequest2.set(indexRequest2.getListExecutedPipelines()); + } + }; Map params = new HashMap<>(); { new RestBulkAction(settings(IndexVersion.current()).build()).handleRequest( diff --git a/server/src/test/java/org/elasticsearch/search/scroll/RestClearScrollActionTests.java b/server/src/test/java/org/elasticsearch/search/scroll/RestClearScrollActionTests.java index fbc4acc959579..a5fabe32de645 100644 --- a/server/src/test/java/org/elasticsearch/search/scroll/RestClearScrollActionTests.java +++ b/server/src/test/java/org/elasticsearch/search/scroll/RestClearScrollActionTests.java @@ -12,7 +12,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.search.ClearScrollRequest; import org.elasticsearch.action.search.ClearScrollResponse; -import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.search.RestClearScrollAction; @@ -41,14 +40,15 @@ public void testParseClearScrollRequestWithInvalidJsonThrowsException() throws E public void testBodyParamsOverrideQueryStringParams() throws Exception { SetOnce scrollCalled = new SetOnce<>(); - try (NodeClient nodeClient = new NoOpNodeClient(this.getTestName()) { - @Override - public void clearScroll(ClearScrollRequest request, ActionListener listener) { - scrollCalled.set(true); - assertThat(request.getScrollIds(), hasSize(1)); - assertThat(request.getScrollIds().get(0), equalTo("BODY")); - } - }) { + try (var threadPool = createThreadPool()) { + final var nodeClient = new NoOpNodeClient(threadPool) { + @Override + public void clearScroll(ClearScrollRequest request, ActionListener listener) { + scrollCalled.set(true); + assertThat(request.getScrollIds(), hasSize(1)); + assertThat(request.getScrollIds().get(0), equalTo("BODY")); + } + }; RestClearScrollAction action = new RestClearScrollAction(); RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withParams( Collections.singletonMap("scroll_id", "QUERY_STRING") diff --git a/server/src/test/java/org/elasticsearch/search/scroll/RestSearchScrollActionTests.java b/server/src/test/java/org/elasticsearch/search/scroll/RestSearchScrollActionTests.java index 2e4a92fa54344..9c521d15d7f74 100644 --- a/server/src/test/java/org/elasticsearch/search/scroll/RestSearchScrollActionTests.java +++ b/server/src/test/java/org/elasticsearch/search/scroll/RestSearchScrollActionTests.java @@ -12,7 +12,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchScrollRequest; -import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.search.RestSearchScrollAction; @@ -41,14 +40,15 @@ public void testParseSearchScrollRequestWithInvalidJsonThrowsException() throws public void testBodyParamsOverrideQueryStringParams() throws Exception { SetOnce scrollCalled = new SetOnce<>(); - try (NodeClient nodeClient = new NoOpNodeClient(this.getTestName()) { - @Override - public void searchScroll(SearchScrollRequest request, ActionListener listener) { - scrollCalled.set(true); - assertThat(request.scrollId(), equalTo("BODY")); - assertThat(request.scroll().keepAlive().getStringRep(), equalTo("1m")); - } - }) { + try (var threadPool = createThreadPool()) { + final var nodeClient = new NoOpNodeClient(threadPool) { + @Override + public void searchScroll(SearchScrollRequest request, ActionListener listener) { + scrollCalled.set(true); + assertThat(request.scrollId(), equalTo("BODY")); + assertThat(request.scroll().keepAlive().getStringRep(), equalTo("1m")); + } + }; RestSearchScrollAction action = new RestSearchScrollAction(); Map params = new HashMap<>(); params.put("scroll_id", "QUERY_STRING"); diff --git a/server/src/test/java/org/elasticsearch/usage/UsageServiceTests.java b/server/src/test/java/org/elasticsearch/usage/UsageServiceTests.java index 136a86457ba34..5cc92db89d5a2 100644 --- a/server/src/test/java/org/elasticsearch/usage/UsageServiceTests.java +++ b/server/src/test/java/org/elasticsearch/usage/UsageServiceTests.java @@ -95,7 +95,8 @@ public void testRestUsage() throws Exception { usageService.addRestHandler(handlerD); usageService.addRestHandler(handlerE); usageService.addRestHandler(handlerF); - try (NodeClient client = new NoOpNodeClient(this.getClass().getSimpleName() + "TestClient")) { + try (var threadPool = createThreadPool()) { + final var client = new NoOpNodeClient(threadPool); handlerA.handleRequest(restRequest, null, client); handlerB.handleRequest(restRequest, null, client); handlerA.handleRequest(restRequest, null, client); diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java index eb4a129f09110..ab5e3a7555214 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java @@ -23,6 +23,7 @@ import org.elasticsearch.search.vectors.KnnSearchBuilder; import org.elasticsearch.search.vectors.QueryVectorBuilder; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.junit.Before; @@ -114,7 +115,8 @@ public final void testKnnSearchRewrite() throws Exception { KnnSearchBuilder::new, TransportVersion.current() ); - try (NoOpClient client = new AssertingClient(expected, queryVectorBuilder)) { + try (var threadPool = createThreadPool()) { + final var client = new AssertingClient(threadPool, expected, queryVectorBuilder); QueryRewriteContext context = new QueryRewriteContext(null, client, null); PlainActionFuture future = new PlainActionFuture<>(); Rewriteable.rewriteAndFetch(randomFrom(serialized, searchBuilder), context, future); @@ -128,7 +130,8 @@ public final void testKnnSearchRewrite() throws Exception { public final void testVectorFetch() throws Exception { float[] expected = randomVector(randomIntBetween(10, 1024)); T queryVectorBuilder = createTestInstance(expected); - try (NoOpClient client = new AssertingClient(expected, queryVectorBuilder)) { + try (var threadPool = createThreadPool()) { + final var client = new AssertingClient(threadPool, expected, queryVectorBuilder); PlainActionFuture future = new PlainActionFuture<>(); queryVectorBuilder.buildVector(client, future); assertThat(future.get(), equalTo(expected)); @@ -163,8 +166,8 @@ private class AssertingClient extends NoOpClient { private final float[] array; private final T queryVectorBuilder; - AssertingClient(float[] array, T queryVectorBuilder) { - super("query_vector_builder_tests"); + AssertingClient(ThreadPool threadPool, float[] array, T queryVectorBuilder) { + super(threadPool); this.array = array; this.queryVectorBuilder = queryVectorBuilder; } diff --git a/test/framework/src/main/java/org/elasticsearch/test/client/NoOpClient.java b/test/framework/src/main/java/org/elasticsearch/test/client/NoOpClient.java index a9f4391975f68..7914d00be91fc 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/client/NoOpClient.java +++ b/test/framework/src/main/java/org/elasticsearch/test/client/NoOpClient.java @@ -8,18 +8,14 @@ package org.elasticsearch.test.client; -import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; import org.elasticsearch.client.internal.support.AbstractClient; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; -import java.util.concurrent.TimeUnit; - /** * Client that always responds with {@code null} to every request. Override {@link #doExecute(ActionType, ActionRequest, ActionListener)} * for testing. @@ -34,13 +30,6 @@ public NoOpClient(ThreadPool threadPool) { super(Settings.EMPTY, threadPool); } - /** - * Create a new {@link TestThreadPool} for this client. This {@linkplain TestThreadPool} is terminated on {@link #close()}. - */ - public NoOpClient(String testName) { - super(Settings.EMPTY, new TestThreadPool(testName)); - } - @Override protected void doExecute( ActionType action, @@ -51,11 +40,5 @@ protected void } @Override - public void close() { - try { - ThreadPool.terminate(threadPool(), 10, TimeUnit.SECONDS); - } catch (Exception e) { - throw new ElasticsearchException(e.getMessage(), e); - } - } + public void close() {} } diff --git a/test/framework/src/main/java/org/elasticsearch/test/client/NoOpNodeClient.java b/test/framework/src/main/java/org/elasticsearch/test/client/NoOpNodeClient.java index 0df930048f9c7..0300a3c41d00f 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/client/NoOpNodeClient.java +++ b/test/framework/src/main/java/org/elasticsearch/test/client/NoOpNodeClient.java @@ -8,7 +8,6 @@ package org.elasticsearch.test.client; -import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; @@ -20,14 +19,12 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskManager; -import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.RemoteClusterService; import org.elasticsearch.transport.Transport; import java.util.Map; import java.util.concurrent.Executor; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Supplier; @@ -48,13 +45,6 @@ public NoOpNodeClient(ThreadPool threadPool) { super(Settings.EMPTY, threadPool); } - /** - * Create a new {@link TestThreadPool} for this client. This {@linkplain TestThreadPool} is terminated on {@link #close()}. - */ - public NoOpNodeClient(String testName) { - super(Settings.EMPTY, new TestThreadPool(testName)); - } - @Override public void doExecute( ActionType action, @@ -97,17 +87,4 @@ public String getLocalNodeId() { public Client getRemoteClusterClient(String clusterAlias, Executor responseExecutor) { return null; } - - @Override - public void close() { - try { - ThreadPool.terminate(threadPool(), 10, TimeUnit.SECONDS); - } catch (Exception e) { - throw new ElasticsearchException(e.getMessage(), e); - } - } - - public long getExecutionCount() { - return executionCount.get(); - } } diff --git a/test/framework/src/main/java/org/elasticsearch/test/rest/RestActionTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/rest/RestActionTestCase.java index 1229b3470775f..9e638425d5c5c 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/rest/RestActionTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/rest/RestActionTestCase.java @@ -20,6 +20,8 @@ import org.elasticsearch.telemetry.tracing.Tracer; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpNodeClient; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.usage.UsageService; import org.junit.After; import org.junit.Before; @@ -35,17 +37,19 @@ */ public abstract class RestActionTestCase extends ESTestCase { private RestController controller; + private TestThreadPool threadPool; protected VerifyingClient verifyingClient; @Before public void setUpController() { - verifyingClient = new VerifyingClient(this.getTestName()); + threadPool = createThreadPool(); + verifyingClient = new VerifyingClient(threadPool); controller = new RestController(null, verifyingClient, new NoneCircuitBreakerService(), new UsageService(), Tracer.NOOP); } @After public void tearDownController() { - verifyingClient.close(); + threadPool.close(); } /** @@ -78,8 +82,8 @@ public static final class VerifyingClient extends NoOpNodeClient { AtomicReference, ActionRequest, ActionResponse>> executeVerifier = new AtomicReference<>(); AtomicReference, ActionRequest, ActionResponse>> executeLocallyVerifier = new AtomicReference<>(); - public VerifyingClient(String testName) { - super(testName); + public VerifyingClient(ThreadPool threadPool) { + super(threadPool); reset(); } diff --git a/test/framework/src/main/java/org/elasticsearch/threadpool/TestThreadPool.java b/test/framework/src/main/java/org/elasticsearch/threadpool/TestThreadPool.java index 58f06cbbe9d40..e8a853989e8e5 100644 --- a/test/framework/src/main/java/org/elasticsearch/threadpool/TestThreadPool.java +++ b/test/framework/src/main/java/org/elasticsearch/threadpool/TestThreadPool.java @@ -10,14 +10,16 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.core.Releasable; import org.elasticsearch.node.Node; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; -public class TestThreadPool extends ThreadPool { +public class TestThreadPool extends ThreadPool implements Releasable { private final CountDownLatch blockingLatch = new CountDownLatch(1); private volatile boolean returnRejectingExecutor = false; @@ -98,4 +100,9 @@ private synchronized void createRejectingExecutor() { throw new RuntimeException(e); } } + + @Override + public void close() { + ThreadPool.terminate(this, 10, TimeUnit.SECONDS); + } } diff --git a/x-pack/plugin/autoscaling/src/test/java/org/elasticsearch/xpack/autoscaling/capacity/nodeinfo/AutoscalingNodesInfoServiceTests.java b/x-pack/plugin/autoscaling/src/test/java/org/elasticsearch/xpack/autoscaling/capacity/nodeinfo/AutoscalingNodesInfoServiceTests.java index 8d357a5f050ca..c6fb0613b3e8e 100644 --- a/x-pack/plugin/autoscaling/src/test/java/org/elasticsearch/xpack/autoscaling/capacity/nodeinfo/AutoscalingNodesInfoServiceTests.java +++ b/x-pack/plugin/autoscaling/src/test/java/org/elasticsearch/xpack/autoscaling/capacity/nodeinfo/AutoscalingNodesInfoServiceTests.java @@ -42,6 +42,8 @@ import org.elasticsearch.monitor.os.OsInfo; import org.elasticsearch.monitor.os.OsStats; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.autoscaling.AutoscalingMetadata; import org.elasticsearch.xpack.autoscaling.AutoscalingTestCase; import org.elasticsearch.xpack.autoscaling.policy.AutoscalingPolicy; @@ -73,6 +75,7 @@ public class AutoscalingNodesInfoServiceTests extends AutoscalingTestCase { + private TestThreadPool threadPool; private NodeStatsClient client; private AutoscalingNodeInfoService service; private TimeValue fetchTimeout; @@ -83,7 +86,8 @@ public class AutoscalingNodesInfoServiceTests extends AutoscalingTestCase { @Override public void setUp() throws Exception { super.setUp(); - client = new NodeStatsClient(); + threadPool = createThreadPool(); + client = new NodeStatsClient(threadPool); final ClusterService clusterService = mock(ClusterService.class); Settings settings; if (randomBoolean()) { @@ -105,8 +109,8 @@ public void setUp() throws Exception { @After @Override public void tearDown() throws Exception { + threadPool.close(); super.tearDown(); - client.close(); } public void testAddRemoveNode() { @@ -470,8 +474,8 @@ private class NodeStatsClient extends NoOpClient { private BiConsumer> responderStats; private BiConsumer> responderInfo; - private NodeStatsClient() { - super(getTestName()); + private NodeStatsClient(ThreadPool threadPool) { + super(threadPool); } public void respondInfo(NodesInfoResponse response, Runnable whileFetching) { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/common/validation/SourceDestValidatorTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/common/validation/SourceDestValidatorTests.java index 2e5891a31d402..a9643ce099262 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/common/validation/SourceDestValidatorTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/common/validation/SourceDestValidatorTests.java @@ -95,6 +95,7 @@ public class SourceDestValidatorTests extends ESTestCase { REMOTE_SOURCE_VALIDATION ); + private TestThreadPool clientThreadPool; private Client clientWithBasicLicense; private Client clientWithExpiredBasicLicense; private Client clientWithPlatinumLicense; @@ -154,8 +155,8 @@ private class MockClientLicenseCheck extends NoOpClient { private final String license; private final LicenseStatus licenseStatus; - MockClientLicenseCheck(String testName, String license, LicenseStatus licenseStatus) { - super(testName); + MockClientLicenseCheck(ThreadPool threadPool, String license, LicenseStatus licenseStatus) { + super(threadPool); this.license = license; this.licenseStatus = licenseStatus; } @@ -187,21 +188,19 @@ protected void @Before public void setupComponents() { - clientWithBasicLicense = new MockClientLicenseCheck(getTestName(), "basic", LicenseStatus.ACTIVE); - clientWithExpiredBasicLicense = new MockClientLicenseCheck(getTestName(), "basic", LicenseStatus.EXPIRED); + clientThreadPool = createThreadPool(); + clientWithBasicLicense = new MockClientLicenseCheck(clientThreadPool, "basic", LicenseStatus.ACTIVE); + clientWithExpiredBasicLicense = new MockClientLicenseCheck(clientThreadPool, "basic", LicenseStatus.EXPIRED); LicensedFeature.Momentary feature = LicensedFeature.momentary(null, "feature", License.OperationMode.BASIC); platinumFeature = LicensedFeature.momentary(null, "platinum-feature", License.OperationMode.PLATINUM); remoteClusterLicenseCheckerBasic = new RemoteClusterLicenseChecker(clientWithBasicLicense, feature); - clientWithPlatinumLicense = new MockClientLicenseCheck(getTestName(), "platinum", LicenseStatus.ACTIVE); - clientWithTrialLicense = new MockClientLicenseCheck(getTestName(), "trial", LicenseStatus.ACTIVE); + clientWithPlatinumLicense = new MockClientLicenseCheck(clientThreadPool, "platinum", LicenseStatus.ACTIVE); + clientWithTrialLicense = new MockClientLicenseCheck(clientThreadPool, "trial", LicenseStatus.ACTIVE); } @After public void closeComponents() throws Exception { - clientWithBasicLicense.close(); - clientWithExpiredBasicLicense.close(); - clientWithPlatinumLicense.close(); - clientWithTrialLicense.close(); + clientThreadPool.close(); ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/CleanupShrinkIndexStepTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/CleanupShrinkIndexStepTests.java index 5269b3abd826f..90482be334363 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/CleanupShrinkIndexStepTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/CleanupShrinkIndexStepTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ilm.Step.StepKey; import java.util.Map; @@ -100,7 +101,8 @@ public void testPerformAction() { .metadata(Metadata.builder().put(indexMetadata, true).build()) .build(); - try (NoOpClient client = getDeleteIndexRequestAssertingClient(shrinkIndexName)) { + try (var threadPool = createThreadPool()) { + final var client = getDeleteIndexRequestAssertingClient(threadPool, shrinkIndexName); CleanupShrinkIndexStep step = new CleanupShrinkIndexStep(randomStepKey(), randomStepKey(), client); step.performAction(indexMetadata, clusterState, null, ActionListener.noop()); } @@ -126,14 +128,15 @@ public void testDeleteSkippedIfManagedIndexIsShrunkAndSourceDoesntExist() { .metadata(Metadata.builder().put(shrunkIndexMetadata, true).build()) .build(); - try (NoOpClient client = getFailingIfCalledClient()) { + try (var threadPool = createThreadPool()) { + final var client = getFailingIfCalledClient(threadPool); CleanupShrinkIndexStep step = new CleanupShrinkIndexStep(randomStepKey(), randomStepKey(), client); step.performAction(shrunkIndexMetadata, clusterState, null, ActionListener.noop()); } } - private NoOpClient getDeleteIndexRequestAssertingClient(String shrinkIndexName) { - return new NoOpClient(getTestName()) { + private NoOpClient getDeleteIndexRequestAssertingClient(ThreadPool threadPool, String shrinkIndexName) { + return new NoOpClient(threadPool) { @Override protected void doExecute( ActionType action, @@ -147,8 +150,8 @@ protected void }; } - private NoOpClient getFailingIfCalledClient() { - return new NoOpClient(getTestName()) { + private NoOpClient getFailingIfCalledClient(ThreadPool threadPool) { + return new NoOpClient(threadPool) { @Override protected void doExecute( ActionType action, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/CleanupSnapshotStepTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/CleanupSnapshotStepTests.java index c279e2e82738e..5fddcd51a6614 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/CleanupSnapshotStepTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/CleanupSnapshotStepTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ilm.Step.StepKey; import java.util.Map; @@ -108,14 +109,15 @@ public void testPerformAction() { .metadata(Metadata.builder().put(indexMetadata, true).build()) .build(); - try (NoOpClient client = getDeleteSnapshotRequestAssertingClient(snapshotName)) { + try (var threadPool = createThreadPool()) { + final var client = getDeleteSnapshotRequestAssertingClient(threadPool, snapshotName); CleanupSnapshotStep step = new CleanupSnapshotStep(randomStepKey(), randomStepKey(), client); step.performAction(indexMetadata, clusterState, null, ActionListener.noop()); } } - private NoOpClient getDeleteSnapshotRequestAssertingClient(String expectedSnapshotName) { - return new NoOpClient(getTestName()) { + private NoOpClient getDeleteSnapshotRequestAssertingClient(ThreadPool threadPool, String expectedSnapshotName) { + return new NoOpClient(threadPool) { @Override protected void doExecute( ActionType action, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/CleanupTargetIndexStepTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/CleanupTargetIndexStepTests.java index cad79e1b8c2ca..dea53b2c736ac 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/CleanupTargetIndexStepTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/CleanupTargetIndexStepTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ilm.Step.StepKey; import java.util.Map; @@ -118,7 +119,8 @@ public void testPerformAction() { .metadata(Metadata.builder().put(indexMetadata, true).build()) .build(); - try (NoOpClient client = getDeleteIndexRequestAssertingClient(shrinkIndexName)) { + try (var threadPool = createThreadPool()) { + final var client = getDeleteIndexRequestAssertingClient(threadPool, shrinkIndexName); CleanupTargetIndexStep step = new CleanupTargetIndexStep( randomStepKey(), randomStepKey(), @@ -150,7 +152,8 @@ public void testDeleteSkippedIfManagedIndexIsShrunkAndSourceDoesntExist() { .metadata(Metadata.builder().put(shrunkIndexMetadata, true).build()) .build(); - try (NoOpClient client = getFailingIfCalledClient()) { + try (var threadPool = createThreadPool()) { + final var client = getFailingIfCalledClient(threadPool); CleanupTargetIndexStep step = new CleanupTargetIndexStep( randomStepKey(), randomStepKey(), @@ -162,8 +165,8 @@ public void testDeleteSkippedIfManagedIndexIsShrunkAndSourceDoesntExist() { } } - private NoOpClient getDeleteIndexRequestAssertingClient(String shrinkIndexName) { - return new NoOpClient(getTestName()) { + private NoOpClient getDeleteIndexRequestAssertingClient(ThreadPool threadPool, String shrinkIndexName) { + return new NoOpClient(threadPool) { @Override protected void doExecute( ActionType action, @@ -177,8 +180,8 @@ protected void }; } - private NoOpClient getFailingIfCalledClient() { - return new NoOpClient(getTestName()) { + private NoOpClient getFailingIfCalledClient(ThreadPool threadPool) { + return new NoOpClient(threadPool) { @Override protected void doExecute( ActionType action, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/CreateSnapshotStepTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/CreateSnapshotStepTests.java index ace616d837e94..b954162aee6f2 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/CreateSnapshotStepTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/CreateSnapshotStepTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.index.IndexVersion; import org.elasticsearch.snapshots.SnapshotNameAlreadyInUseException; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ilm.Step.StepKey; import java.util.HashMap; @@ -132,7 +133,8 @@ public void testPerformAction() { .metadata(Metadata.builder().put(indexMetadata, true).build()) .build(); - try (NoOpClient client = getCreateSnapshotRequestAssertingClient(repository, snapshotName, indexName)) { + try (var threadPool = createThreadPool()) { + final var client = getCreateSnapshotRequestAssertingClient(threadPool, repository, snapshotName, indexName); CreateSnapshotStep step = new CreateSnapshotStep(randomStepKey(), randomStepKey(), randomStepKey(), client); step.performAction(indexMetadata, clusterState, null, ActionListener.noop()); } @@ -158,7 +160,8 @@ public void testNextStepKey() { .metadata(Metadata.builder().put(indexMetadata, true).build()) .build(); { - try (NoOpClient client = new NoOpClient(getTestName())) { + try (var threadPool = createThreadPool()) { + final var client = new NoOpClient(threadPool); StepKey nextKeyOnComplete = randomStepKey(); StepKey nextKeyOnIncomplete = randomStepKey(); CreateSnapshotStep completeStep = new CreateSnapshotStep(randomStepKey(), nextKeyOnComplete, nextKeyOnIncomplete, client) { @@ -173,7 +176,8 @@ void createSnapshot(IndexMetadata indexMetadata, ActionListener listene } { - try (NoOpClient client = new NoOpClient(getTestName())) { + try (var threadPool = createThreadPool()) { + final var client = new NoOpClient(threadPool); StepKey nextKeyOnComplete = randomStepKey(); StepKey nextKeyOnIncomplete = randomStepKey(); CreateSnapshotStep incompleteStep = new CreateSnapshotStep( @@ -193,7 +197,8 @@ void createSnapshot(IndexMetadata indexMetadata, ActionListener listene } { - try (NoOpClient client = new NoOpClient(getTestName())) { + try (var threadPool = createThreadPool()) { + final var client = new NoOpClient(threadPool); StepKey nextKeyOnComplete = randomStepKey(); StepKey nextKeyOnIncomplete = randomStepKey(); CreateSnapshotStep doubleInvocationStep = new CreateSnapshotStep( @@ -213,8 +218,13 @@ void createSnapshot(IndexMetadata indexMetadata, ActionListener listene } } - private NoOpClient getCreateSnapshotRequestAssertingClient(String expectedRepoName, String expectedSnapshotName, String indexName) { - return new NoOpClient(getTestName()) { + private NoOpClient getCreateSnapshotRequestAssertingClient( + ThreadPool threadPool, + String expectedRepoName, + String expectedSnapshotName, + String indexName + ) { + return new NoOpClient(threadPool) { @Override protected void doExecute( ActionType action, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/DownsampleStepTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/DownsampleStepTests.java index 421f6bd30f432..e3731c4416491 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/DownsampleStepTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/DownsampleStepTests.java @@ -251,7 +251,8 @@ public void testNextStepKey() { .metadata(Metadata.builder().put(sourceIndexMetadata, true).build()) .build(); { - try (NoOpClient client = new NoOpClient(getTestName())) { + try (var threadPool = createThreadPool()) { + final var client = new NoOpClient(threadPool); StepKey nextKey = randomStepKey(); DateHistogramInterval fixedInterval = ConfigTestHelpers.randomInterval(); TimeValue timeout = DownsampleAction.DEFAULT_WAIT_TIMEOUT; @@ -265,7 +266,8 @@ void performDownsampleIndex(String indexName, String downsampleIndexName, Action } } { - try (NoOpClient client = new NoOpClient(getTestName())) { + try (var threadPool = createThreadPool()) { + final var client = new NoOpClient(threadPool); StepKey nextKey = randomStepKey(); DateHistogramInterval fixedInterval = ConfigTestHelpers.randomInterval(); TimeValue timeout = DownsampleAction.DEFAULT_WAIT_TIMEOUT; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/MountSnapshotStepTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/MountSnapshotStepTests.java index 4bd909fc00ca7..f905ca38e1c5c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/MountSnapshotStepTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/MountSnapshotStepTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.index.IndexVersion; import org.elasticsearch.snapshots.RestoreInfo; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ilm.Step.StepKey; import org.elasticsearch.xpack.core.searchablesnapshots.MountSearchableSnapshotAction; import org.elasticsearch.xpack.core.searchablesnapshots.MountSearchableSnapshotRequest; @@ -164,16 +165,16 @@ public void testPerformAction() throws Exception { .metadata(Metadata.builder().put(indexMetadata, true).build()) .build(); - try ( - NoOpClient client = getRestoreSnapshotRequestAssertingClient( + try (var threadPool = createThreadPool()) { + final var client = getRestoreSnapshotRequestAssertingClient( + threadPool, repository, snapshotName, indexName, RESTORED_INDEX_PREFIX, indexName, new String[] { LifecycleSettings.LIFECYCLE_NAME } - ) - ) { + ); MountSnapshotStep step = new MountSnapshotStep( randomStepKey(), randomStepKey(), @@ -207,7 +208,8 @@ public void testResponseStatusHandling() throws Exception { { RestoreSnapshotResponse responseWithOKStatus = new RestoreSnapshotResponse(new RestoreInfo("test", List.of(), 1, 1)); - try (NoOpClient clientPropagatingOKResponse = getClientTriggeringResponse(responseWithOKStatus)) { + try (var threadPool = createThreadPool()) { + final var clientPropagatingOKResponse = getClientTriggeringResponse(threadPool, responseWithOKStatus); MountSnapshotStep step = new MountSnapshotStep( randomStepKey(), randomStepKey(), @@ -221,7 +223,8 @@ public void testResponseStatusHandling() throws Exception { { RestoreSnapshotResponse responseWithACCEPTEDStatus = new RestoreSnapshotResponse((RestoreInfo) null); - try (NoOpClient clientPropagatingACCEPTEDResponse = getClientTriggeringResponse(responseWithACCEPTEDStatus)) { + try (var threadPool = createThreadPool()) { + final var clientPropagatingACCEPTEDResponse = getClientTriggeringResponse(threadPool, responseWithACCEPTEDStatus); MountSnapshotStep step = new MountSnapshotStep( randomStepKey(), randomStepKey(), @@ -286,16 +289,16 @@ public void doTestMountWithoutSnapshotIndexNameInState(String prefix) throws Exc .metadata(Metadata.builder().put(indexMetadata, true).build()) .build(); - try ( - NoOpClient client = getRestoreSnapshotRequestAssertingClient( + try (var threadPool = createThreadPool()) { + final var client = getRestoreSnapshotRequestAssertingClient( + threadPool, repository, snapshotName, indexName, RESTORED_INDEX_PREFIX, indexNameSnippet, new String[] { LifecycleSettings.LIFECYCLE_NAME } - ) - ) { + ); MountSnapshotStep step = new MountSnapshotStep( randomStepKey(), randomStepKey(), @@ -328,16 +331,16 @@ public void testIgnoreTotalShardsPerNodeInFrozenPhase() throws Exception { .metadata(Metadata.builder().put(indexMetadata, true).build()) .build(); - try ( - NoOpClient client = getRestoreSnapshotRequestAssertingClient( + try (var threadPool = createThreadPool()) { + final var client = getRestoreSnapshotRequestAssertingClient( + threadPool, repository, snapshotName, indexName, RESTORED_INDEX_PREFIX, indexName, new String[] { LifecycleSettings.LIFECYCLE_NAME, ShardsLimitAllocationDecider.INDEX_TOTAL_SHARDS_PER_NODE_SETTING.getKey() } - ) - ) { + ); MountSnapshotStep step = new MountSnapshotStep( new StepKey(TimeseriesLifecycleType.FROZEN_PHASE, randomAlphaOfLength(10), randomAlphaOfLength(10)), randomStepKey(), @@ -350,8 +353,8 @@ public void testIgnoreTotalShardsPerNodeInFrozenPhase() throws Exception { } @SuppressWarnings("unchecked") - private NoOpClient getClientTriggeringResponse(RestoreSnapshotResponse response) { - return new NoOpClient(getTestName()) { + private NoOpClient getClientTriggeringResponse(ThreadPool threadPool, RestoreSnapshotResponse response) { + return new NoOpClient(threadPool) { @Override protected void doExecute( ActionType action, @@ -365,6 +368,7 @@ protected void @SuppressWarnings("unchecked") private NoOpClient getRestoreSnapshotRequestAssertingClient( + ThreadPool threadPool, String expectedRepoName, String expectedSnapshotName, String indexName, @@ -372,7 +376,7 @@ private NoOpClient getRestoreSnapshotRequestAssertingClient( String expectedSnapshotIndexName, String[] expectedIgnoredIndexSettings ) { - return new NoOpClient(getTestName()) { + return new NoOpClient(threadPool) { @Override protected void doExecute( ActionType action, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/SwapAliasesAndDeleteSourceIndexStepTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/SwapAliasesAndDeleteSourceIndexStepTests.java index 4c62781bc2406..7a09b375ed53b 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/SwapAliasesAndDeleteSourceIndexStepTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/SwapAliasesAndDeleteSourceIndexStepTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ilm.Step.StepKey; import java.util.Arrays; @@ -103,7 +104,8 @@ public void testPerformAction() { .isHidden(isHidden) ); - try (NoOpClient client = getIndicesAliasAssertingClient(expectedAliasActions)) { + try (var threadPool = createThreadPool()) { + final var client = getIndicesAliasAssertingClient(threadPool, expectedAliasActions); SwapAliasesAndDeleteSourceIndexStep step = new SwapAliasesAndDeleteSourceIndexStep( randomStepKey(), randomStepKey(), @@ -124,8 +126,8 @@ public void testPerformAction() { } } - private NoOpClient getIndicesAliasAssertingClient(List expectedAliasActions) { - return new NoOpClient(getTestName()) { + private NoOpClient getIndicesAliasAssertingClient(ThreadPool threadPool, List expectedAliasActions) { + return new NoOpClient(threadPool) { @Override protected void doExecute( ActionType action, diff --git a/x-pack/plugin/enrich/src/test/java/org/elasticsearch/xpack/enrich/EnrichProcessorFactoryTests.java b/x-pack/plugin/enrich/src/test/java/org/elasticsearch/xpack/enrich/EnrichProcessorFactoryTests.java index caa550f3658b8..a26cab231f52c 100644 --- a/x-pack/plugin/enrich/src/test/java/org/elasticsearch/xpack/enrich/EnrichProcessorFactoryTests.java +++ b/x-pack/plugin/enrich/src/test/java/org/elasticsearch/xpack/enrich/EnrichProcessorFactoryTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.action.ActionType; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.ShardSearchFailure; -import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.AliasMetadata; @@ -63,7 +62,8 @@ public void initializeScriptService() { public void testCreateProcessorInstance() throws Exception { List enrichValues = List.of("globalRank", "tldRank", "tld"); EnrichPolicy policy = new EnrichPolicy(EnrichPolicy.MATCH_TYPE, null, List.of("source_index"), "my_key", enrichValues); - try (Client client = new NoOpClient(this.getClass().getSimpleName() + "TestClient")) { + try (var threadPool = createThreadPool()) { + final var client = new NoOpClient(threadPool); EnrichProcessorFactory factory = new EnrichProcessorFactory(client, scriptService, enrichCache); factory.metadata = createMetadata("majestic", policy); @@ -172,7 +172,8 @@ public void testPolicyNameMissing() { public void testUnsupportedPolicy() throws Exception { List enrichValues = List.of("globalRank", "tldRank", "tld"); EnrichPolicy policy = new EnrichPolicy("unsupported", null, List.of("source_index"), "my_key", enrichValues); - try (Client client = new NoOpClient(this.getClass().getSimpleName() + "TestClient")) { + try (var threadPool = createThreadPool()) { + final var client = new NoOpClient(threadPool); EnrichProcessorFactory factory = new EnrichProcessorFactory(client, scriptService, enrichCache); factory.metadata = createMetadata("majestic", policy); @@ -193,7 +194,8 @@ public void testUnsupportedPolicy() throws Exception { public void testCompactEnrichValuesFormat() throws Exception { List enrichValues = List.of("globalRank", "tldRank", "tld"); EnrichPolicy policy = new EnrichPolicy(EnrichPolicy.MATCH_TYPE, null, List.of("source_index"), "host", enrichValues); - try (Client client = new NoOpClient(this.getClass().getSimpleName() + "TestClient")) { + try (var threadPool = createThreadPool()) { + final var client = new NoOpClient(threadPool); EnrichProcessorFactory factory = new EnrichProcessorFactory(client, scriptService, enrichCache); factory.metadata = createMetadata("majestic", policy); @@ -245,38 +247,38 @@ public void testCaching() throws Exception { enrichCache = new EnrichCache(100L); List enrichValues = List.of("globalRank", "tldRank", "tld"); EnrichPolicy policy = new EnrichPolicy(EnrichPolicy.MATCH_TYPE, null, List.of("source_index"), "host", enrichValues); - try (Client client = new NoOpClient(this.getClass().getSimpleName() + "testCaching") { - - @Override - @SuppressWarnings("unchecked") - protected void doExecute( - ActionType action, - Request request, - ActionListener listener - ) { - assert EnrichCoordinatorProxyAction.NAME.equals(action.name()); - var emptyResponse = new SearchResponse( - new InternalSearchResponse( - new SearchHits(new SearchHit[0], new TotalHits(0L, TotalHits.Relation.EQUAL_TO), 0.0f), - InternalAggregations.EMPTY, - new Suggest(Collections.emptyList()), - new SearchProfileResults(Collections.emptyMap()), - false, - false, - 1 - ), - "", - 1, - 1, - 0, - 0, - ShardSearchFailure.EMPTY_ARRAY, - SearchResponse.Clusters.EMPTY - ); - requestCounter[0]++; - listener.onResponse((Response) emptyResponse); - } - }) { + try (var threadPool = createThreadPool()) { + final var client = new NoOpClient(threadPool) { + @Override + @SuppressWarnings("unchecked") + protected void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + assert EnrichCoordinatorProxyAction.NAME.equals(action.name()); + var emptyResponse = new SearchResponse( + new InternalSearchResponse( + new SearchHits(new SearchHit[0], new TotalHits(0L, TotalHits.Relation.EQUAL_TO), 0.0f), + InternalAggregations.EMPTY, + new Suggest(Collections.emptyList()), + new SearchProfileResults(Collections.emptyMap()), + false, + false, + 1 + ), + "", + 1, + 1, + 0, + 0, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + requestCounter[0]++; + listener.onResponse((Response) emptyResponse); + } + }; EnrichProcessorFactory factory = new EnrichProcessorFactory(client, scriptService, enrichCache); factory.accept(ClusterState.builder(new ClusterName("_name")).metadata(createMetadata("majestic", policy)).build()); diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/AbstractRestEnterpriseSearchActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/AbstractRestEnterpriseSearchActionTests.java index 8591ec821271a..259beb008dd70 100644 --- a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/AbstractRestEnterpriseSearchActionTests.java +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/AbstractRestEnterpriseSearchActionTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.application; -import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; @@ -28,7 +27,8 @@ protected void checkLicenseForRequest(FakeRestRequest request, LicenseUtils.Prod final FakeRestChannel channel = new FakeRestChannel(request, true, 1); - try (NodeClient nodeClient = new NoOpNodeClient(this.getTestName())) { + try (var threadPool = createThreadPool()) { + final var nodeClient = new NoOpNodeClient(threadPool); action.handleRequest(request, channel, nodeClient); } assertThat(channel.capturedResponse(), notNullValue()); diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/EnterpriseSearchBaseRestHandlerTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/EnterpriseSearchBaseRestHandlerTests.java index 7ff8dae594e8a..6cf176e21498e 100644 --- a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/EnterpriseSearchBaseRestHandlerTests.java +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/EnterpriseSearchBaseRestHandlerTests.java @@ -59,7 +59,8 @@ public List routes() { FakeRestRequest fakeRestRequest = new FakeRestRequest(); FakeRestChannel fakeRestChannel = new FakeRestChannel(fakeRestRequest, randomBoolean(), licensedFeature ? 0 : 1); - try (NodeClient client = new NoOpNodeClient(this.getTestName())) { + try (var threadPool = createThreadPool()) { + final var client = new NoOpNodeClient(threadPool); assertFalse(consumerCalled.get()); verifyNoMoreInteractions(licenseState); handler.handleRequest(fakeRestRequest, fakeRestChannel, client); diff --git a/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/execution/sample/CircuitBreakerTests.java b/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/execution/sample/CircuitBreakerTests.java index 2df9452137ee4..663a0328a575b 100644 --- a/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/execution/sample/CircuitBreakerTests.java +++ b/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/execution/sample/CircuitBreakerTests.java @@ -34,6 +34,7 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.eql.EqlTestUtils; import org.elasticsearch.xpack.eql.analysis.PostAnalyzer; import org.elasticsearch.xpack.eql.analysis.PreAnalyzer; @@ -103,8 +104,9 @@ private void testMemoryCleared(boolean fail) { Collections.singletonList(EqlTestUtils.circuitBreakerSettings(Settings.EMPTY)), new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) ); - ESMockClient esClient = new ESMockClient(service.getBreaker(CIRCUIT_BREAKER_NAME)); + var threadPool = createThreadPool() ) { + final var esClient = new ESMockClient(threadPool, service.getBreaker(CIRCUIT_BREAKER_NAME)); CircuitBreaker eqlCircuitBreaker = service.getBreaker(CIRCUIT_BREAKER_NAME); IndexResolver indexResolver = new IndexResolver(esClient, "cluster", DefaultDataTypeRegistry.INSTANCE, () -> emptySet()); EqlSession eqlSession = new EqlSession( @@ -190,8 +192,8 @@ private class ESMockClient extends NoOpClient { protected final CircuitBreaker circuitBreaker; private final String pitId = "test_pit_id"; - ESMockClient(CircuitBreaker circuitBreaker) { - super(getTestName()); + ESMockClient(ThreadPool threadPool, CircuitBreaker circuitBreaker) { + super(threadPool); this.circuitBreaker = circuitBreaker; } diff --git a/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/execution/sequence/CircuitBreakerTests.java b/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/execution/sequence/CircuitBreakerTests.java index 51768096458f5..08d21de6d048a 100644 --- a/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/execution/sequence/CircuitBreakerTests.java +++ b/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/execution/sequence/CircuitBreakerTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.TriFunction; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.breaker.NoopCircuitBreaker; @@ -50,6 +51,7 @@ import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.async.AsyncExecutionId; import org.elasticsearch.xpack.eql.action.EqlSearchAction; import org.elasticsearch.xpack.eql.action.EqlSearchTask; @@ -81,7 +83,6 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicLong; -import java.util.function.BiFunction; import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; @@ -201,7 +202,10 @@ public void testMemoryClearedOnShardsException() { assertMemoryCleared(stages, FailureESMockClient::new); } - private void assertMemoryCleared(int sequenceFiltersCount, BiFunction esClientSupplier) { + private void assertMemoryCleared( + int sequenceFiltersCount, + TriFunction esClientSupplier + ) { final int searchRequestsExpectedCount = 2; try ( CircuitBreakerService service = new HierarchyCircuitBreakerService( @@ -209,8 +213,9 @@ private void assertMemoryCleared(int sequenceFiltersCount, BiFunction criteria = buildCriteria(sequenceFiltersCount); @@ -245,8 +250,14 @@ public void testEqlCBCleanedUp_on_ParentCBBreak() { breakerSettings(), new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) ); - ESMockClient esClient = new SuccessfulESMockClient(service.getBreaker(CIRCUIT_BREAKER_NAME), searchRequestsExpectedCount); + var threadPool = createThreadPool() ) { + final var esClient = new SuccessfulESMockClient( + threadPool, + service.getBreaker(CIRCUIT_BREAKER_NAME), + searchRequestsExpectedCount + ); + CircuitBreaker eqlCircuitBreaker = service.getBreaker(CIRCUIT_BREAKER_NAME); QueryClient eqlClient = buildQueryClient(esClient, eqlCircuitBreaker); List criteria = buildCriteria(sequenceFiltersCount); @@ -359,8 +370,8 @@ private abstract class ESMockClient extends NoOpClient { private int searchRequestsRemainingCount; private final String pitId = "test_pit_id"; - ESMockClient(CircuitBreaker circuitBreaker, int searchRequestsRemainingCount) { - super(getTestName()); + ESMockClient(ThreadPool threadPool, CircuitBreaker circuitBreaker, int searchRequestsRemainingCount) { + super(threadPool); this.circuitBreaker = circuitBreaker; this.searchRequestsRemainingCount = searchRequestsRemainingCount; } @@ -404,8 +415,8 @@ int searchRequestsRemainingCount() { */ private class SuccessfulESMockClient extends ESMockClient { - SuccessfulESMockClient(CircuitBreaker circuitBreaker, int expectedSearchRequestsCount) { - super(circuitBreaker, expectedSearchRequestsCount); + SuccessfulESMockClient(ThreadPool threadPool, CircuitBreaker circuitBreaker, int expectedSearchRequestsCount) { + super(threadPool, circuitBreaker, expectedSearchRequestsCount); } @SuppressWarnings("unchecked") @@ -447,8 +458,8 @@ void handleSearchRequest(ActionListener keyExtractors = emptyList(); public void testHandlingPitFailure() { - try (ESMockClient esClient = new ESMockClient();) { + try (var threadPool = createThreadPool()) { + final var esClient = new ESMockClient(threadPool); EqlConfiguration eqlConfiguration = new EqlConfiguration( new String[] { "test" }, @@ -146,8 +148,8 @@ public void testHandlingPitFailure() { */ private class ESMockClient extends NoOpClient { - ESMockClient() { - super(getTestName()); + ESMockClient(ThreadPool threadPool) { + super(threadPool); } @SuppressWarnings("unchecked") diff --git a/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/IndexLifecycleTransitionTests.java b/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/IndexLifecycleTransitionTests.java index b1e07729bee0a..9449e0c0574dc 100644 --- a/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/IndexLifecycleTransitionTests.java +++ b/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/IndexLifecycleTransitionTests.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.ilm; import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetadata; @@ -592,7 +591,8 @@ public void testValidateTransitionToCachedStepMissingFromPolicy() { IndexMetadata meta = buildIndexMetadata("my-policy", executionState); - try (Client client = new NoOpClient(getTestName())) { + try (var threadPool = createThreadPool()) { + final var client = new NoOpClient(threadPool); Step.StepKey currentStepKey = new Step.StepKey("hot", RolloverAction.NAME, WaitForRolloverReadyStep.NAME); Step.StepKey nextStepKey = new Step.StepKey("hot", RolloverAction.NAME, RolloverStep.NAME); Step currentStep = new WaitForRolloverReadyStep( @@ -648,7 +648,8 @@ public void testValidateTransitionToCachedStepWhenMissingPhaseFromPolicy() { IndexMetadata meta = buildIndexMetadata("my-policy", executionState); - try (Client client = new NoOpClient(getTestName())) { + try (var threadPool = createThreadPool()) { + final var client = new NoOpClient(threadPool); Step.StepKey currentStepKey = new Step.StepKey("warm", MigrateAction.NAME, DataTierMigrationRoutedStep.NAME); Step.StepKey nextStepKey = new Step.StepKey("warm", PhaseCompleteStep.NAME, PhaseCompleteStep.NAME); @@ -708,7 +709,8 @@ public void testValidateTransitionToInjectedMissingStep() { IndexMetadata meta = buildIndexMetadata("my-policy", executionState); - try (Client client = new NoOpClient(getTestName())) { + try (var threadPool = createThreadPool()) { + final var client = new NoOpClient(threadPool); Step.StepKey currentStepKey = new Step.StepKey("warm", MigrateAction.NAME, MigrateAction.NAME); Step.StepKey nextStepKey = new Step.StepKey("warm", MigrateAction.NAME, DataTierMigrationRoutedStep.NAME); @@ -1198,7 +1200,8 @@ public void testMoveStateToNextActionAndUpdateCachedPhase() { 2L ); - try (Client client = new NoOpClient(getTestName())) { + try (var threadPool = createThreadPool()) { + final var client = new NoOpClient(threadPool); LifecycleExecutionState newState = moveStateToNextActionAndUpdateCachedPhase( meta, meta.getLifecycleExecutionState(), @@ -1235,7 +1238,8 @@ public void testMoveStateToNextActionAndUpdateCachedPhase() { 2L ); - try (Client client = new NoOpClient(getTestName())) { + try (var threadPool = createThreadPool()) { + final var client = new NoOpClient(threadPool); LifecycleExecutionState newState = moveStateToNextActionAndUpdateCachedPhase( meta, meta.getLifecycleExecutionState(), diff --git a/x-pack/plugin/logstash/src/test/java/org/elasticsearch/xpack/logstash/action/TransportDeletePipelineActionTests.java b/x-pack/plugin/logstash/src/test/java/org/elasticsearch/xpack/logstash/action/TransportDeletePipelineActionTests.java index 56b7965be3686..159da551917c4 100644 --- a/x-pack/plugin/logstash/src/test/java/org/elasticsearch/xpack/logstash/action/TransportDeletePipelineActionTests.java +++ b/x-pack/plugin/logstash/src/test/java/org/elasticsearch/xpack/logstash/action/TransportDeletePipelineActionTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.MockUtils; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.RemoteTransportException; import org.elasticsearch.transport.TransportService; @@ -27,7 +28,8 @@ public class TransportDeletePipelineActionTests extends ESTestCase { public void testDeletePipelineWithMissingIndex() throws Exception { - try (Client client = getFailureClient(new IndexNotFoundException("missing .logstash"))) { + try (var threadPool = createThreadPool()) { + final var client = getFailureClient(threadPool, new IndexNotFoundException("missing .logstash")); TransportService transportService = MockUtils.setupTransportServiceWithThreadpoolExecutor(); final TransportDeletePipelineAction action = new TransportDeletePipelineAction( transportService, @@ -41,8 +43,8 @@ public void testDeletePipelineWithMissingIndex() throws Exception { } } - private Client getFailureClient(Exception e) { - return new NoOpClient(getTestName()) { + private Client getFailureClient(ThreadPool threadPool, Exception e) { + return new NoOpClient(threadPool) { @Override protected void doExecute( ActionType action, diff --git a/x-pack/plugin/logstash/src/test/java/org/elasticsearch/xpack/logstash/action/TransportGetPipelineActionTests.java b/x-pack/plugin/logstash/src/test/java/org/elasticsearch/xpack/logstash/action/TransportGetPipelineActionTests.java index d8a4d048f1fe4..7f1a0f2bcc2cb 100644 --- a/x-pack/plugin/logstash/src/test/java/org/elasticsearch/xpack/logstash/action/TransportGetPipelineActionTests.java +++ b/x-pack/plugin/logstash/src/test/java/org/elasticsearch/xpack/logstash/action/TransportGetPipelineActionTests.java @@ -35,6 +35,7 @@ import org.elasticsearch.test.MockLogAppender; import org.elasticsearch.test.MockUtils; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.RemoteTransportException; import org.elasticsearch.transport.TransportService; @@ -99,7 +100,8 @@ public void onFailure(Exception e) { } }; - try (Client client = getMockClient(multiGetResponse)) { + try (var threadPool = createThreadPool()) { + final var client = getMockClient(threadPool, multiGetResponse); Loggers.addAppender(logger, mockLogAppender); TransportService transportService = MockUtils.setupTransportServiceWithThreadpoolExecutor(); TransportGetPipelineAction action = new TransportGetPipelineAction(transportService, mock(ActionFilters.class), client); @@ -151,7 +153,8 @@ public void onFailure(Exception e) { }; TransportService transportService = MockUtils.setupTransportServiceWithThreadpoolExecutor(); - try (Client client = getMockClient(searchResponse)) { + try (var threadPool = createThreadPool()) { + final var client = getMockClient(threadPool, searchResponse); new TransportGetPipelineAction(transportService, mock(ActionFilters.class), client).doExecute( null, request, @@ -163,7 +166,8 @@ public void onFailure(Exception e) { } public void testMissingIndexHandling() throws Exception { - try (Client failureClient = getFailureClient(new IndexNotFoundException("foo"))) { + try (var threadPool = createThreadPool()) { + final var failureClient = getFailureClient(threadPool, new IndexNotFoundException("foo")); TransportService transportService = MockUtils.setupTransportServiceWithThreadpoolExecutor(); final TransportGetPipelineAction action = new TransportGetPipelineAction( transportService, @@ -178,8 +182,8 @@ public void testMissingIndexHandling() throws Exception { } } - private Client getMockClient(ActionResponse response) { - return new NoOpClient(getTestName()) { + private Client getMockClient(ThreadPool threadPool, ActionResponse response) { + return new NoOpClient(threadPool) { @Override @SuppressWarnings("unchecked") protected void doExecute( @@ -192,8 +196,8 @@ protected void }; } - private Client getFailureClient(Exception e) { - return new NoOpClient(getTestName()) { + private Client getFailureClient(ThreadPool threadPool, Exception e) { + return new NoOpClient(threadPool) { @Override protected void doExecute( ActionType action, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java index 8ae2015466b02..164d4efe6b6f5 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xpack.core.action.util.PageParams; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; @@ -346,7 +347,8 @@ public void testGetDefinitionFromDocs() { } public void testStoreTrainedModelConfigCallsClientExecuteWithOperationCreate() { - try (var client = createMockClient()) { + try (var threadPool = createThreadPool()) { + final var client = createMockClient(threadPool); var config = TrainedModelConfigTests.createTestInstance("modelId").build(); var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); var future = new PlainActionFuture(); @@ -357,7 +359,8 @@ public void testStoreTrainedModelConfigCallsClientExecuteWithOperationCreate() { } public void testStoreTrainedModelConfigCallsClientExecuteWithOperationCreateWhenAllowOverwriteIsFalse() { - try (var client = createMockClient()) { + try (var threadPool = createThreadPool()) { + final var client = createMockClient(threadPool); var config = TrainedModelConfigTests.createTestInstance("modelId").build(); var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); var future = new PlainActionFuture(); @@ -368,7 +371,8 @@ public void testStoreTrainedModelConfigCallsClientExecuteWithOperationCreateWhen } public void testStoreTrainedModelConfigCallsClientExecuteWithOperationIndex() { - try (var client = createMockClient()) { + try (var threadPool = createThreadPool()) { + final var client = createMockClient(threadPool); var config = TrainedModelConfigTests.createTestInstance("modelId").build(); var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); var future = new PlainActionFuture(); @@ -379,7 +383,8 @@ public void testStoreTrainedModelConfigCallsClientExecuteWithOperationIndex() { } public void testStoreTrainedModelWithDefinitionCallsClientExecuteWithOperationCreate() throws IOException { - try (var client = createMockClient()) { + try (var threadPool = createThreadPool()) { + final var client = createMockClient(threadPool); var config = createTrainedModelConfigWithDefinition("modelId"); var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); var future = new PlainActionFuture(); @@ -390,7 +395,8 @@ public void testStoreTrainedModelWithDefinitionCallsClientExecuteWithOperationCr } public void testStoreTrainedModelWithDefinitionCallsClientExecuteWithOperationCreateWhenAllowOverwriteIsFalse() throws IOException { - try (var client = createMockClient()) { + try (var threadPool = createThreadPool()) { + final var client = createMockClient(threadPool); var config = createTrainedModelConfigWithDefinition("modelId"); var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); var future = new PlainActionFuture(); @@ -401,7 +407,8 @@ public void testStoreTrainedModelWithDefinitionCallsClientExecuteWithOperationCr } public void testStoreTrainedModelWithDefinitionCallsClientExecuteWithOperationIndex() throws IOException { - try (var client = createMockClient()) { + try (var threadPool = createThreadPool()) { + final var client = createMockClient(threadPool); var config = createTrainedModelConfigWithDefinition("modelId"); var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); var future = new PlainActionFuture(); @@ -412,7 +419,8 @@ public void testStoreTrainedModelWithDefinitionCallsClientExecuteWithOperationIn } public void testStoreTrainedModelDefinitionDocCallsClientExecuteWithOperationCreate() { - try (var client = createMockClient()) { + try (var threadPool = createThreadPool()) { + final var client = createMockClient(threadPool); var config = TrainedModelDefinitionDocTests.createDefinitionDocInstance(); var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); var future = new PlainActionFuture(); @@ -423,7 +431,8 @@ public void testStoreTrainedModelDefinitionDocCallsClientExecuteWithOperationCre } public void testStoreTrainedModelDefinitionDocCallsClientExecuteWithOperationCreateWhenAllowOverwriteIsFalse() { - try (var client = createMockClient()) { + try (var threadPool = createThreadPool()) { + final var client = createMockClient(threadPool); var config = TrainedModelDefinitionDocTests.createDefinitionDocInstance(); var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); var future = new PlainActionFuture(); @@ -434,7 +443,8 @@ public void testStoreTrainedModelDefinitionDocCallsClientExecuteWithOperationCre } public void testStoreTrainedModelDefinitionDocCallsClientExecuteWithOperationIndex() { - try (var client = createMockClient()) { + try (var threadPool = createThreadPool()) { + final var client = createMockClient(threadPool); var config = TrainedModelDefinitionDocTests.createDefinitionDocInstance(); var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); var future = new PlainActionFuture(); @@ -445,7 +455,8 @@ public void testStoreTrainedModelDefinitionDocCallsClientExecuteWithOperationInd } public void testStoreTrainedModelVocabularyCallsClientExecuteWithOperationCreate() { - try (var client = createMockClient()) { + try (var threadPool = createThreadPool()) { + final var client = createMockClient(threadPool); var vocab = createVocabulary(); var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); var future = new PlainActionFuture(); @@ -456,7 +467,8 @@ public void testStoreTrainedModelVocabularyCallsClientExecuteWithOperationCreate } public void testStoreTrainedModelVocabularyCallsClientExecuteWithOperationCreateWhenAllowOverwritingIsFalse() { - try (var client = createMockClient()) { + try (var threadPool = createThreadPool()) { + final var client = createMockClient(threadPool); var vocab = createVocabulary(); var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); var future = new PlainActionFuture(); @@ -467,7 +479,8 @@ public void testStoreTrainedModelVocabularyCallsClientExecuteWithOperationCreate } public void testStoreTrainedModelVocabularyCallsClientExecuteWithOperationIndex() { - try (var client = createMockClient()) { + try (var threadPool = createThreadPool()) { + final var client = createMockClient(threadPool); var vocab = createVocabulary(); var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); var future = new PlainActionFuture(); @@ -478,7 +491,8 @@ public void testStoreTrainedModelVocabularyCallsClientExecuteWithOperationIndex( } public void testStoreTrainedModelMetadataCallsClientExecuteWithOperationCreate() { - try (var client = createMockClient()) { + try (var threadPool = createThreadPool()) { + final var client = createMockClient(threadPool); var metadata = TrainedModelMetadataTests.randomInstance(); var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); var future = new PlainActionFuture(); @@ -489,7 +503,8 @@ public void testStoreTrainedModelMetadataCallsClientExecuteWithOperationCreate() } public void testStoreTrainedModelMetadataCallsClientExecuteWithOperationCreateWhenAllowOverwritingIsFalse() { - try (var client = createMockClient()) { + try (var threadPool = createThreadPool()) { + final var client = createMockClient(threadPool); var metadata = TrainedModelMetadataTests.randomInstance(); var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); var future = new PlainActionFuture(); @@ -500,7 +515,8 @@ public void testStoreTrainedModelMetadataCallsClientExecuteWithOperationCreateWh } public void testStoreTrainedModelMetadataCallsClientExecuteWithOperationIndex() { - try (var client = createMockClient()) { + try (var threadPool = createThreadPool()) { + final var client = createMockClient(threadPool); var metadata = TrainedModelMetadataTests.randomInstance(); var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); var future = new PlainActionFuture(); @@ -520,10 +536,8 @@ private TrainedModelConfig createTrainedModelConfigWithDefinition(String modelId return TrainedModelConfigTests.createTestInstance(modelId).setDefinitionFromBytes(bytes).build(); } - private Client createMockClient() { - var noOpClient = new NoOpClient(getTestName()); - - return spy(noOpClient); + private Client createMockClient(ThreadPool threadPool) { + return spy(new NoOpClient(threadPool)); } private void assertThatIndexRequestHasOperation(Client client, DocWriteRequest.OpType operation) { diff --git a/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/exporter/local/LocalExporterTests.java b/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/exporter/local/LocalExporterTests.java index a30975be1055d..2d692e977f3d5 100644 --- a/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/exporter/local/LocalExporterTests.java +++ b/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/exporter/local/LocalExporterTests.java @@ -51,7 +51,8 @@ public void testLocalExporterDoesNotInteractWithClusterServiceUntilStateIsRecove final Exporter.Config config = new Exporter.Config("name", "type", Settings.EMPTY, clusterService, licenseState); final CleanerService cleanerService = mock(CleanerService.class); final MonitoringMigrationCoordinator migrationCoordinator = new MonitoringMigrationCoordinator(); - try (Client client = new NoOpClient(getTestName())) { + try (var threadPool = createThreadPool()) { + final var client = new NoOpClient(threadPool); final LocalExporter exporter = new LocalExporter(config, client, migrationCoordinator, cleanerService); final TimeValue retention = TimeValue.timeValueDays(randomIntBetween(1, 90)); diff --git a/x-pack/plugin/rollup/src/test/java/org/elasticsearch/xpack/rollup/job/RollupJobTaskTests.java b/x-pack/plugin/rollup/src/test/java/org/elasticsearch/xpack/rollup/job/RollupJobTaskTests.java index c315218bdaab4..5befaafba0f8a 100644 --- a/x-pack/plugin/rollup/src/test/java/org/elasticsearch/xpack/rollup/job/RollupJobTaskTests.java +++ b/x-pack/plugin/rollup/src/test/java/org/elasticsearch/xpack/rollup/job/RollupJobTaskTests.java @@ -64,7 +64,7 @@ public class RollupJobTaskTests extends ESTestCase { private ThreadPool pool; @Before - public void createThreadPool() { + public void createSuiteThreadPool() { pool = new TestThreadPool("test"); } @@ -291,7 +291,8 @@ public void testStartWhenStopping() throws InterruptedException { final CountDownLatch block = new CountDownLatch(1); final CountDownLatch unblock = new CountDownLatch(1); - try (NoOpClient client = getEmptySearchResponseClient(block, unblock)) { + try (var threadPool = createThreadPool()) { + final var client = getEmptySearchResponseClient(threadPool, block, unblock); SchedulerEngine schedulerEngine = mock(SchedulerEngine.class); AtomicInteger counter = new AtomicInteger(0); @@ -949,7 +950,8 @@ public void testStopWhenStopping() throws InterruptedException { RollupJob job = new RollupJob(ConfigTestHelpers.randomRollupJobConfig(random()), Collections.emptyMap()); final CountDownLatch block = new CountDownLatch(1); final CountDownLatch unblock = new CountDownLatch(1); - try (NoOpClient client = getEmptySearchResponseClient(block, unblock)) { + try (var threadPool = createThreadPool()) { + final var client = getEmptySearchResponseClient(threadPool, block, unblock); SchedulerEngine schedulerEngine = mock(SchedulerEngine.class); AtomicInteger counter = new AtomicInteger(0); @@ -1118,8 +1120,8 @@ private static void assertUnblockIn10s(CountDownLatch latch) { } } - private NoOpClient getEmptySearchResponseClient(CountDownLatch unblock, CountDownLatch block) { - return new NoOpClient(getTestName()) { + private NoOpClient getEmptySearchResponseClient(ThreadPool threadPool, CountDownLatch unblock, CountDownLatch block) { + return new NoOpClient(threadPool) { @SuppressWarnings("unchecked") @Override protected void doExecute( diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/role/PutRoleBuilderTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/role/PutRoleBuilderTests.java index efe7892daec70..984442e82be16 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/role/PutRoleBuilderTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/role/PutRoleBuilderTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.security.action.role; import org.elasticsearch.ElasticsearchParseException; -import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpClient; @@ -26,7 +25,8 @@ public void testBWCFieldPermissions() throws Exception { Path path = getDataPath("roles2xformat.json"); byte[] bytes = Files.readAllBytes(path); String roleString = new String(bytes, Charset.defaultCharset()); - try (Client client = new NoOpClient("testBWCFieldPermissions")) { + try (var threadPool = createThreadPool()) { + final var client = new NoOpClient(threadPool); ElasticsearchParseException e = expectThrows( ElasticsearchParseException.class, () -> new PutRoleRequestBuilder(client).source("role1", new BytesArray(roleString), XContentType.JSON) diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authz/store/NativePrivilegeStoreTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authz/store/NativePrivilegeStoreTests.java index 01d3ca6db354e..ecc69e957d8ba 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authz/store/NativePrivilegeStoreTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authz/store/NativePrivilegeStoreTests.java @@ -38,7 +38,6 @@ import org.elasticsearch.test.ClusterServiceUtils; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpClient; -import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentType; @@ -108,7 +107,8 @@ public class NativePrivilegeStoreTests extends ESTestCase { public void setup() { requests = new ArrayList<>(); listener = new AtomicReference<>(); - client = new NoOpClient(getTestName()) { + threadPool = createThreadPool(); + client = new NoOpClient(threadPool) { @Override @SuppressWarnings("unchecked") protected void doExecute( @@ -144,7 +144,6 @@ public void searchScroll(SearchScrollRequest request, ActionListener void super.doExecute(action, request, listener); } } - } - ) { + }; SnapshotLifecyclePolicy policy = new SnapshotLifecyclePolicy( policyId, "snap", @@ -307,8 +308,9 @@ public void testErrStillRunsFailureHandlerWhenDeleting() throws Exception { ); try ( ClusterService clusterService = ClusterServiceUtils.createClusterService(threadPool, settings); - Client noOpClient = new NoOpClient("slm-test") { - + var clientThreadPool = createThreadPool() + ) { + final var noOpClient = new NoOpClient(clientThreadPool) { @Override @SuppressWarnings("unchecked") protected void doExecute( @@ -323,8 +325,7 @@ protected void super.doExecute(action, request, listener); } } - } - ) { + }; final String policyId = "policy"; final String repoId = "repo"; SnapshotLifecyclePolicy policy = new SnapshotLifecyclePolicy( @@ -393,8 +394,9 @@ private void doTestSkipDuringMode(OperationMode mode) throws Exception { ); try ( ClusterService clusterService = ClusterServiceUtils.createClusterService(threadPool, settings); - Client noOpClient = new NoOpClient("slm-test") + var clientThreadPool = createThreadPool() ) { + final var noOpClient = new NoOpClient(clientThreadPool); final String policyId = "policy"; final String repoId = "repo"; SnapshotLifecyclePolicy policy = new SnapshotLifecyclePolicy( @@ -449,8 +451,9 @@ private void doTestRunManuallyDuringMode(OperationMode mode) throws Exception { ); try ( ClusterService clusterService = ClusterServiceUtils.createClusterService(threadPool, settings); - Client noOpClient = new NoOpClient("slm-test") + var clientThreadPool = createThreadPool() ) { + final var noOpClient = new NoOpClient(clientThreadPool); final String policyId = "policy"; final String repoId = "repo"; SnapshotLifecyclePolicy policy = new SnapshotLifecyclePolicy( diff --git a/x-pack/plugin/transform/src/internalClusterTest/java/org/elasticsearch/xpack/transform/checkpoint/TransformCheckpointServiceNodeTests.java b/x-pack/plugin/transform/src/internalClusterTest/java/org/elasticsearch/xpack/transform/checkpoint/TransformCheckpointServiceNodeTests.java index 9b6b67e76c01c..366f3e6f917bf 100644 --- a/x-pack/plugin/transform/src/internalClusterTest/java/org/elasticsearch/xpack/transform/checkpoint/TransformCheckpointServiceNodeTests.java +++ b/x-pack/plugin/transform/src/internalClusterTest/java/org/elasticsearch/xpack/transform/checkpoint/TransformCheckpointServiceNodeTests.java @@ -48,6 +48,8 @@ import org.elasticsearch.tasks.TaskId; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.ActionNotFoundTransportException; import org.elasticsearch.xpack.core.transform.action.GetCheckpointAction; import org.elasticsearch.xpack.core.transform.transforms.TransformCheckpoint; @@ -82,6 +84,7 @@ public class TransformCheckpointServiceNodeTests extends TransformSingleNodeTest // re-use the mock client for the whole test suite as the underlying thread pool and the // corresponding context if recreated cause unreliable test execution // see https://github.com/elastic/elasticsearch/issues/45238 and https://github.com/elastic/elasticsearch/issues/42577 + private static TestThreadPool threadPool; private static MockClientForCheckpointing mockClientForCheckpointing = null; private IndexBasedTransformConfigManager transformsConfigManager; @@ -96,11 +99,10 @@ private class MockClientForCheckpointing extends NoOpClient { /** * Mock client for checkpointing * - * @param testName name of the test, used for naming the threadpool * @param supportTransformCheckpointApi whether to mock the checkpoint API, if false throws action not found */ - MockClientForCheckpointing(String testName, boolean supportTransformCheckpointApi) { - super(testName); + MockClientForCheckpointing(ThreadPool threadPool, boolean supportTransformCheckpointApi) { + super(threadPool); this.supportTransformCheckpointApi = supportTransformCheckpointApi; } @@ -156,8 +158,11 @@ protected void @Before public void createComponents() { // it's not possible to run it as @BeforeClass as clients aren't initialized + if (threadPool == null) { + threadPool = new TestThreadPool("TransformCheckpointServiceNodeTests"); + } if (mockClientForCheckpointing == null) { - mockClientForCheckpointing = new MockClientForCheckpointing("TransformCheckpointServiceNodeTests", randomBoolean()); + mockClientForCheckpointing = new MockClientForCheckpointing(threadPool, randomBoolean()); } ClusterService clusterService = mock(ClusterService.class); transformsConfigManager = new IndexBasedTransformConfigManager( @@ -185,8 +190,9 @@ public void createComponents() { @AfterClass public static void tearDownClient() { - mockClientForCheckpointing.close(); mockClientForCheckpointing = null; + threadPool.close(); + threadPool = null; } public void testCreateReadDeleteCheckpoint() throws InterruptedException { diff --git a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/action/TransformPrivilegeCheckerTests.java b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/action/TransformPrivilegeCheckerTests.java index 5dd01a32d5c43..8549f669dda0d 100644 --- a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/action/TransformPrivilegeCheckerTests.java +++ b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/action/TransformPrivilegeCheckerTests.java @@ -23,7 +23,6 @@ import org.elasticsearch.indices.TestIndexNameExpressionResolver; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpClient; -import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.security.SecurityContext; import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesRequest; @@ -93,8 +92,8 @@ public void setupClient() { if (client != null) { client.close(); } - client = new MyMockClient(getTestName()); - threadPool = new TestThreadPool("transform_privilege_checker_tests"); + threadPool = createThreadPool(); + client = new MyMockClient(threadPool); securityContext = new SecurityContext(Settings.EMPTY, threadPool.getThreadContext()) { public User getUser() { return new User(USER_NAME); @@ -404,8 +403,8 @@ private static class MyMockClient extends NoOpClient { emptyMap() ); - MyMockClient(String testName) { - super(testName); + MyMockClient(ThreadPool threadPool) { + super(threadPool); } @SuppressWarnings("unchecked") diff --git a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/action/TransformUpdaterTests.java b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/action/TransformUpdaterTests.java index 27ca18d0e1d2f..fa957a2ac89cf 100644 --- a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/action/TransformUpdaterTests.java +++ b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/action/TransformUpdaterTests.java @@ -26,6 +26,8 @@ import org.elasticsearch.indices.TestIndexNameExpressionResolver; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.indexing.IndexerState; import org.elasticsearch.xpack.core.security.SecurityContext; @@ -73,6 +75,7 @@ public class TransformUpdaterTests extends ESTestCase { private static final String JOHN = "john"; private final SecurityContext johnSecurityContext = newSecurityContextFor(JOHN); private final IndexNameExpressionResolver indexNameExpressionResolver = TestIndexNameExpressionResolver.newInstance(); + private TestThreadPool threadPool; private Client client; private ClusterService clusterService = mock(ClusterService.class); private TransformAuditor auditor = new MockTransformAuditor(clusterService); @@ -81,8 +84,8 @@ public class TransformUpdaterTests extends ESTestCase { private static class MyMockClient extends NoOpClient { - MyMockClient(String testName) { - super(testName); + MyMockClient(ThreadPool threadPool) { + super(threadPool); } @SuppressWarnings("unchecked") @@ -116,17 +119,18 @@ protected void @Before public void setupClient() { - if (client != null) { - client.close(); + if (threadPool != null) { + threadPool.close(); } - client = new MyMockClient(getTestName()); + threadPool = createThreadPool(); + client = new MyMockClient(threadPool); clusterService = mock(ClusterService.class); auditor = new MockTransformAuditor(clusterService); } @After public void tearDownClient() { - client.close(); + threadPool.close(); } public void testTransformUpdateNoAction() throws InterruptedException { diff --git a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/ClientTransformIndexerTests.java b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/ClientTransformIndexerTests.java index a850c7beef7dd..06de37af346d2 100644 --- a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/ClientTransformIndexerTests.java +++ b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/ClientTransformIndexerTests.java @@ -127,7 +127,8 @@ public void testPitInjection() throws InterruptedException { new SettingsConfig.Builder().setUsePit(true).build() ).build(); - try (PitMockClient client = new PitMockClient(getTestName(), true)) { + try (var threadPool = createThreadPool()) { + final var client = new PitMockClient(threadPool, true); MockClientTransformIndexer indexer = new MockClientTransformIndexer( mock(ThreadPool.class), new TransformServices( @@ -220,7 +221,8 @@ public void testPitInjectionIfPitNotSupported() throws InterruptedException { new SettingsConfig.Builder().setUsePit(true).build() ).build(); - try (PitMockClient client = new PitMockClient(getTestName(), false)) { + try (var threadPool = createThreadPool()) { + final var client = new PitMockClient(threadPool, false); MockClientTransformIndexer indexer = new MockClientTransformIndexer( mock(ThreadPool.class), new TransformServices( @@ -296,7 +298,8 @@ public void testDisablePit() throws InterruptedException { TransformConfig config = TransformConfigTests.randomTransformConfig(); boolean pitEnabled = config.getSettings().getUsePit() == null || config.getSettings().getUsePit(); - try (PitMockClient client = new PitMockClient(getTestName(), true)) { + try (var threadPool = createThreadPool()) { + final var client = new PitMockClient(threadPool, true); MockClientTransformIndexer indexer = new MockClientTransformIndexer( mock(ThreadPool.class), new TransformServices( @@ -359,7 +362,8 @@ public void testDisablePitWhenThereIsRemoteIndexInSource() throws InterruptedExc .build(); boolean pitEnabled = config.getSettings().getUsePit() == null || config.getSettings().getUsePit(); - try (PitMockClient client = new PitMockClient(getTestName(), true)) { + try (var threadPool = createThreadPool()) { + final var client = new PitMockClient(threadPool, true); MockClientTransformIndexer indexer = new MockClientTransformIndexer( mock(ThreadPool.class), new TransformServices( @@ -413,7 +417,8 @@ public void testDisablePitWhenThereIsRemoteIndexInSource() throws InterruptedExc public void testHandlePitIndexNotFound() throws InterruptedException { // simulate a deleted index due to ILM - try (PitMockClient client = new PitMockClient(getTestName(), true)) { + try (var threadPool = createThreadPool()) { + final var client = new PitMockClient(threadPool, true); ClientTransformIndexer indexer = createTestIndexer(new ParentTaskAssigningClient(client, new TaskId("dummy-node:123456"))); SearchRequest searchRequest = new SearchRequest("deleted-index"); searchRequest.source().pointInTimeBuilder(new PointInTimeBuilder("the_pit_id")); @@ -425,7 +430,8 @@ public void testHandlePitIndexNotFound() throws InterruptedException { } // simulate a deleted index that is essential, search must fail (after a retry without pit) - try (PitMockClient client = new PitMockClient(getTestName(), true)) { + try (var threadPool = createThreadPool()) { + final var client = new PitMockClient(threadPool, true); ClientTransformIndexer indexer = createTestIndexer(new ParentTaskAssigningClient(client, new TaskId("dummy-node:123456"))); SearchRequest searchRequest = new SearchRequest("essential-deleted-index"); searchRequest.source().pointInTimeBuilder(new PointInTimeBuilder("the_pit_id")); @@ -483,8 +489,8 @@ private static class PitMockClient extends NoOpClient { private final boolean pitSupported; private AtomicLong pitContextCounter = new AtomicLong(); - PitMockClient(String testName, boolean pitSupported) { - super(testName); + PitMockClient(ThreadPool threadPool, boolean pitSupported) { + super(threadPool); this.pitSupported = pitSupported; } diff --git a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformIndexerFailureHandlingTests.java b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformIndexerFailureHandlingTests.java index 161f364b6e7e4..f59aaab33f0f1 100644 --- a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformIndexerFailureHandlingTests.java +++ b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformIndexerFailureHandlingTests.java @@ -35,7 +35,6 @@ import org.elasticsearch.search.suggest.Suggest; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpClient; -import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.common.notifications.Level; import org.elasticsearch.xpack.core.indexing.IndexerState; @@ -274,13 +273,12 @@ protected void persistState(TransformState state, ActionListener listener) @Before public void setUpMocks() { - client = new NoOpClient(getTestName()); - threadPool = new TestThreadPool(getTestName()); + threadPool = createThreadPool(); + client = new NoOpClient(threadPool); } @After public void tearDownClient() { - client.close(); ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); } diff --git a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformIndexerFailureOnStatePersistenceTests.java b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformIndexerFailureOnStatePersistenceTests.java index aeb94cd2c2f66..33ced92a8fa19 100644 --- a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformIndexerFailureOnStatePersistenceTests.java +++ b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformIndexerFailureOnStatePersistenceTests.java @@ -10,7 +10,6 @@ import org.elasticsearch.ElasticsearchTimeoutException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.LatchedActionListener; -import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.ParentTaskAssigningClient; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Tuple; @@ -209,7 +208,8 @@ public void fail(String failureMessage, ActionListener listener) { ? new VersionConflictEngineException(new ShardId("index", "indexUUID", 42), "some_id", 45L, 44L, 43L, 42L) : new ElasticsearchTimeoutException("timeout"); TransformConfigManager configManager = new FailingToPutStoredDocTransformConfigManager(Set.of(0, 1, 2, 3), exceptionToThrow); - try (Client client = new NoOpClient(getTestName())) { + try (var threadPool = createThreadPool()) { + final var client = new NoOpClient(threadPool); MockClientTransformIndexer indexer = new MockClientTransformIndexer( mock(ThreadPool.class), @@ -292,7 +292,8 @@ public void fail(String failureMessage, ActionListener listener) { ? new VersionConflictEngineException(new ShardId("index", "indexUUID", 42), "some_id", 45L, 44L, 43L, 42L) : new ElasticsearchTimeoutException("timeout"); TransformConfigManager configManager = new FailingToPutStoredDocTransformConfigManager(Set.of(0, 2, 3, 4), exceptionToThrow); - try (Client client = new NoOpClient(getTestName())) { + try (var threadPool = createThreadPool()) { + final var client = new NoOpClient(threadPool); MockClientTransformIndexer indexer = new MockClientTransformIndexer( mock(ThreadPool.class), new TransformServices( @@ -422,7 +423,8 @@ public void fail(String failureMessage, ActionListener listener) { TransformContext context = new TransformContext(state.get(), null, 0, contextListener); TransformConfigManager configManager = new SeqNoCheckingTransformConfigManager(); - try (Client client = new NoOpClient(getTestName())) { + try (var threadPool = createThreadPool()) { + final var client = new NoOpClient(threadPool); MockClientTransformIndexer indexer = new MockClientTransformIndexer( mock(ThreadPool.class), new TransformServices( diff --git a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformIndexerStateTests.java b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformIndexerStateTests.java index 33a72fd1e0181..638a66fa3fb0d 100644 --- a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformIndexerStateTests.java +++ b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformIndexerStateTests.java @@ -351,13 +351,12 @@ public void initialize() { public void setUpMocks() { auditor = MockTransformAuditor.createMockAuditor(); transformConfigManager = new InMemoryTransformConfigManager(); - client = new NoOpClient(getTestName()); threadPool = new TestThreadPool(ThreadPool.Names.GENERIC); + client = new NoOpClient(threadPool); } @After public void tearDownClient() { - client.close(); ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); } diff --git a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformIndexerTests.java b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformIndexerTests.java index 41fcd84af2fcc..6406308312f04 100644 --- a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformIndexerTests.java +++ b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformIndexerTests.java @@ -280,13 +280,12 @@ void validate(ActionListener listener) { public void setUpMocks() { auditor = MockTransformAuditor.createMockAuditor(); transformConfigManager = new InMemoryTransformConfigManager(); - client = new NoOpClient(getTestName()); threadPool = new TestThreadPool(ThreadPool.Names.GENERIC); + client = new NoOpClient(threadPool); } @After public void tearDownClient() { - client.close(); ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); } diff --git a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformTaskTests.java b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformTaskTests.java index 4d5807913636d..277553cd9f4ec 100644 --- a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformTaskTests.java +++ b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformTaskTests.java @@ -29,6 +29,7 @@ import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.indexing.IndexerState; import org.elasticsearch.xpack.core.transform.TransformConfigVersion; @@ -69,19 +70,21 @@ public class TransformTaskTests extends ESTestCase { + private TestThreadPool threadPool; private Client client; @Before public void setupClient() { - if (client != null) { - client.close(); + if (threadPool != null) { + threadPool.close(); } - client = new NoOpClient(getTestName()); + threadPool = createThreadPool(); + client = new NoOpClient(threadPool); } @After public void tearDownClient() { - client.close(); + threadPool.close(); } // see https://github.com/elastic/elasticsearch/issues/48957 diff --git a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/AggregationSchemaAndResultTests.java b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/AggregationSchemaAndResultTests.java index b93c2544da456..a2dda2a1603f1 100644 --- a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/AggregationSchemaAndResultTests.java +++ b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/AggregationSchemaAndResultTests.java @@ -25,6 +25,8 @@ import org.elasticsearch.search.aggregations.metrics.Percentile; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.transform.transforms.pivot.AggregationConfig; import org.elasticsearch.xpack.core.transform.transforms.pivot.GroupConfig; import org.elasticsearch.xpack.core.transform.transforms.pivot.GroupConfigTests; @@ -48,25 +50,27 @@ public class AggregationSchemaAndResultTests extends ESTestCase { + private TestThreadPool threadPool; private Client client; @Before public void setupClient() { - if (client != null) { - client.close(); + if (threadPool != null) { + threadPool.close(); } - client = new MyMockClient(getTestName()); + threadPool = createThreadPool(); + client = new MyMockClient(threadPool); } @After public void tearDownClient() { - client.close(); + threadPool.close(); } private class MyMockClient extends NoOpClient { - MyMockClient(String testName) { - super(testName); + MyMockClient(ThreadPool threadPool) { + super(threadPool); } @SuppressWarnings("unchecked") diff --git a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/PivotTests.java b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/PivotTests.java index 40af93c35cd39..37bee4a4eb999 100644 --- a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/PivotTests.java +++ b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/PivotTests.java @@ -30,6 +30,8 @@ import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; @@ -82,6 +84,7 @@ public class PivotTests extends ESTestCase { private NamedXContentRegistry namedXContentRegistry; + private TestThreadPool threadPool; private Client client; // exclude aggregations from the analytics module as we don't have parser for it here @@ -102,15 +105,16 @@ public void registerAggregationNamedObjects() throws Exception { @Before public void setupClient() { - if (client != null) { - client.close(); + if (threadPool != null) { + threadPool.close(); } - client = new MyMockClient(getTestName()); + threadPool = createThreadPool(); + client = new MyMockClient(threadPool); } @After public void tearDownClient() { - client.close(); + threadPool.close(); } @Override @@ -279,17 +283,17 @@ public void testPreviewForEmptyAggregation() throws Exception { final AtomicReference exceptionHolder = new AtomicReference<>(); final AtomicReference>> responseHolder = new AtomicReference<>(); - Client emptyAggregationClient = new MyMockClientWithEmptyAggregation("empty aggregation test for preview"); - pivot.preview(emptyAggregationClient, null, new HashMap<>(), new SourceConfig("test"), null, 1, ActionListener.wrap(r -> { - responseHolder.set(r); - latch.countDown(); - }, e -> { - exceptionHolder.set(e); - latch.countDown(); - })); - assertTrue(latch.await(100, TimeUnit.MILLISECONDS)); - emptyAggregationClient.close(); - + try (var threadPool = createThreadPool()) { + final var emptyAggregationClient = new MyMockClientWithEmptyAggregation(threadPool); + pivot.preview(emptyAggregationClient, null, new HashMap<>(), new SourceConfig("test"), null, 1, ActionListener.wrap(r -> { + responseHolder.set(r); + latch.countDown(); + }, e -> { + exceptionHolder.set(e); + latch.countDown(); + })); + assertTrue(latch.await(100, TimeUnit.MILLISECONDS)); + } assertThat(exceptionHolder.get(), is(nullValue())); assertThat(responseHolder.get(), is(empty())); } @@ -306,16 +310,17 @@ public void testPreviewForCompositeAggregation() throws Exception { final AtomicReference exceptionHolder = new AtomicReference<>(); final AtomicReference>> responseHolder = new AtomicReference<>(); - Client compositeAggregationClient = new MyMockClientWithCompositeAggregation("composite aggregation test for preview"); - pivot.preview(compositeAggregationClient, null, new HashMap<>(), new SourceConfig("test"), null, 1, ActionListener.wrap(r -> { - responseHolder.set(r); - latch.countDown(); - }, e -> { - exceptionHolder.set(e); - latch.countDown(); - })); - assertTrue(latch.await(100, TimeUnit.MILLISECONDS)); - compositeAggregationClient.close(); + try (var threadPool = createThreadPool()) { + final var compositeAggregationClient = new MyMockClientWithCompositeAggregation(threadPool); + pivot.preview(compositeAggregationClient, null, new HashMap<>(), new SourceConfig("test"), null, 1, ActionListener.wrap(r -> { + responseHolder.set(r); + latch.countDown(); + }, e -> { + exceptionHolder.set(e); + latch.countDown(); + })); + assertTrue(latch.await(100, TimeUnit.MILLISECONDS)); + } assertThat(exceptionHolder.get(), is(nullValue())); assertThat(responseHolder.get(), is(empty())); @@ -328,8 +333,8 @@ private static SearchResponse searchResponseFromAggs(Aggregations aggs) { } private class MyMockClient extends NoOpClient { - MyMockClient(String testName) { - super(testName); + MyMockClient(ThreadPool threadPool) { + super(threadPool); } @SuppressWarnings("unchecked") @@ -383,8 +388,8 @@ protected void } private class MyMockClientWithEmptyAggregation extends NoOpClient { - MyMockClientWithEmptyAggregation(String testName) { - super(testName); + MyMockClientWithEmptyAggregation(ThreadPool threadPool) { + super(threadPool); } @SuppressWarnings("unchecked") @@ -401,8 +406,8 @@ protected void } private class MyMockClientWithCompositeAggregation extends NoOpClient { - MyMockClientWithCompositeAggregation(String testName) { - super(testName); + MyMockClientWithCompositeAggregation(ThreadPool threadPool) { + super(threadPool); } @SuppressWarnings("unchecked") diff --git a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/SchemaUtilTests.java b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/SchemaUtilTests.java index 37206f20d1269..778ca4bf7767d 100644 --- a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/SchemaUtilTests.java +++ b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/SchemaUtilTests.java @@ -16,10 +16,10 @@ import org.elasticsearch.action.fieldcaps.FieldCapabilitiesRequest; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse; import org.elasticsearch.action.support.ActionTestUtils; -import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.Strings; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.ThreadPool; import java.math.BigInteger; import java.util.Collections; @@ -96,7 +96,8 @@ public void testConvertToIntegerTypeIfNeeded() { } public void testGetSourceFieldMappings() throws InterruptedException { - try (Client client = new FieldCapsMockClient(getTestName())) { + try (var threadPool = createThreadPool()) { + final var client = new FieldCapsMockClient(threadPool); // fields is null this.>assertAsync( listener -> SchemaUtil.getSourceFieldMappings( @@ -188,7 +189,8 @@ public void testGetSourceFieldMappingsWithRuntimeMappings() throws InterruptedEx put("field-3", singletonMap("type", "boolean")); } }; - try (Client client = new FieldCapsMockClient(getTestName())) { + try (var threadPool = createThreadPool()) { + final var client = new FieldCapsMockClient(threadPool); this.>assertAsync( listener -> SchemaUtil.getSourceFieldMappings( client, @@ -210,8 +212,8 @@ public void testGetSourceFieldMappingsWithRuntimeMappings() throws InterruptedEx } private static class FieldCapsMockClient extends NoOpClient { - FieldCapsMockClient(String testName) { - super(testName); + FieldCapsMockClient(ThreadPool threadPool) { + super(threadPool); } @SuppressWarnings("unchecked") From d3aec39c72e0a7c5f4ad0ed684a1b57efc822a28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20R=C3=BChsen?= Date: Tue, 7 Nov 2023 10:18:31 +0100 Subject: [PATCH 14/21] [Profiling] Add helper class StopWatch (#101683) * [Profiling] Add helper class StopWatch * Generalize StopWatch * Simplify StopWatch, use lambda for logging --------- Co-authored-by: Elastic Machine --- .../xpack/profiling/StopWatch.java | 35 +++++++++++++++++++ .../TransportGetFlamegraphAction.java | 12 ++----- .../TransportGetStackTracesAction.java | 16 ++++----- 3 files changed, 46 insertions(+), 17 deletions(-) create mode 100644 x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/StopWatch.java diff --git a/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/StopWatch.java b/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/StopWatch.java new file mode 100644 index 0000000000000..c423fe12f3581 --- /dev/null +++ b/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/StopWatch.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.profiling; + +/** + * Measures time and logs it in milliseconds. + */ +public final class StopWatch { + private final String name; + private final long start; + + public StopWatch(String name) { + this.name = name; + start = System.nanoTime(); + } + + /** + * Return a textual report including the name and the number of elapsed milliseconds since object creation. + */ + public String report() { + return name + " took [" + millis() + " ms]."; + } + + /** + * Return number of elapsed milliseconds since object creation. + */ + public double millis() { + return (System.nanoTime() - start) / 1_000_000.0d; + } +} diff --git a/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/TransportGetFlamegraphAction.java b/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/TransportGetFlamegraphAction.java index f26a6b1fb3a84..b791684bec233 100644 --- a/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/TransportGetFlamegraphAction.java +++ b/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/TransportGetFlamegraphAction.java @@ -44,20 +44,14 @@ public TransportGetFlamegraphAction(NodeClient nodeClient, TransportService tran @Override protected void doExecute(Task task, GetStackTracesRequest request, ActionListener listener) { Client client = new ParentTaskAssigningClient(this.nodeClient, transportService.getLocalNode(), task); - long start = System.nanoTime(); + StopWatch watch = new StopWatch("getFlamegraphAction"); client.execute(GetStackTracesAction.INSTANCE, request, new ActionListener<>() { @Override public void onResponse(GetStackTracesResponse response) { - long responseStart = System.nanoTime(); try { + StopWatch processingWatch = new StopWatch("Processing response"); GetFlamegraphResponse flamegraphResponse = buildFlamegraph(response); - log.debug( - "getFlamegraphAction took [" - + (System.nanoTime() - start) / 1_000_000.0d - + "] ms (processing response: [" - + (System.nanoTime() - responseStart) / 1_000_000.0d - + "] ms." - ); + log.debug(() -> watch.report() + " " + processingWatch.report()); listener.onResponse(flamegraphResponse); } catch (Exception ex) { listener.onFailure(ex); diff --git a/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/TransportGetStackTracesAction.java b/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/TransportGetStackTracesAction.java index e15792adc489d..8b9fce4d04040 100644 --- a/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/TransportGetStackTracesAction.java +++ b/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/TransportGetStackTracesAction.java @@ -126,7 +126,7 @@ public TransportGetStackTracesAction( @Override protected void doExecute(Task submitTask, GetStackTracesRequest request, ActionListener submitListener) { licenseChecker.requireSupportedLicense(); - long start = System.nanoTime(); + StopWatch watch = new StopWatch("getResampledIndex"); Client client = new ParentTaskAssigningClient(this.nodeClient, transportService.getLocalNode(), submitTask); EventsIndex mediumDownsampled = EventsIndex.MEDIUM_DOWNSAMPLED; client.prepareSearch(mediumDownsampled.getName()) @@ -143,7 +143,7 @@ protected void doExecute(Task submitTask, GetStackTracesRequest request, ActionL mediumDownsampled, resampledIndex ); - log.debug("getResampledIndex took [" + (System.nanoTime() - start) / 1_000_000.0d + " ms]."); + log.debug(() -> watch.report()); searchEventGroupByStackTrace(client, request, resampledIndex, submitListener); }, e -> { // All profiling-events data streams are created lazily. In a relatively empty cluster it can happen that there are so few @@ -166,7 +166,7 @@ private void searchEventGroupByStackTrace( EventsIndex eventsIndex, ActionListener submitListener ) { - long start = System.nanoTime(); + StopWatch watch = new StopWatch("searchEventGroupByStackTrace"); GetStackTracesResponseBuilder responseBuilder = new GetStackTracesResponseBuilder(); responseBuilder.setSampleRate(eventsIndex.getSampleRate()); client.prepareSearch(eventsIndex.getName()) @@ -216,7 +216,7 @@ private void searchEventGroupByStackTrace( totalFinalCount, stackTraceEvents.size() ); - log.debug("searchEventGroupByStackTrace took [" + (System.nanoTime() - start) / 1_000_000.0d + " ms]."); + log.debug(() -> watch.report()); if (stackTraceEvents.isEmpty() == false) { responseBuilder.setStart(Instant.ofEpochMilli(minTime)); responseBuilder.setEnd(Instant.ofEpochMilli(maxTime)); @@ -287,7 +287,7 @@ private class StackTraceHandler { private final Set stackFrameIds = new ConcurrentSkipListSet<>(); private final Set executableIds = new ConcurrentSkipListSet<>(); private final AtomicInteger totalFrames = new AtomicInteger(); - private final long start = System.nanoTime(); + private final StopWatch watch = new StopWatch("retrieveStackTraces"); private StackTraceHandler( ClusterState clusterState, @@ -334,7 +334,7 @@ public void onResponse(MultiGetResponse multiGetItemResponses) { stackFrameIds.size(), executableIds.size() ); - log.debug("retrieveStackTraces took [" + (System.nanoTime() - start) / 1_000_000.0d + " ms]."); + log.debug(() -> watch.report()); retrieveStackTraceDetails( clusterState, client, @@ -409,7 +409,7 @@ private static class DetailsHandler { private final Map executables; private final Map stackFrames; private final AtomicInteger expectedSlices; - private final long start = System.nanoTime(); + private final StopWatch watch = new StopWatch("retrieveStackTraceDetails"); private DetailsHandler( GetStackTracesResponseBuilder builder, @@ -479,7 +479,7 @@ public void mayFinish() { builder.setExecutables(executables); builder.setStackFrames(stackFrames); log.debug("retrieveStackTraceDetails found [{}] stack frames, [{}] executables.", stackFrames.size(), executables.size()); - log.debug("retrieveStackTraceDetails took [" + (System.nanoTime() - start) / 1_000_000.0d + " ms]."); + log.debug(() -> watch.report()); submitListener.onResponse(builder.build()); } } From c773e04761f9709c1c190ab0bebb69a6356ce4b7 Mon Sep 17 00:00:00 2001 From: Artem Prigoda Date: Tue, 7 Nov 2023 10:23:23 +0100 Subject: [PATCH 15/21] Respect regional AWS STS endpoints (#101705) The AWS SDK supports regional STS endpoints via the AWS_STS_REGIONAL_ENDPOINTS environment variable. If the user set it to regional and provided the region in the AWS_REGION env variable, we should respect that and make the STS client use the regional adjusted STS endpoint like https://sts.us-west-2.amazonaws.com. See https://docs.aws.amazon.com/sdkref/latest/guide/feature-sts-regionalized-endpoints.html Resolves https://github.com/elastic/elasticsearch/issues/89175 --- docs/changelog/101705.yaml | 6 +++ .../repositories/s3/S3Service.java | 28 ++++++++++-- ...IdentityTokenCredentialsProviderTests.java | 44 ++++++++++++++++--- 3 files changed, 69 insertions(+), 9 deletions(-) create mode 100644 docs/changelog/101705.yaml diff --git a/docs/changelog/101705.yaml b/docs/changelog/101705.yaml new file mode 100644 index 0000000000000..baa7e69d48d88 --- /dev/null +++ b/docs/changelog/101705.yaml @@ -0,0 +1,6 @@ +pr: 101705 +summary: Respect regional AWS STS endpoints +area: Snapshot/Restore +type: bug +issues: + - 89175 diff --git a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3Service.java b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3Service.java index 291cf84019cd1..25bba12db6952 100644 --- a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3Service.java +++ b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3Service.java @@ -9,6 +9,7 @@ package org.elasticsearch.repositories.s3; import com.amazonaws.ClientConfiguration; +import com.amazonaws.SDKGlobalConfiguration; import com.amazonaws.auth.AWSCredentials; import com.amazonaws.auth.AWSCredentialsProvider; import com.amazonaws.auth.AWSCredentialsProviderChain; @@ -320,6 +321,7 @@ static class CustomWebIdentityTokenCredentialsProvider implements AWSCredentials private STSAssumeRoleWithWebIdentitySessionCredentialsProvider credentialsProvider; private AWSSecurityTokenService stsClient; + private String stsRegion; CustomWebIdentityTokenCredentialsProvider( Environment environment, @@ -361,10 +363,24 @@ static class CustomWebIdentityTokenCredentialsProvider implements AWSCredentials ); AWSSecurityTokenServiceClientBuilder stsClientBuilder = AWSSecurityTokenServiceClient.builder(); - // Custom system property used for specifying a mocked version of the STS for testing - String customStsEndpoint = jvmEnvironment.getProperty("com.amazonaws.sdk.stsMetadataServiceEndpointOverride", STS_HOSTNAME); - // Set the region explicitly via the endpoint URL, so the AWS SDK doesn't make any guesses internally. - stsClientBuilder.withEndpointConfiguration(new AwsClientBuilder.EndpointConfiguration(customStsEndpoint, null)); + // Check if we need to use regional STS endpoints + // https://docs.aws.amazon.com/sdkref/latest/guide/feature-sts-regionalized-endpoints.html + if ("regional".equalsIgnoreCase(systemEnvironment.getEnv("AWS_STS_REGIONAL_ENDPOINTS"))) { + // AWS_REGION should be injected by the EKS pod identity webhook: + // https://github.com/aws/amazon-eks-pod-identity-webhook/pull/41 + stsRegion = systemEnvironment.getEnv(SDKGlobalConfiguration.AWS_REGION_ENV_VAR); + if (stsRegion != null) { + stsClientBuilder.withRegion(stsRegion); + } else { + LOGGER.warn("Unable to use regional STS endpoints because the AWS_REGION environment variable is not set"); + } + } + if (stsRegion == null) { + // Custom system property used for specifying a mocked version of the STS for testing + String customStsEndpoint = jvmEnvironment.getProperty("com.amazonaws.sdk.stsMetadataServiceEndpointOverride", STS_HOSTNAME); + // Set the region explicitly via the endpoint URL, so the AWS SDK doesn't make any guesses internally. + stsClientBuilder.withEndpointConfiguration(new AwsClientBuilder.EndpointConfiguration(customStsEndpoint, null)); + } stsClientBuilder.withCredentials(new AWSStaticCredentialsProvider(new AnonymousAWSCredentials())); stsClient = SocketAccess.doPrivileged(stsClientBuilder::build); try { @@ -383,6 +399,10 @@ boolean isActive() { return credentialsProvider != null; } + String getStsRegion() { + return stsRegion; + } + @Override public AWSCredentials getCredentials() { Objects.requireNonNull(credentialsProvider, "credentialsProvider is not set"); diff --git a/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/CustomWebIdentityTokenCredentialsProviderTests.java b/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/CustomWebIdentityTokenCredentialsProviderTests.java index 04c47bb9b55e6..f245b1ad91fe4 100644 --- a/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/CustomWebIdentityTokenCredentialsProviderTests.java +++ b/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/CustomWebIdentityTokenCredentialsProviderTests.java @@ -22,6 +22,7 @@ import org.junit.Assert; import org.mockito.Mockito; +import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.URLDecoder; @@ -42,6 +43,15 @@ public class CustomWebIdentityTokenCredentialsProviderTests extends ESTestCase { private static final String ROLE_ARN = "arn:aws:iam::123456789012:role/FederatedWebIdentityRole"; private static final String ROLE_NAME = "aws-sdk-java-1651084775908"; + private static Environment getEnvironment() throws IOException { + Path configDirectory = Files.createTempDirectory("web-identity-token-test"); + Files.createDirectory(configDirectory.resolve("repository-s3")); + Files.writeString(configDirectory.resolve("repository-s3/aws-web-identity-token-file"), "YXdzLXdlYi1pZGVudGl0eS10b2tlbi1maWxl"); + Environment environment = Mockito.mock(Environment.class); + Mockito.when(environment.configFile()).thenReturn(configDirectory); + return environment; + } + @SuppressForbidden(reason = "HTTP server is used for testing") public void testCreateWebIdentityTokenCredentialsProvider() throws Exception { HttpServer httpServer = MockHttpServer.createHttp(new InetSocketAddress(InetAddress.getLoopbackAddress().getHostAddress(), 0), 0); @@ -88,11 +98,7 @@ public void testCreateWebIdentityTokenCredentialsProvider() throws Exception { }); httpServer.start(); - Path configDirectory = Files.createTempDirectory("web-identity-token-test"); - Files.createDirectory(configDirectory.resolve("repository-s3")); - Files.writeString(configDirectory.resolve("repository-s3/aws-web-identity-token-file"), "YXdzLXdlYi1pZGVudGl0eS10b2tlbi1maWxl"); - Environment environment = Mockito.mock(Environment.class); - Mockito.when(environment.configFile()).thenReturn(configDirectory); + Environment environment = getEnvironment(); // No region is set, but the SDK shouldn't fail because of that Map environmentVariables = Map.of( @@ -125,4 +131,32 @@ public void testCreateWebIdentityTokenCredentialsProvider() throws Exception { httpServer.stop(0); } } + + public void testSupportRegionalizedEndpoints() throws Exception { + Map environmentVariables = Map.of( + "AWS_WEB_IDENTITY_TOKEN_FILE", + "/var/run/secrets/eks.amazonaws.com/serviceaccount/token", + "AWS_ROLE_ARN", + ROLE_ARN, + "AWS_STS_REGIONAL_ENDPOINTS", + "regional", + "AWS_REGION", + "us-west-2" + ); + Map systemProperties = Map.of(); + + var webIdentityTokenCredentialsProvider = new S3Service.CustomWebIdentityTokenCredentialsProvider( + getEnvironment(), + environmentVariables::get, + systemProperties::getOrDefault, + Clock.systemUTC() + ); + // We can't verify that webIdentityTokenCredentialsProvider's STS client uses the "https://sts.us-west-2.amazonaws.com" + // endpoint in a unit test. The client depends on hardcoded RegionalEndpointsOptionResolver that in turn depends + // on the system environment that we can't change in the test. So we just verify we that we called `withRegion` + // on stsClientBuilder which should internally correctly configure the endpoint when the STS client is built. + assertEquals("us-west-2", webIdentityTokenCredentialsProvider.getStsRegion()); + + webIdentityTokenCredentialsProvider.shutdown(); + } } From eed3c6b15c8291a3fa43bd6972418ab8040b4384 Mon Sep 17 00:00:00 2001 From: Rene Groeschke Date: Tue, 7 Nov 2023 11:16:20 +0100 Subject: [PATCH 16/21] Make enforce TestConvention cc compatible (#101822) --- rest-api-spec/build.gradle | 3 ++- x-pack/plugin/build.gradle | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/rest-api-spec/build.gradle b/rest-api-spec/build.gradle index e484b98d3188e..787d684c3779e 100644 --- a/rest-api-spec/build.gradle +++ b/rest-api-spec/build.gradle @@ -228,8 +228,9 @@ tasks.named("yamlRestTestV7CompatTransform").configure { task -> } tasks.register('enforceYamlTestConvention').configure { + def tree = fileTree('src/main/resources/rest-api-spec/test') doLast { - if (fileTree('src/main/resources/rest-api-spec/test').files) { + if (tree.files) { throw new GradleException("There are YAML tests in src/main source set. These should be moved to src/yamlRestTest.") } } diff --git a/x-pack/plugin/build.gradle b/x-pack/plugin/build.gradle index 17495a3568923..eae3031512d4f 100644 --- a/x-pack/plugin/build.gradle +++ b/x-pack/plugin/build.gradle @@ -177,16 +177,18 @@ tasks.named("yamlRestTestV7CompatTransform").configure { task -> } tasks.register('enforceApiSpecsConvention').configure { + def mainApiSpecs = fileTree('src/test/resources/rest-api-spec/api') doLast { - if (fileTree('src/test/resources/rest-api-spec/api').files) { + if (mainApiSpecs.files) { throw new GradleException("There are REST specs in src/test source set. These should be moved to the :rest-api-spec project.") } } } tasks.register('enforceYamlTestConvention').configure { + def mainYamlFiles = fileTree('src/test/resources/rest-api-spec/test') doLast { - if (fileTree('src/test/resources/rest-api-spec/test').files) { + if (mainYamlFiles.files) { throw new GradleException("There are YAML tests in src/test source set. These should be moved to src/yamlRestTest.") } } From e1184cd7a8ac3d3c1cf3abb5dbf8456fe3c4ab7c Mon Sep 17 00:00:00 2001 From: Rene Groeschke Date: Tue, 7 Nov 2023 11:17:21 +0100 Subject: [PATCH 17/21] Fix snippet task cc incompatibilities (#101823) Addresses some Gradle configuration cache issues related to https://github.com/elastic/elasticsearch/issues/57918 --- .../gradle/internal/doc/DocsTestPlugin.groovy | 19 ++++++++++++----- .../doc/RestTestsFromSnippetsTask.groovy | 21 +++++++++++++++---- .../gradle/internal/doc/SnippetsTask.groovy | 7 ++++--- 3 files changed, 35 insertions(+), 12 deletions(-) diff --git a/build-tools-internal/src/main/groovy/org/elasticsearch/gradle/internal/doc/DocsTestPlugin.groovy b/build-tools-internal/src/main/groovy/org/elasticsearch/gradle/internal/doc/DocsTestPlugin.groovy index 874141f2135ad..38b4cb499eeb9 100644 --- a/build-tools-internal/src/main/groovy/org/elasticsearch/gradle/internal/doc/DocsTestPlugin.groovy +++ b/build-tools-internal/src/main/groovy/org/elasticsearch/gradle/internal/doc/DocsTestPlugin.groovy @@ -12,6 +12,7 @@ import org.elasticsearch.gradle.Version import org.elasticsearch.gradle.VersionProperties import org.elasticsearch.gradle.internal.test.rest.CopyRestApiTask import org.elasticsearch.gradle.internal.test.rest.CopyRestTestsTask +import org.gradle.api.Action import org.gradle.api.Plugin import org.gradle.api.Project import org.gradle.api.file.Directory @@ -61,16 +62,24 @@ class DocsTestPlugin implements Plugin { group 'Docs' description 'List each snippet' defaultSubstitutions = commonDefaultSubstitutions - perSnippet { println(it.toString()) } + perSnippet = new Action() { + @Override + void execute(SnippetsTask.Snippet snippet) { + println(snippet.toString()) + } + } } project.tasks.register('listConsoleCandidates', SnippetsTask) { group 'Docs' description 'List snippets that probably should be marked // CONSOLE' defaultSubstitutions = commonDefaultSubstitutions - perSnippet { - if (RestTestsFromSnippetsTask.isConsoleCandidate(it)) { - println(it.toString()) + perSnippet = new Action() { + @Override + void execute(SnippetsTask.Snippet snippet) { + if (RestTestsFromSnippetsTask.isConsoleCandidate(it)) { + println(it.toString()) + } } } } @@ -80,7 +89,7 @@ class DocsTestPlugin implements Plugin { defaultSubstitutions = commonDefaultSubstitutions testRoot.convention(restRootDir) doFirst { - fileOperations.delete(restRootDir) + getFileOperations().delete(testRoot.get()) } } diff --git a/build-tools-internal/src/main/groovy/org/elasticsearch/gradle/internal/doc/RestTestsFromSnippetsTask.groovy b/build-tools-internal/src/main/groovy/org/elasticsearch/gradle/internal/doc/RestTestsFromSnippetsTask.groovy index eda86355ee306..81207181dc9a7 100644 --- a/build-tools-internal/src/main/groovy/org/elasticsearch/gradle/internal/doc/RestTestsFromSnippetsTask.groovy +++ b/build-tools-internal/src/main/groovy/org/elasticsearch/gradle/internal/doc/RestTestsFromSnippetsTask.groovy @@ -10,8 +10,10 @@ package org.elasticsearch.gradle.internal.doc import groovy.transform.PackageScope import org.elasticsearch.gradle.internal.doc.SnippetsTask.Snippet +import org.gradle.api.Action import org.gradle.api.InvalidUserDataException import org.gradle.api.file.DirectoryProperty +import org.gradle.api.internal.file.FileOperations import org.gradle.api.tasks.Input import org.gradle.api.tasks.Internal import org.gradle.api.tasks.OutputDirectory @@ -24,7 +26,7 @@ import java.nio.file.Path /** * Generates REST tests for each snippet marked // TEST. */ -class RestTestsFromSnippetsTask extends SnippetsTask { +abstract class RestTestsFromSnippetsTask extends SnippetsTask { /** * These languages aren't supported by the syntax highlighter so we * shouldn't use them. @@ -64,13 +66,23 @@ class RestTestsFromSnippetsTask extends SnippetsTask { @Internal Set names = new HashSet<>() + @Inject + abstract FileOperations getFileOperations(); + @Inject RestTestsFromSnippetsTask(ObjectFactory objectFactory) { testRoot = objectFactory.directoryProperty() TestBuilder builder = new TestBuilder() - perSnippet builder.&handleSnippet - doLast builder.&checkUnconverted - doLast builder.&finishLastTest + perSnippet = new Action() { + @Override + void execute(Snippet snippet) { + builder.handleSnippet(snippet) + } + } + doLast { + builder.checkUnconverted() + builder.finishLastTest() + } } /** @@ -190,6 +202,7 @@ class RestTestsFromSnippetsTask extends SnippetsTask { * Called each time a snippet is encountered. Tracks the snippets and * calls buildTest to actually build the test. */ + void handleSnippet(Snippet snippet) { if (RestTestsFromSnippetsTask.isConsoleCandidate(snippet)) { unconvertedCandidates.add(snippet.path.toString() diff --git a/build-tools-internal/src/main/groovy/org/elasticsearch/gradle/internal/doc/SnippetsTask.groovy b/build-tools-internal/src/main/groovy/org/elasticsearch/gradle/internal/doc/SnippetsTask.groovy index 1580ec891ed2b..3e4ad91024082 100644 --- a/build-tools-internal/src/main/groovy/org/elasticsearch/gradle/internal/doc/SnippetsTask.groovy +++ b/build-tools-internal/src/main/groovy/org/elasticsearch/gradle/internal/doc/SnippetsTask.groovy @@ -11,8 +11,9 @@ package org.elasticsearch.gradle.internal.doc import com.fasterxml.jackson.core.JsonFactory; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonParseException; -import com.fasterxml.jackson.core.JsonToken; +import com.fasterxml.jackson.core.JsonToken +import org.gradle.api.Action; import org.gradle.api.DefaultTask import org.gradle.api.InvalidUserDataException import org.gradle.api.file.ConfigurableFileTree @@ -44,7 +45,7 @@ class SnippetsTask extends DefaultTask { * instance of Snippet. */ @Internal - Closure perSnippet + Action perSnippet /** * The docs to scan. Defaults to every file in the directory exception the @@ -134,7 +135,7 @@ class SnippetsTask extends DefaultTask { + "After substitutions and munging, the json looks like:\n" + quoted, e); } } - perSnippet(snippet) + perSnippet.execute(snippet) snippet = null } file.eachLine('UTF-8') { String line, int lineNumber -> From d189d0e5c53fe5ca10f995e9a858d180f6bb0641 Mon Sep 17 00:00:00 2001 From: Rene Groeschke Date: Tue, 7 Nov 2023 11:18:02 +0100 Subject: [PATCH 18/21] Make addRemote task configuration cache compatible (#101830) --- .../elasticsearch/gradle/internal/InternalBwcGitPlugin.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/InternalBwcGitPlugin.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/InternalBwcGitPlugin.java index d51770ffd30ed..71c76b2045007 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/InternalBwcGitPlugin.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/InternalBwcGitPlugin.java @@ -72,20 +72,19 @@ public void apply(Project project) { createClone.commandLine("git", "clone", buildLayout.getRootDirectory(), gitExtension.getCheckoutDir().get()); }); - ExtraPropertiesExtension extraProperties = project.getExtensions().getExtraProperties(); TaskProvider findRemoteTaskProvider = tasks.register("findRemote", LoggedExec.class, findRemote -> { findRemote.dependsOn(createCloneTaskProvider); findRemote.getWorkingDir().set(gitExtension.getCheckoutDir()); findRemote.commandLine("git", "remote", "-v"); findRemote.getCaptureOutput().set(true); - findRemote.doLast(t -> { extraProperties.set("remoteExists", isRemoteAvailable(remote, findRemote.getOutput())); }); + findRemote.doLast(t -> System.setProperty("remoteExists", String.valueOf(isRemoteAvailable(remote, findRemote.getOutput())))); }); TaskProvider addRemoteTaskProvider = tasks.register("addRemote", addRemote -> { String rootProjectName = project.getRootProject().getName(); addRemote.dependsOn(findRemoteTaskProvider); - addRemote.onlyIf("remote exists", task -> ((boolean) extraProperties.get("remoteExists")) == false); + addRemote.onlyIf("remote exists", task -> (Boolean.valueOf(providerFactory.systemProperty("remoteExists").get()) == false)); addRemote.doLast(new Action() { @Override public void execute(Task task) { From c58427d9d3003b3a56dded3e44b82431e5859082 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lorenzo=20Dematt=C3=A9?= Date: Tue, 7 Nov 2023 11:19:48 +0100 Subject: [PATCH 19/21] Add MlFeatures basics + test feature (#101858) --- .../plugin/ml/src/main/java/module-info.java | 1 + .../xpack/ml/MachineLearning.java | 3 +++ .../elasticsearch/xpack/ml/MlFeatures.java | 24 +++++++++++++++++++ ...lasticsearch.features.FeatureSpecification | 8 +++++++ 4 files changed, 36 insertions(+) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlFeatures.java create mode 100644 x-pack/plugin/ml/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification diff --git a/x-pack/plugin/ml/src/main/java/module-info.java b/x-pack/plugin/ml/src/main/java/module-info.java index a73c9bdfa32b4..52dee889d15fc 100644 --- a/x-pack/plugin/ml/src/main/java/module-info.java +++ b/x-pack/plugin/ml/src/main/java/module-info.java @@ -33,6 +33,7 @@ provides org.elasticsearch.painless.spi.PainlessExtension with org.elasticsearch.xpack.ml.MachineLearningPainlessExtension; provides org.elasticsearch.xpack.autoscaling.AutoscalingExtension with org.elasticsearch.xpack.ml.autoscaling.MlAutoscalingExtension; + provides org.elasticsearch.features.FeatureSpecification with org.elasticsearch.xpack.ml.MlFeatures; exports org.elasticsearch.xpack.ml; exports org.elasticsearch.xpack.ml.action; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index f4bce4906c0b0..b4b8084b4b328 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -44,6 +44,7 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.TimeValue; import org.elasticsearch.env.Environment; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.analysis.CharFilterFactory; import org.elasticsearch.index.analysis.TokenizerFactory; import org.elasticsearch.index.query.QueryBuilder; @@ -485,6 +486,8 @@ public class MachineLearning extends Plugin public static final String TRAINED_MODEL_CIRCUIT_BREAKER_NAME = "model_inference"; + public static final NodeFeature STATE_RESET_FALLBACK_ON_DISABLED = new NodeFeature("ml.state_reset_fallback_on_disabled"); + private static final long DEFAULT_MODEL_CIRCUIT_BREAKER_LIMIT = (long) ((0.50) * JvmInfo.jvmInfo().getMem().getHeapMax().getBytes()); private static final double DEFAULT_MODEL_CIRCUIT_BREAKER_OVERHEAD = 1.0D; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlFeatures.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlFeatures.java new file mode 100644 index 0000000000000..29aa189b2acd4 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlFeatures.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml; + +import org.elasticsearch.Version; +import org.elasticsearch.features.FeatureSpecification; +import org.elasticsearch.features.NodeFeature; + +import java.util.Map; + +/** + * This class specifies source code features exposed by the Shutdown plugin. + */ +public class MlFeatures implements FeatureSpecification { + @Override + public Map getHistoricalFeatures() { + return Map.of(MachineLearning.STATE_RESET_FALLBACK_ON_DISABLED, Version.V_8_7_0); + } +} diff --git a/x-pack/plugin/ml/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification b/x-pack/plugin/ml/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification new file mode 100644 index 0000000000000..7dbef291bdd46 --- /dev/null +++ b/x-pack/plugin/ml/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification @@ -0,0 +1,8 @@ +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0; you may not use this file except in compliance with the Elastic License +# 2.0. +# + +org.elasticsearch.xpack.ml.MlFeatures From 73ca01ebf5a40476d009ea7d871e3d3fbde8b44b Mon Sep 17 00:00:00 2001 From: Rene Groeschke Date: Tue, 7 Nov 2023 11:20:44 +0100 Subject: [PATCH 20/21] Fix configuration cache incompatibility in Rest compatibility tests (#101842) related to https://github.com/elastic/elasticsearch/issues/57918 --- .../gradle/internal/test/rest/CopyRestTestsTask.java | 6 +++++- .../rest/compat/compat/RestCompatTestTransformTask.java | 6 ++++-- x-pack/qa/xpack-prefix-rest-compat/build.gradle | 7 ++++--- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/test/rest/CopyRestTestsTask.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/test/rest/CopyRestTestsTask.java index 9359272b29610..94345ed80eec7 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/test/rest/CopyRestTestsTask.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/test/rest/CopyRestTestsTask.java @@ -15,6 +15,7 @@ import org.gradle.api.file.FileSystemOperations; import org.gradle.api.file.FileTree; import org.gradle.api.file.ProjectLayout; +import org.gradle.api.internal.file.FileOperations; import org.gradle.api.model.ObjectFactory; import org.gradle.api.provider.ListProperty; import org.gradle.api.tasks.IgnoreEmptyDirectories; @@ -43,7 +44,7 @@ * * @see RestResourcesPlugin */ -public class CopyRestTestsTask extends DefaultTask { +public abstract class CopyRestTestsTask extends DefaultTask { private static final String REST_TEST_PREFIX = "rest-api-spec/test"; private final ListProperty includeCore; private final ListProperty includeXpack; @@ -62,6 +63,9 @@ public class CopyRestTestsTask extends DefaultTask { private final ProjectLayout projectLayout; private final FileSystemOperations fileSystemOperations; + @Inject + public abstract FileOperations getFileOperations(); + @Inject public CopyRestTestsTask( ProjectLayout projectLayout, diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/test/rest/compat/compat/RestCompatTestTransformTask.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/test/rest/compat/compat/RestCompatTestTransformTask.java index 76004e3e5f6db..9b1e8a67deec8 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/test/rest/compat/compat/RestCompatTestTransformTask.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/test/rest/compat/compat/RestCompatTestTransformTask.java @@ -457,15 +457,17 @@ public void transform() throws IOException { Collections.singletonList(new Skip(skippedFilesWithReason.get(file))) ); } else { + List> transformations = new ArrayList<>(getTransformations().get()); + if (skippedFilesWithTestAndReason.containsKey(file)) { // skip the named tests for this file skippedFilesWithTestAndReason.get(file).forEach(fullTestNameAndReasonPair -> { String prefix = file.getName().replace(".yml", "/"); String singleTestName = fullTestNameAndReasonPair.getLeft().replaceAll(".*" + prefix, ""); - getTransformations().add(new Skip(singleTestName, fullTestNameAndReasonPair.getRight())); + transformations.add(new Skip(singleTestName, fullTestNameAndReasonPair.getRight())); }); } - transformRestTests = transformer.transformRestTests(new LinkedList<>(tests), getTransformations().get()); + transformRestTests = transformer.transformRestTests(new LinkedList<>(tests), transformations); } // convert to url to ensure forward slashes diff --git a/x-pack/qa/xpack-prefix-rest-compat/build.gradle b/x-pack/qa/xpack-prefix-rest-compat/build.gradle index caca3b63d4951..8b91aae21ff73 100644 --- a/x-pack/qa/xpack-prefix-rest-compat/build.gradle +++ b/x-pack/qa/xpack-prefix-rest-compat/build.gradle @@ -34,10 +34,11 @@ tasks.named("copyRestCompatTestTask").configure { task -> task.dependsOn(configurations.compatXpackTests); task.setXpackConfig(configurations.compatXpackTests); task.getIncludeXpack().set(List.of("license", "migration", "ml", "rollup", "sql", "ssl")); - task.getOutputResourceDir().set(project.getLayout().getBuildDirectory().dir("restResources/v${compatVersion}/yamlTests/original")); + def fileOperations = task.getFileOperations() + task.getOutputResourceDir().set(project.getLayout().getBuildDirectory().dir("restResources/v${compatVersion}/yamlTests/original")) task.setXpackConfigToFileTree( - config -> fileTree( - config.getSingleFile().toPath() + config -> fileOperations.fileTree( + config.getSingleFile() ) ) } From 9132f95fb4fc07965be10d61ef106c685afc6072 Mon Sep 17 00:00:00 2001 From: Abdon Pijpelink Date: Tue, 7 Nov 2023 11:35:37 +0100 Subject: [PATCH 21/21] [DOCS] Add 'Using ES|QL in Elastic Security' (#101677) * [DOCS] Add 'Using ES|QL in Elastic Security' * Add a note about enabling knowledge base * Update links --- .../esql/esql-security-solution.asciidoc | 41 +++++++++++++++++++ docs/reference/esql/esql-using.asciidoc | 7 +++- docs/reference/esql/index.asciidoc | 4 +- 3 files changed, 49 insertions(+), 3 deletions(-) create mode 100644 docs/reference/esql/esql-security-solution.asciidoc diff --git a/docs/reference/esql/esql-security-solution.asciidoc b/docs/reference/esql/esql-security-solution.asciidoc new file mode 100644 index 0000000000000..45e8e44e44bdd --- /dev/null +++ b/docs/reference/esql/esql-security-solution.asciidoc @@ -0,0 +1,41 @@ +[[esql-elastic-security]] +=== Using {esql} in {elastic-sec} + +++++ +Using {esql} in {elastic-sec} +++++ + +You can use {esql} in {elastic-sec} to investigate events in Timeline and create +detection rules. Use the Elastic AI Assistant to build {esql} queries, or answer +questions about the {esql} query language. + +[discrete] +[[esql-elastic-security-timeline]] +=== Use {esql} to investigate events in Timeline + +You can use {esql} in Timeline to filter, transform, and analyze event data +stored in {es}. To start using {esql}, open the the **{esql}** tab. To learn +more, refer to {security-guide}/timelines-ui.html#esql-in-timeline[Investigate +events in Timeline]. + +[discrete] +[[esql-elastic-security-detection-rules]] +=== Use {esql} to create detection rules + +Use the {esql} rule type to create detection rules using {esql} queries. The +{esql} rule type supports aggregating and non-aggregating queries. To learn +more, refer to {security-guide}/rules-ui-create.html#create-esql-rule[Create an +{esql} rule]. + +[discrete] +[[esql-elastic-security-ai-assistant]] +=== Elastic AI Assistant + +Use the Elastic AI Assistant to build {esql} queries, or answer questions about +the {esql} query language. To learn more, refer to +{security-guide}/security-assistant.html[AI Assistant]. + +NOTE: For AI Assistant to answer questions about {esql} and write {esql} +queries, you need to +{security-guide}/security-assistant.html#set-up-ai-assistant[enable knowledge +base]. \ No newline at end of file diff --git a/docs/reference/esql/esql-using.asciidoc b/docs/reference/esql/esql-using.asciidoc index f586f3a28de5c..dbab521ead4d1 100644 --- a/docs/reference/esql/esql-using.asciidoc +++ b/docs/reference/esql/esql-using.asciidoc @@ -6,11 +6,16 @@ Information about using the <>. <>:: Using {esql} in {kib} to query and aggregate your data, create visualizations, -and set up alerts. +and set up alerts. + +<>:: +Using {esql} in {elastic-sec} to investigate events in Timeline and create +detection rules. <>:: Using the <> to list and cancel {esql} queries. include::esql-rest.asciidoc[] include::esql-kibana.asciidoc[] +include::esql-security-solution.asciidoc[] include::task-management.asciidoc[] \ No newline at end of file diff --git a/docs/reference/esql/index.asciidoc b/docs/reference/esql/index.asciidoc index 799f95751aa69..dcbe426b1bcac 100644 --- a/docs/reference/esql/index.asciidoc +++ b/docs/reference/esql/index.asciidoc @@ -55,8 +55,8 @@ fields>> and <>. And guidance for GROK>> and <>. <>:: -An overview of using the <>, <>, and -<>. +An overview of using the <>, <>, +<>, and <>. <>:: The current limitations of {esql}.