Skip to content

Commit

Permalink
Filter shards for sliced search at coordinator
Browse files Browse the repository at this point in the history
Prior to this commit, a sliced search would fan out to every shard,
then apply a MatchNoDocsQuery filter on shards that don't correspond
to the current slice. This still creates a (useless) search context
on each shard for every slice, though. For a long-running sliced
scroll, this can quickly exhaust the number of available scroll
contexts.

This change avoids fanning out to all the shards by checking at the
coordinator if a shard is matched by the current slice. This should
reduce the number of open scroll contexts to max(numShards, numSlices)
instead of numShards * numSlices.

Signed-off-by: Michael Froh <froh@amazon.com>
  • Loading branch information
msfroh committed Dec 10, 2024
1 parent 5ba909a commit b4aaa2f
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

package org.opensearch.action.admin.cluster.shards;

import org.opensearch.Version;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.action.IndicesRequest;
import org.opensearch.action.support.IndicesOptions;
Expand All @@ -41,6 +42,7 @@
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.search.slice.SliceBuilder;

import java.io.IOException;
import java.util.Objects;
Expand All @@ -61,6 +63,8 @@ public class ClusterSearchShardsRequest extends ClusterManagerNodeReadRequest<Cl
@Nullable
private String preference;
private IndicesOptions indicesOptions = IndicesOptions.lenientExpandOpen();
@Nullable
private SliceBuilder sliceBuilder;

public ClusterSearchShardsRequest() {}

Expand All @@ -76,6 +80,12 @@ public ClusterSearchShardsRequest(StreamInput in) throws IOException {
preference = in.readOptionalString();

indicesOptions = IndicesOptions.readIndicesOptions(in);
if (in.getVersion().onOrAfter(Version.V_3_0_0)) {
boolean hasSlice = in.readBoolean();
if (hasSlice) {
sliceBuilder = new SliceBuilder(in);
}
}
}

@Override
Expand All @@ -84,8 +94,15 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeStringArray(indices);
out.writeOptionalString(routing);
out.writeOptionalString(preference);

indicesOptions.writeIndicesOptions(out);
if (out.getVersion().onOrAfter(Version.V_3_0_0)) {
if (sliceBuilder != null) {
out.writeBoolean(true);
sliceBuilder.writeTo(out);
} else {
out.writeBoolean(false);
}
}
}

