Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] opensearch crashes on closed client connection before search reply #3626

Merged
merged 2 commits into from
Jun 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,11 @@ private void onShardFailure(final int shardIndex, @Nullable SearchShardTarget sh
}
final int totalOps = this.totalOps.incrementAndGet();
if (totalOps == expectedTotalOps) {
onPhaseDone();
try {
onPhaseDone();
} catch (final Exception ex) {
onPhaseFailure(this, "The phase has failed", ex);
}
} else if (totalOps > expectedTotalOps) {
throw new AssertionError(
"unexpected higher total ops [" + totalOps + "] compared to expected [" + expectedTotalOps + "]",
Expand Down Expand Up @@ -561,7 +565,11 @@ private void successfulShardExecution(SearchShardIterator shardsIt) {
}
final int xTotalOps = totalOps.addAndGet(remainingOpsOnIterator);
if (xTotalOps == expectedTotalOps) {
onPhaseDone();
try {
onPhaseDone();
} catch (final Exception ex) {
onPhaseFailure(this, "The phase has failed", ex);
}
} else if (xTotalOps > expectedTotalOps) {
throw new AssertionError(
"unexpected higher total ops [" + xTotalOps + "] compared to expected [" + expectedTotalOps + "]",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

package org.opensearch.action.search;

import org.junit.After;
import org.junit.Before;
import org.opensearch.action.ActionListener;
import org.opensearch.action.OriginalIndices;
import org.opensearch.action.support.IndicesOptions;
Expand All @@ -43,25 +45,34 @@
import org.opensearch.index.Index;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.shard.ShardId;
import org.opensearch.index.shard.ShardNotFoundException;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.internal.AliasFilter;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.internal.ShardSearchContextId;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.transport.Transport;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;
import java.util.stream.IntStream;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
Expand All @@ -71,13 +82,49 @@ public class AbstractSearchAsyncActionTests extends OpenSearchTestCase {

private final List<Tuple<String, String>> resolvedNodes = new ArrayList<>();
private final Set<ShardSearchContextId> releasedContexts = new CopyOnWriteArraySet<>();
private ExecutorService executor;

@Before
@Override
public void setUp() throws Exception {
super.setUp();
executor = Executors.newFixedThreadPool(1);
}

@After
@Override
public void tearDown() throws Exception {
super.tearDown();
executor.shutdown();
assertTrue(executor.awaitTermination(1, TimeUnit.SECONDS));
}

private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
SearchRequest request,
ArraySearchPhaseResults<SearchPhaseResult> results,
ActionListener<SearchResponse> listener,
final boolean controlled,
final AtomicLong expected
) {
return createAction(
request,
results,
listener,
controlled,
false,
expected,
new SearchShardIterator(null, null, Collections.emptyList(), null)
);
}

private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
SearchRequest request,
ArraySearchPhaseResults<SearchPhaseResult> results,
ActionListener<SearchResponse> listener,
final boolean controlled,
final boolean failExecutePhaseOnShard,
final AtomicLong expected,
final SearchShardIterator... shards
) {
final Runnable runnable;
final TransportSearchAction.SearchTimeProvider timeProvider;
Expand Down Expand Up @@ -105,10 +152,10 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
Collections.singletonMap("foo", new AliasFilter(new MatchAllQueryBuilder())),
Collections.singletonMap("foo", 2.0f),
Collections.singletonMap("name", Sets.newHashSet("bar", "baz")),
null,
executor,
request,
listener,
new GroupShardsIterator<>(Collections.singletonList(new SearchShardIterator(null, null, Collections.emptyList(), null))),
new GroupShardsIterator<>(Arrays.asList(shards)),
timeProvider,
ClusterState.EMPTY_STATE,
null,
Expand All @@ -126,7 +173,13 @@ protected void executePhaseOnShard(
final SearchShardIterator shardIt,
final SearchShardTarget shard,
final SearchActionListener<SearchPhaseResult> listener
) {}
) {
if (failExecutePhaseOnShard) {
listener.onFailure(new ShardNotFoundException(shardIt.shardId()));
} else {
listener.onResponse(new QuerySearchResult());
}
}

@Override
long buildTookInMillis() {
Expand Down Expand Up @@ -328,6 +381,102 @@ private static ArraySearchPhaseResults<SearchPhaseResult> phaseResults(
return phaseResults;
}

public void testOnShardFailurePhaseDoneFailure() throws InterruptedException {
final Index index = new Index("test", UUID.randomUUID().toString());
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean fail = new AtomicBoolean(true);

final SearchShardIterator[] shards = IntStream.range(0, 5 + randomInt(10))
.mapToObj(i -> new SearchShardIterator(null, new ShardId(index, i), List.of("n1", "n2", "n3"), null, null, null))
.toArray(SearchShardIterator[]::new);

SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true);
searchRequest.setMaxConcurrentShardRequests(1);

final ArraySearchPhaseResults<SearchPhaseResult> queryResult = new ArraySearchPhaseResults<>(shards.length);
AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(
searchRequest,
queryResult,
new ActionListener<SearchResponse>() {
@Override
public void onResponse(SearchResponse response) {

}

@Override
public void onFailure(Exception e) {
if (fail.compareAndExchange(true, false)) {
try {
throw new RuntimeException("Simulated exception");
} finally {
executor.submit(() -> latch.countDown());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this countDown() call have to happen on a separate thread? I think this is the only reason the executor is wired in here and it's not obvious to me why you can't just synchronously invoke countDown() here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good one, the execution of this callback is happening inside executor thread, so for predictability, we have to make sure the current thread finishes and we could release the latch. Otherwise, the current thread will unblock the latch, but the processing may not be finished just yet, making tests flaky.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I missed that the executor was being used to fork off the callbacks as well.

}
}
}
},
false,
true,
new AtomicLong(),
shards
);
action.run();
assertTrue(latch.await(1, TimeUnit.SECONDS));

InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty();
SearchResponse searchResponse = action.buildSearchResponse(internalSearchResponse, action.buildShardFailures(), null, null);
assertSame(searchResponse.getAggregations(), internalSearchResponse.aggregations());
assertSame(searchResponse.getSuggest(), internalSearchResponse.suggest());
assertSame(searchResponse.getProfileResults(), internalSearchResponse.profile());
assertSame(searchResponse.getHits(), internalSearchResponse.hits());
assertThat(searchResponse.getSuccessfulShards(), equalTo(0));
}

public void testOnShardSuccessPhaseDoneFailure() throws InterruptedException {
final Index index = new Index("test", UUID.randomUUID().toString());
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean fail = new AtomicBoolean(true);

final SearchShardIterator[] shards = IntStream.range(0, 5 + randomInt(10))
.mapToObj(i -> new SearchShardIterator(null, new ShardId(index, i), List.of("n1", "n2", "n3"), null, null, null))
.toArray(SearchShardIterator[]::new);

SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true);
searchRequest.setMaxConcurrentShardRequests(1);

final ArraySearchPhaseResults<SearchPhaseResult> queryResult = new ArraySearchPhaseResults<>(shards.length);
AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(
searchRequest,
queryResult,
new ActionListener<SearchResponse>() {
@Override
public void onResponse(SearchResponse response) {
if (fail.compareAndExchange(true, false)) {
throw new RuntimeException("Simulated exception");
}
}

@Override
public void onFailure(Exception e) {
executor.submit(() -> latch.countDown());
}
},
false,
false,
new AtomicLong(),
shards
);
action.run();
assertTrue(latch.await(1, TimeUnit.SECONDS));

InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty();
SearchResponse searchResponse = action.buildSearchResponse(internalSearchResponse, action.buildShardFailures(), null, null);
assertSame(searchResponse.getAggregations(), internalSearchResponse.aggregations());
assertSame(searchResponse.getSuggest(), internalSearchResponse.suggest());
assertSame(searchResponse.getProfileResults(), internalSearchResponse.profile());
assertSame(searchResponse.getHits(), internalSearchResponse.hits());
assertThat(searchResponse.getSuccessfulShards(), equalTo(shards.length));
}

private static final class PhaseResult extends SearchPhaseResult {
PhaseResult(ShardSearchContextId contextId) {
this.contextId = contextId;
Expand Down