diff --git a/core/src/main/java/org/elasticsearch/client/transport/TransportClientNodesService.java b/core/src/main/java/org/elasticsearch/client/transport/TransportClientNodesService.java index 145208baf0db8..3e50b4d74c916 100644 --- a/core/src/main/java/org/elasticsearch/client/transport/TransportClientNodesService.java +++ b/core/src/main/java/org/elasticsearch/client/transport/TransportClientNodesService.java @@ -469,14 +469,17 @@ protected void doSample() { */ Transport.Connection connectionToClose = null; - @Override - public void onAfter() { - IOUtils.closeWhileHandlingException(connectionToClose); + void onDone() { + try { + IOUtils.closeWhileHandlingException(connectionToClose); + } finally { + latch.countDown(); + } } @Override public void onFailure(Exception e) { - latch.countDown(); + onDone(); if (e instanceof ConnectTransportException) { logger.debug((Supplier) () -> new ParameterizedMessage("failed to connect to node [{}], ignoring...", nodeToPing), e); @@ -522,7 +525,7 @@ public String executor() { @Override public void handleResponse(ClusterStateResponse response) { clusterStateResponses.put(nodeToPing, response); - latch.countDown(); + onDone(); } @Override @@ -532,9 +535,8 @@ public void handleException(TransportException e) { "failed to get local cluster state for {}, disconnecting...", nodeToPing), e); try { hostFailureListener.onNodeDisconnected(nodeToPing, e); - } - finally { - latch.countDown(); + } finally { + onDone(); } } }); diff --git a/core/src/test/java/org/elasticsearch/client/transport/TransportClientNodesServiceTests.java b/core/src/test/java/org/elasticsearch/client/transport/TransportClientNodesServiceTests.java index e2e200b9e2795..b260163a81c0f 100644 --- a/core/src/test/java/org/elasticsearch/client/transport/TransportClientNodesServiceTests.java +++ b/core/src/test/java/org/elasticsearch/client/transport/TransportClientNodesServiceTests.java @@ -19,21 +19,13 @@ package org.elasticsearch.client.transport; -import java.io.Closeable; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; - import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.cluster.node.liveness.LivenessResponse; import org.elasticsearch.action.admin.cluster.node.liveness.TransportLivenessAction; +import org.elasticsearch.action.admin.cluster.state.ClusterStateAction; +import org.elasticsearch.action.admin.cluster.state.ClusterStateRequest; +import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.node.DiscoveryNode; @@ -42,19 +34,40 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.LocalTransportAddress; import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.node.Node; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.ConnectionProfile; import org.elasticsearch.transport.Transport; +import org.elasticsearch.transport.TransportChannel; import org.elasticsearch.transport.TransportException; import org.elasticsearch.transport.TransportInterceptor; import org.elasticsearch.transport.TransportRequest; +import org.elasticsearch.transport.TransportRequestHandler; import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.transport.TransportResponseHandler; import org.elasticsearch.transport.TransportService; import org.hamcrest.CustomMatcher; +import java.io.Closeable; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.elasticsearch.test.transport.MockTransportService.createNewService; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.everyItem; import static org.hamcrest.CoreMatchers.instanceOf; @@ -323,6 +336,157 @@ public boolean matches(Object item) { } } + public void testSniffNodesSamplerClosesConnections() throws Exception { + final TestThreadPool threadPool = new TestThreadPool("testSniffNodesSamplerClosesConnections"); + + Settings remoteSettings = Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), "remote").build(); + try (MockTransportService remoteService = createNewService(remoteSettings, Version.CURRENT, threadPool, null)) { + final MockHandler handler = new MockHandler(remoteService); + remoteService.registerRequestHandler(ClusterStateAction.NAME, ClusterStateRequest::new, ThreadPool.Names.SAME, handler); + remoteService.start(); + remoteService.acceptIncomingRequests(); + + Settings clientSettings = Settings.builder() + .put(TransportClient.CLIENT_TRANSPORT_SNIFF.getKey(), true) + .put(TransportClient.CLIENT_TRANSPORT_PING_TIMEOUT.getKey(), TimeValue.timeValueSeconds(1)) + .put(TransportClient.CLIENT_TRANSPORT_NODES_SAMPLER_INTERVAL.getKey(), TimeValue.timeValueSeconds(30)) + .build(); + + try (MockTransportService clientService = createNewService(clientSettings, Version.CURRENT, threadPool, null)) { + final List establishedConnections = new CopyOnWriteArrayList<>(); + final List reusedConnections = new CopyOnWriteArrayList<>(); + + clientService.addDelegate(remoteService, new MockTransportService.DelegateTransport(clientService.original()) { + @Override + public Connection openConnection(DiscoveryNode node, ConnectionProfile profile) throws IOException { + MockConnection connection = new MockConnection(super.openConnection(node, profile)); + establishedConnections.add(connection); + return connection; + } + + @Override + public Connection getConnection(DiscoveryNode node) { + MockConnection connection = new MockConnection(super.getConnection(node)); + reusedConnections.add(connection); + return connection; + } + }); + + clientService.start(); + clientService.acceptIncomingRequests(); + + try (TransportClientNodesService transportClientNodesService = + new TransportClientNodesService(clientSettings, clientService, threadPool, (a, b) -> {})) { + assertEquals(0, transportClientNodesService.connectedNodes().size()); + assertEquals(0, establishedConnections.size()); + assertEquals(0, reusedConnections.size()); + + transportClientNodesService.addTransportAddresses(remoteService.getLocalDiscoNode().getAddress()); + assertEquals(1, transportClientNodesService.connectedNodes().size()); + assertClosedConnections(establishedConnections, 1); + + transportClientNodesService.doSample(); + assertClosedConnections(establishedConnections, 2); + assertOpenConnections(reusedConnections, 1); + + handler.blockRequest(); + Thread thread = new Thread(transportClientNodesService::doSample); + thread.start(); + + assertBusy(() -> assertEquals(3, establishedConnections.size())); + assertFalse("Temporary ping connection must be opened", establishedConnections.get(2).isClosed()); + + handler.releaseRequest(); + thread.join(); + + assertClosedConnections(establishedConnections, 3); + } + } + } finally { + terminate(threadPool); + } + } + + private void assertClosedConnections(final List connections, final int size) { + assertEquals("Expecting " + size + " closed connections but got " + connections.size(), size, connections.size()); + connections.forEach(c -> assertConnection(c, true)); + } + + private void assertOpenConnections(final List connections, final int size) { + assertEquals("Expecting " + size + " open connections but got " + connections.size(), size, connections.size()); + connections.forEach(c -> assertConnection(c, false)); + } + + private static void assertConnection(final MockConnection connection, final boolean closed) { + assertEquals("Connection [" + connection + "] must be " + (closed ? "closed" : "open"), closed, connection.isClosed()); + } + + class MockConnection implements Transport.Connection { + private final AtomicBoolean closed = new AtomicBoolean(false); + private final Transport.Connection connection; + + private MockConnection(Transport.Connection connection) { + this.connection = connection; + } + + @Override + public DiscoveryNode getNode() { + return connection.getNode(); + } + + @Override + public Version getVersion() { + return connection.getVersion(); + } + + @Override + public void sendRequest(long requestId, String action, TransportRequest request, TransportRequestOptions options) + throws IOException, TransportException { + connection.sendRequest(requestId, action, request, options); + } + + @Override + public void close() throws IOException { + if (closed.compareAndSet(false, true)) { + connection.close(); + } + } + + boolean isClosed() { + return closed.get(); + } + } + + class MockHandler implements TransportRequestHandler { + private final AtomicBoolean block = new AtomicBoolean(false); + private final CountDownLatch release = new CountDownLatch(1); + private final MockTransportService transportService; + + MockHandler(MockTransportService transportService) { + this.transportService = transportService; + } + + @Override + public void messageReceived(ClusterStateRequest request, TransportChannel channel) throws Exception { + if (block.get()) { + release.await(); + return; + } + DiscoveryNodes discoveryNodes = DiscoveryNodes.builder().add(transportService.getLocalDiscoNode()).build(); + ClusterState build = ClusterState.builder(ClusterName.DEFAULT).nodes(discoveryNodes).build(); + channel.sendResponse(new ClusterStateResponse(ClusterName.DEFAULT, build)); + } + + void blockRequest() { + if (block.compareAndSet(false, true) == false) { + throw new AssertionError("Request handler is already marked as blocking"); + } + } + void releaseRequest() { + release.countDown(); + } + } + public static class TestRequest extends TransportRequest { } diff --git a/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java b/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java index 5e74333159770..e888684639ab4 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java +++ b/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java @@ -109,6 +109,14 @@ protected Version getVersion() { return createNewService(settings, transport, version, threadPool, clusterSettings); } + public static MockTransportService createNewService(Settings settings, Version version, ThreadPool threadPool, + @Nullable ClusterSettings clusterSettings) { + NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(ClusterModule.getNamedWriteables()); + final Transport transport = new MockTcpTransport(settings, threadPool, BigArrays.NON_RECYCLING_INSTANCE, + new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(settings, Collections.emptyList()), version); + return createNewService(settings, transport, version, threadPool, clusterSettings); + } + public static MockTransportService createNewService(Settings settings, Transport transport, Version version, ThreadPool threadPool, @Nullable ClusterSettings clusterSettings) { return new MockTransportService(settings, transport, threadPool, TransportService.NOOP_TRANSPORT_INTERCEPTOR,