@Override
Expand Down Expand Up @@ -166,4 +183,13 @@ public ClusterSearchShardsRequest preference(String preference) {
public String preference() {
return this.preference;
}

public ClusterSearchShardsRequest slice(SliceBuilder sliceBuilder) {
this.sliceBuilder = sliceBuilder;
return this;
}

public SliceBuilder slice() {
return this.sliceBuilder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ protected void clusterManagerOperation(

Set<String> nodeIds = new HashSet<>();
GroupShardsIterator<ShardIterator> groupShardsIterator = clusterService.operationRouting()
.searchShards(clusterState, concreteIndices, routingMap, request.preference());
.searchShards(clusterState, concreteIndices, routingMap, request.preference(), null, null, request.slice());
ShardRouting shard;
ClusterSearchShardsGroup[] groupResponses = new ClusterSearchShardsGroup[groupShardsIterator.size()];
int currentGroup = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,7 @@ private AsyncShardsAction(FieldCapabilitiesIndexRequest request, ActionListener<
throw blockException;
}

shardsIt = clusterService.operationRouting()
.searchShards(clusterService.state(), new String[] { request.index() }, null, null, null, null);
shardsIt = clusterService.operationRouting().searchShards(clusterService.state(), new String[] { request.index() }, null, null);
}

public void start() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
import org.opensearch.search.pipeline.SearchPipelineService;
import org.opensearch.search.profile.ProfileShardResult;
import org.opensearch.search.profile.SearchProfileShardResults;
import org.opensearch.search.slice.SliceBuilder;
import org.opensearch.tasks.CancellableTask;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskResourceTrackingService;
Expand Down Expand Up @@ -551,6 +552,7 @@ private ActionListener<SearchSourceBuilder> buildRewriteListener(
);
} else {
AtomicInteger skippedClusters = new AtomicInteger(0);
SliceBuilder slice = searchRequest.source() == null ? null : searchRequest.source().slice();
collectSearchShards(
searchRequest.indicesOptions(),
searchRequest.preference(),
Expand All @@ -559,6 +561,7 @@ private ActionListener<SearchSourceBuilder> buildRewriteListener(
remoteClusterIndices,
remoteClusterService,
threadPool,
slice,
ActionListener.wrap(searchShardsResponses -> {
final BiFunction<String, String, DiscoveryNode> clusterNodeLookup = getRemoteClusterNodeLookup(
searchShardsResponses
Expand Down Expand Up @@ -787,6 +790,7 @@ static void collectSearchShards(
Map<String, OriginalIndices> remoteIndicesByCluster,
RemoteClusterService remoteClusterService,
ThreadPool threadPool,
SliceBuilder slice,
ActionListener<Map<String, ClusterSearchShardsResponse>> listener
) {
final CountDown responsesCountDown = new CountDown(remoteIndicesByCluster.size());
Expand All @@ -800,7 +804,8 @@ static void collectSearchShards(
ClusterSearchShardsRequest searchShardsRequest = new ClusterSearchShardsRequest(indices).indicesOptions(indicesOptions)
.local(true)
.preference(preference)
.routing(routing);
.routing(routing)
.slice(slice);
clusterClient.admin()
.cluster()
.searchShards(
Expand Down Expand Up @@ -1042,14 +1047,16 @@ private void executeSearch(
concreteLocalIndices[i] = indices[i].getName();
}
Map<String, Long> nodeSearchCounts = searchTransportService.getPendingSearchRequests();
SliceBuilder slice = searchRequest.source() == null ? null : searchRequest.source().slice();
GroupShardsIterator<ShardIterator> localShardRoutings = clusterService.operationRouting()
.searchShards(
clusterState,
concreteLocalIndices,
routingMap,
searchRequest.preference(),
searchService.getResponseCollectorService(),
nodeSearchCounts
nodeSearchCounts,
slice
);
localShardIterators = StreamSupport.stream(localShardRoutings.spliterator(), false)
.map(it -> new SearchShardIterator(searchRequest.getLocalClusterAlias(), it.shardId(), it.getShardRoutings(), localIndices))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

package org.opensearch.cluster.routing;

import org.apache.lucene.util.CollectionUtil;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.metadata.WeightedRoutingMetadata;
Expand All @@ -44,14 +45,17 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.FeatureFlags;
import org.opensearch.core.common.Strings;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.index.IndexModule;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.node.ResponseCollectorService;
import org.opensearch.search.slice.SliceBuilder;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -230,7 +234,7 @@ public GroupShardsIterator<ShardIterator> searchShards(
@Nullable Map<String, Set<String>> routing,
@Nullable String preference
) {
return searchShards(clusterState, concreteIndices, routing, preference, null, null);
return searchShards(clusterState, concreteIndices, routing, preference, null, null, null);
}

public GroupShardsIterator<ShardIterator> searchShards(
Expand All @@ -239,11 +243,14 @@ public GroupShardsIterator<ShardIterator> searchShards(
@Nullable Map<String, Set<String>> routing,
@Nullable String preference,
@Nullable ResponseCollectorService collectorService,
@Nullable Map<String, Long> nodeCounts
@Nullable Map<String, Long> nodeCounts,
@Nullable SliceBuilder slice
) {
final Set<IndexShardRoutingTable> shards = computeTargetedShards(clusterState, concreteIndices, routing);
final Set<ShardIterator> set = new HashSet<>(shards.size());

Map<Index, List<ShardIterator>> shardIterators = new HashMap<>();
for (IndexShardRoutingTable shard : shards) {

IndexMetadata indexMetadataForShard = indexMetadata(clusterState, shard.shardId.getIndex().getName());
if (indexMetadataForShard.isRemoteSnapshot() && (preference == null || preference.isEmpty())) {
preference = Preference.PRIMARY.type();
Expand Down Expand Up @@ -274,10 +281,25 @@ public GroupShardsIterator<ShardIterator> searchShards(
clusterState.metadata().weightedRoutingMetadata()
);
if (iterator != null) {
set.add(iterator);
shardIterators.computeIfAbsent(iterator.shardId().getIndex(), k -> new ArrayList<>()).add(iterator);
}
}
List<ShardIterator> allShardIterators = new ArrayList<>();
for (List<ShardIterator> indexIterators : shardIterators.values()) {
if (slice != null) {
// Filter the returned shards for the given slice
CollectionUtil.timSort(indexIterators);
for (int i = 0; i < indexIterators.size(); i++) {
if (slice.shardMatches(i, indexIterators.size())) {
allShardIterators.add(indexIterators.get(i));
}
}
} else {
allShardIterators.addAll(indexIterators);
}
}
return GroupShardsIterator.sortAndCreate(new ArrayList<>(set));

return GroupShardsIterator.sortAndCreate(allShardIterators);
}

public static ShardIterator getShards(ClusterState clusterState, ShardId shardId) {
Expand Down Expand Up @@ -311,6 +333,7 @@ private Set<IndexShardRoutingTable> computeTargetedShards(
set.add(indexShard);
}
}

}
return set;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;
import org.opensearch.search.builder.SearchSourceBuilder;

import java.io.IOException;
import java.util.List;
Expand Down Expand Up @@ -81,6 +82,13 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC
clusterSearchShardsRequest.routing(request.param("routing"));
clusterSearchShardsRequest.preference(request.param("preference"));
clusterSearchShardsRequest.indicesOptions(IndicesOptions.fromRequest(request, clusterSearchShardsRequest.indicesOptions()));
if (request.hasContent()) {
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
request.withContentOrSourceParamParserOrNull(sourceBuilder::parseXContent);
if (sourceBuilder.slice() != null) {
clusterSearchShardsRequest.slice(sourceBuilder.slice());
}
}
return channel -> client.admin().cluster().searchShards(clusterSearchShardsRequest, new RestToXContentListener<>(channel));
}
}
31 changes: 18 additions & 13 deletions server/src/main/java/org/opensearch/search/slice/SliceBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,15 @@ public int hashCode() {
return Objects.hash(this.field, this.id, this.max);
}

public boolean shardMatches(int shardId, int numShards) {
if (max >= numShards) {
// Slices are distributed over shards
return id % numShards == shardId;
}
// Shards are distributed over slices
return shardId % max == id;
}

/**
* Converts this QueryBuilder to a lucene {@link Query}.
*
Expand Down Expand Up @@ -255,7 +264,12 @@ public Query toFilter(ClusterService clusterService, ShardSearchRequest request,
}
}

String field = this.field;
if (shardMatches(shardId, numShards) == false) {
// We should have already excluded this shard before routing to it.
// If we somehow land here, then we match nothing.
return new MatchNoDocsQuery("this shard is not part of the slice");
}

boolean useTermQuery = false;
if ("_uid".equals(field)) {
throw new IllegalArgumentException("Computing slices on the [_uid] field is illegal for 7.x indices, use [_id] instead");
Expand All @@ -277,12 +291,7 @@ public Query toFilter(ClusterService clusterService, ShardSearchRequest request,
// the number of slices is greater than the number of shards
// in such case we can reduce the number of requested shards by slice

// first we check if the slice is responsible of this shard
int targetShard = id % numShards;
if (targetShard != shardId) {
// the shard is not part of this slice, we can skip it.
return new MatchNoDocsQuery("this shard is not part of the slice");
}
// compute the number of slices where this shard appears
int numSlicesInShard = max / numShards;
int rest = max % numShards;
Expand All @@ -301,14 +310,8 @@ public Query toFilter(ClusterService clusterService, ShardSearchRequest request,
? new TermsSliceQuery(field, shardSlice, numSlicesInShard)
: new DocValuesSliceQuery(field, shardSlice, numSlicesInShard);
}
// the number of shards is greater than the number of slices
// the number of shards is greater than the number of slices. If we target this shard, we target all of it.

// check if the shard is assigned to the slice
int targetSlice = shardId % max;
if (id != targetSlice) {
// the shard is not part of this slice, we can skip it.
return new MatchNoDocsQuery("this shard is not part of the slice");
}
return new MatchAllDocsQuery();
}

Expand All @@ -321,6 +324,8 @@ private GroupShardsIterator<ShardIterator> buildShardIterator(ClusterService clu
Map<String, Set<String>> routingMap = request.indexRoutings().length > 0
? Collections.singletonMap(indices[0], Sets.newHashSet(request.indexRoutings()))
: null;
// Note that we do *not* want to filter this set of shard IDs based on the slice, since we want the
// full set of shards matched by the routing parameters.
return clusterService.operationRouting().searchShards(state, indices, routingMap, request.preference());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,7 @@ public void testCollectSearchShards() throws Exception {
remoteIndicesByCluster,
remoteClusterService,
threadPool,
null,
new LatchedActionListener<>(ActionListener.wrap(response::set, e -> fail("no failures expected")), latch)
);
awaitLatch(latch, 5, TimeUnit.SECONDS);
Expand All @@ -835,6 +836,7 @@ public void testCollectSearchShards() throws Exception {
remoteIndicesByCluster,
remoteClusterService,
threadPool,
null,
new LatchedActionListener<>(ActionListener.wrap(r -> fail("no response expected"), failure::set), latch)
);
awaitLatch(latch, 5, TimeUnit.SECONDS);
Expand Down Expand Up @@ -880,6 +882,7 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti
remoteIndicesByCluster,
remoteClusterService,
threadPool,
null,
new LatchedActionListener<>(ActionListener.wrap(r -> fail("no response expected"), failure::set), latch)
);
awaitLatch(latch, 5, TimeUnit.SECONDS);
Expand Down Expand Up @@ -907,6 +910,7 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti
remoteIndicesByCluster,
remoteClusterService,
threadPool,
null,
new LatchedActionListener<>(ActionListener.wrap(response::set, e -> fail("no failures expected")), latch)
);
awaitLatch(latch, 5, TimeUnit.SECONDS);
Expand Down Expand Up @@ -949,6 +953,7 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti
remoteIndicesByCluster,
remoteClusterService,
threadPool,
null,
new LatchedActionListener<>(ActionListener.wrap(response::set, e -> fail("no failures expected")), latch)
);
awaitLatch(latch, 5, TimeUnit.SECONDS);
Expand Down
Loading

0 comments on commit b4aaa2f

Please sign in to comment.