diff --git a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java index 1d6d3f284d546..1597b31e89871 100644 --- a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java @@ -456,7 +456,11 @@ private void onShardFailure(final int shardIndex, @Nullable SearchShardTarget sh } final int totalOps = this.totalOps.incrementAndGet(); if (totalOps == expectedTotalOps) { - onPhaseDone(); + try { + onPhaseDone(); + } catch (final Exception ex) { + onPhaseFailure(this, "The phase has failed", ex); + } } else if (totalOps > expectedTotalOps) { throw new AssertionError( "unexpected higher total ops [" + totalOps + "] compared to expected [" + expectedTotalOps + "]", @@ -561,7 +565,11 @@ private void successfulShardExecution(SearchShardIterator shardsIt) { } final int xTotalOps = totalOps.addAndGet(remainingOpsOnIterator); if (xTotalOps == expectedTotalOps) { - onPhaseDone(); + try { + onPhaseDone(); + } catch (final Exception ex) { + onPhaseFailure(this, "The phase has failed", ex); + } } else if (xTotalOps > expectedTotalOps) { throw new AssertionError( "unexpected higher total ops [" + xTotalOps + "] compared to expected [" + expectedTotalOps + "]", diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index f4b45b9c36f96..b44b59b8a4ad5 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -32,6 +32,8 @@ package org.opensearch.action.search; +import org.junit.After; +import org.junit.Before; import org.opensearch.action.ActionListener; import org.opensearch.action.OriginalIndices; import org.opensearch.action.support.IndicesOptions; @@ -43,25 +45,34 @@ import org.opensearch.index.Index; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.shard.ShardId; +import org.opensearch.index.shard.ShardNotFoundException; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.SearchShardTarget; import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.search.query.QuerySearchResult; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.Transport; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Set; +import java.util.UUID; import java.util.concurrent.CopyOnWriteArraySet; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; +import java.util.stream.IntStream; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -71,6 +82,22 @@ public class AbstractSearchAsyncActionTests extends OpenSearchTestCase { private final List> resolvedNodes = new ArrayList<>(); private final Set releasedContexts = new CopyOnWriteArraySet<>(); + private ExecutorService executor; + + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + executor = Executors.newFixedThreadPool(1); + } + + @After + @Override + public void tearDown() throws Exception { + super.tearDown(); + executor.shutdown(); + assertTrue(executor.awaitTermination(1, TimeUnit.SECONDS)); + } private AbstractSearchAsyncAction createAction( SearchRequest request, @@ -78,6 +105,26 @@ private AbstractSearchAsyncAction createAction( ActionListener listener, final boolean controlled, final AtomicLong expected + ) { + return createAction( + request, + results, + listener, + controlled, + false, + expected, + new SearchShardIterator(null, null, Collections.emptyList(), null) + ); + } + + private AbstractSearchAsyncAction createAction( + SearchRequest request, + ArraySearchPhaseResults results, + ActionListener listener, + final boolean controlled, + final boolean failExecutePhaseOnShard, + final AtomicLong expected, + final SearchShardIterator... shards ) { final Runnable runnable; final TransportSearchAction.SearchTimeProvider timeProvider; @@ -105,10 +152,10 @@ private AbstractSearchAsyncAction createAction( Collections.singletonMap("foo", new AliasFilter(new MatchAllQueryBuilder())), Collections.singletonMap("foo", 2.0f), Collections.singletonMap("name", Sets.newHashSet("bar", "baz")), - null, + executor, request, listener, - new GroupShardsIterator<>(Collections.singletonList(new SearchShardIterator(null, null, Collections.emptyList(), null))), + new GroupShardsIterator<>(Arrays.asList(shards)), timeProvider, ClusterState.EMPTY_STATE, null, @@ -126,7 +173,13 @@ protected void executePhaseOnShard( final SearchShardIterator shardIt, final SearchShardTarget shard, final SearchActionListener listener - ) {} + ) { + if (failExecutePhaseOnShard) { + listener.onFailure(new ShardNotFoundException(shardIt.shardId())); + } else { + listener.onResponse(new QuerySearchResult()); + } + } @Override long buildTookInMillis() { @@ -328,6 +381,102 @@ private static ArraySearchPhaseResults phaseResults( return phaseResults; } + public void testOnShardFailurePhaseDoneFailure() throws InterruptedException { + final Index index = new Index("test", UUID.randomUUID().toString()); + final CountDownLatch latch = new CountDownLatch(1); + final AtomicBoolean fail = new AtomicBoolean(true); + + final SearchShardIterator[] shards = IntStream.range(0, 5 + randomInt(10)) + .mapToObj(i -> new SearchShardIterator(null, new ShardId(index, i), List.of("n1", "n2", "n3"), null, null, null)) + .toArray(SearchShardIterator[]::new); + + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + searchRequest.setMaxConcurrentShardRequests(1); + + final ArraySearchPhaseResults queryResult = new ArraySearchPhaseResults<>(shards.length); + AbstractSearchAsyncAction action = createAction( + searchRequest, + queryResult, + new ActionListener() { + @Override + public void onResponse(SearchResponse response) { + + } + + @Override + public void onFailure(Exception e) { + if (fail.compareAndExchange(true, false)) { + try { + throw new RuntimeException("Simulated exception"); + } finally { + executor.submit(() -> latch.countDown()); + } + } + } + }, + false, + true, + new AtomicLong(), + shards + ); + action.run(); + assertTrue(latch.await(1, TimeUnit.SECONDS)); + + InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty(); + SearchResponse searchResponse = action.buildSearchResponse(internalSearchResponse, action.buildShardFailures(), null, null); + assertSame(searchResponse.getAggregations(), internalSearchResponse.aggregations()); + assertSame(searchResponse.getSuggest(), internalSearchResponse.suggest()); + assertSame(searchResponse.getProfileResults(), internalSearchResponse.profile()); + assertSame(searchResponse.getHits(), internalSearchResponse.hits()); + assertThat(searchResponse.getSuccessfulShards(), equalTo(0)); + } + + public void testOnShardSuccessPhaseDoneFailure() throws InterruptedException { + final Index index = new Index("test", UUID.randomUUID().toString()); + final CountDownLatch latch = new CountDownLatch(1); + final AtomicBoolean fail = new AtomicBoolean(true); + + final SearchShardIterator[] shards = IntStream.range(0, 5 + randomInt(10)) + .mapToObj(i -> new SearchShardIterator(null, new ShardId(index, i), List.of("n1", "n2", "n3"), null, null, null)) + .toArray(SearchShardIterator[]::new); + + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + searchRequest.setMaxConcurrentShardRequests(1); + + final ArraySearchPhaseResults queryResult = new ArraySearchPhaseResults<>(shards.length); + AbstractSearchAsyncAction action = createAction( + searchRequest, + queryResult, + new ActionListener() { + @Override + public void onResponse(SearchResponse response) { + if (fail.compareAndExchange(true, false)) { + throw new RuntimeException("Simulated exception"); + } + } + + @Override + public void onFailure(Exception e) { + executor.submit(() -> latch.countDown()); + } + }, + false, + false, + new AtomicLong(), + shards + ); + action.run(); + assertTrue(latch.await(1, TimeUnit.SECONDS)); + + InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty(); + SearchResponse searchResponse = action.buildSearchResponse(internalSearchResponse, action.buildShardFailures(), null, null); + assertSame(searchResponse.getAggregations(), internalSearchResponse.aggregations()); + assertSame(searchResponse.getSuggest(), internalSearchResponse.suggest()); + assertSame(searchResponse.getProfileResults(), internalSearchResponse.profile()); + assertSame(searchResponse.getHits(), internalSearchResponse.hits()); + assertThat(searchResponse.getSuccessfulShards(), equalTo(shards.length)); + } + private static final class PhaseResult extends SearchPhaseResult { PhaseResult(ShardSearchContextId contextId) { this.contextId = contextId;