Skip to content

Commit

Permalink
[TEST] Address rejected execution in SearchAsyncActionTests (#37028)
Browse files Browse the repository at this point in the history
SearchAsyncActionTests may fail with RejectedExecutionException as InitialSearchPhase may try to execute a runnable after the test has successfully completed, and the corresponding executor was already shut down. The latch was located in getNextPhase that is almost correct, but does not cover for the last finishAndRunNext round that gets executed after onShardResult is invoked.

This commit moves the latch to count the number of shards, and allowing the test to count down later, after finishAndRunNext has been potentially forked. This way nothing else will be executed once the executor is shut down at the end of the tests.

Closes #36221
Closes #33699
  • Loading branch information
javanna authored Jan 2, 2019
1 parent 35c09ad commit 9e70696
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ private void executePhase(SearchPhase phase) {
}
}


private ShardSearchFailure[] buildShardFailures() {
AtomicArray<ShardSearchFailure> shardFailures = this.shardFailures.get();
if (shardFailures == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,6 @@ public final void run() {
}
}

private void maybeFork(final Thread thread, final Runnable runnable) {
if (thread == Thread.currentThread()) {
fork(runnable);
} else {
runnable.run();
}
}

private void fork(final Runnable runnable) {
executor.execute(new AbstractRunnable() {
@Override
Expand Down Expand Up @@ -232,10 +224,18 @@ private synchronized Runnable tryQueue(Runnable runnable) {
}

private void executeNext(PendingExecutions pendingExecutions, Thread originalThread) {
executeNext(pendingExecutions == null ? null : pendingExecutions::finishAndRunNext, originalThread);
}

protected void executeNext(Runnable runnable, Thread originalThread) {
if (throttleConcurrentRequests) {
maybeFork(originalThread, pendingExecutions::finishAndRunNext);
if (originalThread == Thread.currentThread()) {
fork(runnable);
} else {
runnable.run();
}
} else {
assert pendingExecutions == null;
assert runnable == null;
}
}

Expand All @@ -256,28 +256,26 @@ private void performPhaseOnShard(final int shardIndex, final SearchShardIterator
Runnable r = () -> {
final Thread thread = Thread.currentThread();
try {
executePhaseOnShard(shardIt, shard, new SearchActionListener<FirstResult>(
shardIt.newSearchShardTarget(shard.currentNodeId()), shardIndex) {
@Override
public void innerOnResponse(FirstResult result) {
try {
onShardResult(result, shardIt);
} finally {
executeNext(pendingExecutions, thread);
executePhaseOnShard(shardIt, shard,
new SearchActionListener<FirstResult>(shardIt.newSearchShardTarget(shard.currentNodeId()), shardIndex) {
@Override
public void innerOnResponse(FirstResult result) {
try {
onShardResult(result, shardIt);
} finally {
executeNext(pendingExecutions, thread);
}
}
}

@Override
public void onFailure(Exception t) {
try {
onShardFailure(shardIndex, shard, shard.currentNodeId(), shardIt, t);
} finally {
executeNext(pendingExecutions, thread);
@Override
public void onFailure(Exception t) {
try {
onShardFailure(shardIndex, shard, shard.currentNodeId(), shardIt, t);
} finally {
executeNext(pendingExecutions, thread);
}
}
}
});


});
} catch (final Exception e) {
try {
/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,34 +62,25 @@ public class SearchAsyncActionTests extends ESTestCase {
public void testSkipSearchShards() throws InterruptedException {
SearchRequest request = new SearchRequest();
request.allowPartialSearchResults(true);
CountDownLatch latch = new CountDownLatch(1);
AtomicReference<TestSearchResponse> response = new AtomicReference<>();
ActionListener<SearchResponse> responseListener = new ActionListener<SearchResponse>() {
@Override
public void onResponse(SearchResponse searchResponse) {
response.set((TestSearchResponse) searchResponse);
}

@Override
public void onFailure(Exception e) {
logger.warn("test failed", e);
fail(e.getMessage());
}
};
int numShards = 10;
ActionListener<SearchResponse> responseListener = ActionListener.wrap(response -> {},
(e) -> { throw new AssertionError("unexpected", e);});
DiscoveryNode primaryNode = new DiscoveryNode("node_1", buildNewFakeTransportAddress(), Version.CURRENT);
DiscoveryNode replicaNode = new DiscoveryNode("node_2", buildNewFakeTransportAddress(), Version.CURRENT);

AtomicInteger contextIdGenerator = new AtomicInteger(0);
GroupShardsIterator<SearchShardIterator> shardsIter = getShardsIter("idx",
new OriginalIndices(new String[]{"idx"}, SearchRequest.DEFAULT_INDICES_OPTIONS),
10, randomBoolean(), primaryNode, replicaNode);
numShards, randomBoolean(), primaryNode, replicaNode);
int numSkipped = 0;
for (SearchShardIterator iter : shardsIter) {
if (iter.shardId().id() % 2 == 0) {
iter.resetAndSkip();
numSkipped++;
}
}
CountDownLatch latch = new CountDownLatch(numShards - numSkipped);
AtomicBoolean searchPhaseDidRun = new AtomicBoolean(false);

SearchTransportService transportService = new SearchTransportService(null, null);
Map<String, Transport.Connection> lookup = new HashMap<>();
Expand Down Expand Up @@ -142,15 +133,22 @@ protected SearchPhase getNextPhase(SearchPhaseResults<TestSearchPhaseResult> res
return new SearchPhase("test") {
@Override
public void run() {
latch.countDown();
assertTrue(searchPhaseDidRun.compareAndSet(false, true));
}
};
}

@Override
protected void executeNext(Runnable runnable, Thread originalThread) {
super.executeNext(runnable, originalThread);
latch.countDown();
}
};
asyncAction.start();
latch.await();
assertTrue(searchPhaseDidRun.get());
SearchResponse searchResponse = asyncAction.buildSearchResponse(null, null);
assertEquals(shardsIter.size()-numSkipped, numRequests.get());
assertEquals(shardsIter.size() - numSkipped, numRequests.get());
assertEquals(0, searchResponse.getFailedShards());
assertEquals(numSkipped, searchResponse.getSkippedShards());
assertEquals(shardsIter.size(), searchResponse.getSuccessfulShards());
Expand All @@ -161,28 +159,19 @@ public void testLimitConcurrentShardRequests() throws InterruptedException {
request.allowPartialSearchResults(true);
int numConcurrent = randomIntBetween(1, 5);
request.setMaxConcurrentShardRequests(numConcurrent);
CountDownLatch latch = new CountDownLatch(1);
AtomicReference<TestSearchResponse> response = new AtomicReference<>();
ActionListener<SearchResponse> responseListener = new ActionListener<SearchResponse>() {
@Override
public void onResponse(SearchResponse searchResponse) {
response.set((TestSearchResponse) searchResponse);
}

@Override
public void onFailure(Exception e) {
logger.warn("test failed", e);
fail(e.getMessage());
}
};
int numShards = 10;
CountDownLatch latch = new CountDownLatch(numShards);
AtomicBoolean searchPhaseDidRun = new AtomicBoolean(false);
ActionListener<SearchResponse> responseListener = ActionListener.wrap(response -> {},
(e) -> { throw new AssertionError("unexpected", e);});
DiscoveryNode primaryNode = new DiscoveryNode("node_1", buildNewFakeTransportAddress(), Version.CURRENT);
// for the sake of this test we place the replica on the same node. ie. this is not a mistake since we limit per node now
DiscoveryNode replicaNode = new DiscoveryNode("node_1", buildNewFakeTransportAddress(), Version.CURRENT);

AtomicInteger contextIdGenerator = new AtomicInteger(0);
GroupShardsIterator<SearchShardIterator> shardsIter = getShardsIter("idx",
new OriginalIndices(new String[]{"idx"}, SearchRequest.DEFAULT_INDICES_OPTIONS),
10, randomBoolean(), primaryNode, replicaNode);
numShards, randomBoolean(), primaryNode, replicaNode);
SearchTransportService transportService = new SearchTransportService(null, null);
Map<String, Transport.Connection> lookup = new HashMap<>();
Map<ShardId, Boolean> seenShard = new ConcurrentHashMap<>();
Expand Down Expand Up @@ -246,15 +235,22 @@ protected SearchPhase getNextPhase(SearchPhaseResults<TestSearchPhaseResult> res
return new SearchPhase("test") {
@Override
public void run() {
latch.countDown();
assertTrue(searchPhaseDidRun.compareAndSet(false, true));
}
};
}

@Override
protected void executeNext(Runnable runnable, Thread originalThread) {
super.executeNext(runnable, originalThread);
latch.countDown();
}
};
asyncAction.start();
assertEquals(numConcurrent, numRequests.get());
awaitInitialRequests.countDown();
latch.await();
assertTrue(searchPhaseDidRun.get());
assertEquals(10, numRequests.get());
}

Expand All @@ -263,26 +259,18 @@ public void testFanOutAndCollect() throws InterruptedException {
request.allowPartialSearchResults(true);
request.setMaxConcurrentShardRequests(randomIntBetween(1, 100));
AtomicReference<TestSearchResponse> response = new AtomicReference<>();
ActionListener<SearchResponse> responseListener = new ActionListener<SearchResponse>() {
@Override
public void onResponse(SearchResponse searchResponse) {
response.set((TestSearchResponse) searchResponse);
}

@Override
public void onFailure(Exception e) {
logger.warn("test failed", e);
fail(e.getMessage());
}
};
ActionListener<SearchResponse> responseListener = ActionListener.wrap(
searchResponse -> response.set((TestSearchResponse) searchResponse),
(e) -> { throw new AssertionError("unexpected", e);});
DiscoveryNode primaryNode = new DiscoveryNode("node_1", buildNewFakeTransportAddress(), Version.CURRENT);
DiscoveryNode replicaNode = new DiscoveryNode("node_2", buildNewFakeTransportAddress(), Version.CURRENT);

Map<DiscoveryNode, Set<Long>> nodeToContextMap = newConcurrentMap();
AtomicInteger contextIdGenerator = new AtomicInteger(0);
int numShards = randomIntBetween(1, 10);
GroupShardsIterator<SearchShardIterator> shardsIter = getShardsIter("idx",
new OriginalIndices(new String[]{"idx"}, SearchRequest.DEFAULT_INDICES_OPTIONS),
randomIntBetween(1, 10), randomBoolean(), primaryNode, replicaNode);
numShards, randomBoolean(), primaryNode, replicaNode);
AtomicInteger numFreedContext = new AtomicInteger();
SearchTransportService transportService = new SearchTransportService(null, null) {
@Override
Expand All @@ -296,9 +284,8 @@ public void sendFreeContext(Transport.Connection connection, long contextId, Ori
lookup.put(primaryNode.getId(), new MockConnection(primaryNode));
lookup.put(replicaNode.getId(), new MockConnection(replicaNode));
Map<String, AliasFilter> aliasFilters = Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY));
final ExecutorService executor = Executors.newFixedThreadPool(randomIntBetween(1, Runtime.getRuntime().availableProcessors()));
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean latchTriggered = new AtomicBoolean();
ExecutorService executor = Executors.newFixedThreadPool(randomIntBetween(1, Runtime.getRuntime().availableProcessors()));
final CountDownLatch latch = new CountDownLatch(numShards);
AbstractSearchAsyncAction<TestSearchPhaseResult> asyncAction =
new AbstractSearchAsyncAction<TestSearchPhaseResult>(
"test",
Expand Down Expand Up @@ -349,13 +336,15 @@ public void run() {
sendReleaseSearchContext(result.getRequestId(), new MockConnection(result.node), OriginalIndices.NONE);
}
responseListener.onResponse(response);
if (latchTriggered.compareAndSet(false, true) == false) {
throw new AssertionError("latch triggered twice");
}
latch.countDown();
}
};
}

@Override
protected void executeNext(Runnable runnable, Thread originalThread) {
super.executeNext(runnable, originalThread);
latch.countDown();
}
};
asyncAction.start();
latch.await();
Expand Down

0 comments on commit 9e70696

Please sign in to comment.