diff --git a/docs/changelog/92558.yaml b/docs/changelog/92558.yaml new file mode 100644 index 0000000000000..f0621449fa20d --- /dev/null +++ b/docs/changelog/92558.yaml @@ -0,0 +1,6 @@ +pr: 92558 +summary: Protect `NodeConnectionsService` from stale conns +area: Network +type: bug +issues: + - 92029 diff --git a/server/src/main/java/org/elasticsearch/cluster/NodeConnectionsService.java b/server/src/main/java/org/elasticsearch/cluster/NodeConnectionsService.java index 2e67288358c2b..8daa253988e04 100644 --- a/server/src/main/java/org/elasticsearch/cluster/NodeConnectionsService.java +++ b/server/src/main/java/org/elasticsearch/cluster/NodeConnectionsService.java @@ -21,6 +21,7 @@ import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.common.util.concurrent.ListenableFuture; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.core.TimeValue; @@ -219,6 +220,13 @@ private class ConnectionTarget { private final AtomicInteger consecutiveFailureCount = new AtomicInteger(); private final AtomicReference connectionRef = new AtomicReference<>(); + // all access to these fields is synchronized + private ActionListener pendingListener; + private boolean connectionInProgress; + + // placeholder listener for a fire-and-forget connection attempt + private static final ActionListener NOOP = ActionListener.noop(); + ConnectionTarget(DiscoveryNode discoveryNode) { this.discoveryNode = discoveryNode; } @@ -229,57 +237,97 @@ private void setConnectionRef(Releasable connectionReleasable) { Runnable connect(ActionListener listener) { return () -> { - final boolean alreadyConnected = transportService.nodeConnected(discoveryNode); + registerListener(listener); + doConnect(); + }; + } - if (alreadyConnected) { - logger.trace("refreshing connection to {}", discoveryNode); - } else { - logger.debug("connecting to {}", discoveryNode); + private synchronized void registerListener(ActionListener listener) { + if (listener == null) { + pendingListener = pendingListener == null ? NOOP : pendingListener; + } else if (pendingListener == null || pendingListener == NOOP) { + pendingListener = listener; + } else if (pendingListener instanceof ListenableFuture listenableFuture) { + listenableFuture.addListener(listener); + } else { + var wrapper = new ListenableFuture(); + wrapper.addListener(pendingListener); + wrapper.addListener(listener); + pendingListener = wrapper; + } + } + + private synchronized ActionListener acquireListener() { + // Avoid concurrent connection attempts because they don't necessarily complete in order otherwise, and out-of-order completion + // might mean we end up disconnected from a node even though we triggered a call to connect() after all close() calls had + // finished. + if (connectionInProgress == false) { + var listener = pendingListener; + if (listener != null) { + pendingListener = null; + connectionInProgress = true; + return listener; } + } + return null; + } - // It's possible that connectionRef is a reference to an older connection that closed out from under us, but that something - // else has opened a fresh connection to the node. Therefore we always call connectToNode() and update connectionRef. - transportService.connectToNode(discoveryNode, new ActionListener<>() { - @Override - public void onResponse(Releasable connectionReleasable) { - if (alreadyConnected) { - logger.trace("refreshed connection to {}", discoveryNode); - } else { - logger.debug("connected to {}", discoveryNode); - } - consecutiveFailureCount.set(0); - setConnectionRef(connectionReleasable); - - final boolean isActive; - synchronized (mutex) { - isActive = targetsByNode.get(discoveryNode) == ConnectionTarget.this; - } - if (isActive == false) { - logger.debug("connected to stale {} - releasing stale connection", discoveryNode); - setConnectionRef(null); - } - if (listener != null) { - listener.onResponse(null); - } + private synchronized void releaseListener() { + assert connectionInProgress; + connectionInProgress = false; + } + + private void doConnect() { + var listener = acquireListener(); + if (listener == null) { + return; + } + + final boolean alreadyConnected = transportService.nodeConnected(discoveryNode); + + if (alreadyConnected) { + logger.trace("refreshing connection to {}", discoveryNode); + } else { + logger.debug("connecting to {}", discoveryNode); + } + + // It's possible that connectionRef is a reference to an older connection that closed out from under us, but that something else + // has opened a fresh connection to the node. Therefore we always call connectToNode() and update connectionRef. + transportService.connectToNode(discoveryNode, ActionListener.runAfter(new ActionListener<>() { + @Override + public void onResponse(Releasable connectionReleasable) { + if (alreadyConnected) { + logger.trace("refreshed connection to {}", discoveryNode); + } else { + logger.debug("connected to {}", discoveryNode); } + consecutiveFailureCount.set(0); + setConnectionRef(connectionReleasable); - @Override - public void onFailure(Exception e) { - final int currentFailureCount = consecutiveFailureCount.incrementAndGet(); - // only warn every 6th failure - final Level level = currentFailureCount % 6 == 1 ? Level.WARN : Level.DEBUG; - logger.log( - level, - () -> format("failed to connect to %s (tried [%s] times)", discoveryNode, currentFailureCount), - e - ); + final boolean isActive; + synchronized (mutex) { + isActive = targetsByNode.get(discoveryNode) == ConnectionTarget.this; + } + if (isActive == false) { + logger.debug("connected to stale {} - releasing stale connection", discoveryNode); setConnectionRef(null); - if (listener != null) { - listener.onFailure(e); - } } - }); - }; + listener.onResponse(null); + } + + @Override + public void onFailure(Exception e) { + final int currentFailureCount = consecutiveFailureCount.incrementAndGet(); + // only warn every 6th failure + final Level level = currentFailureCount % 6 == 1 ? Level.WARN : Level.DEBUG; + logger.log(level, () -> format("failed to connect to %s (tried [%s] times)", discoveryNode, currentFailureCount), e); + setConnectionRef(null); + listener.onFailure(e); + } + }, () -> { + releaseListener(); + transportService.getThreadPool().generic().execute(this::doConnect); + })); } void disconnect() { diff --git a/server/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java index 461fec42d9fc5..a0e491858e9e6 100644 --- a/server/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java @@ -55,7 +55,9 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -97,9 +99,7 @@ public void testEventuallyConnectsOnlyToAppliedNodes() throws Exception { final AtomicBoolean keepGoing = new AtomicBoolean(true); final Thread reconnectionThread = new Thread(() -> { while (keepGoing.get()) { - final PlainActionFuture future = new PlainActionFuture<>(); - service.ensureConnections(() -> future.onResponse(null)); - future.actionGet(); + ensureConnections(service); } }, "reconnection thread"); reconnectionThread.start(); @@ -109,34 +109,18 @@ public void testEventuallyConnectsOnlyToAppliedNodes() throws Exception { final boolean isDisrupting = randomBoolean(); final Thread disruptionThread = new Thread(() -> { while (isDisrupting && keepGoing.get()) { - final Transport.Connection connection; - try { - connection = transportService.getConnection(randomFrom(allNodes)); - } catch (NodeNotConnectedException e) { - continue; - } - - final PlainActionFuture future = new PlainActionFuture<>(); - connection.addRemovedListener(future); - connection.close(); - future.actionGet(10, TimeUnit.SECONDS); + closeConnection(transportService, randomFrom(allNodes)); } }, "disruption thread"); disruptionThread.start(); for (int i = 0; i < 10; i++) { - final DiscoveryNodes connectNodes = discoveryNodesFromList(randomSubsetOf(allNodes)); - final PlainActionFuture future = new PlainActionFuture<>(); - service.connectToNodes(connectNodes, () -> future.onResponse(null)); - future.actionGet(10, TimeUnit.SECONDS); - final DiscoveryNodes disconnectExceptNodes = discoveryNodesFromList(randomSubsetOf(allNodes)); - service.disconnectFromNodesExcept(disconnectExceptNodes); + connectToNodes(service, discoveryNodesFromList(randomSubsetOf(allNodes))); + service.disconnectFromNodesExcept(discoveryNodesFromList(randomSubsetOf(allNodes))); } final DiscoveryNodes nodes = discoveryNodesFromList(randomSubsetOf(allNodes)); - final PlainActionFuture connectFuture = new PlainActionFuture<>(); - service.connectToNodes(nodes, () -> connectFuture.onResponse(null)); - connectFuture.actionGet(10, TimeUnit.SECONDS); + connectToNodes(service, nodes); service.disconnectFromNodesExcept(nodes); assertTrue(keepGoing.compareAndSet(true, false)); @@ -144,14 +128,59 @@ public void testEventuallyConnectsOnlyToAppliedNodes() throws Exception { disruptionThread.join(); if (isDisrupting) { - final PlainActionFuture ensureFuture = new PlainActionFuture<>(); - service.ensureConnections(() -> ensureFuture.onResponse(null)); - ensureFuture.actionGet(10, TimeUnit.SECONDS); + ensureConnections(service); } + assertConnected(transportService, nodes); assertBusy(() -> assertConnectedExactlyToNodes(nodes)); } + public void testConcurrentConnectAndDisconnect() throws Exception { + final NodeConnectionsService service = new NodeConnectionsService(Settings.EMPTY, threadPool, transportService); + + final AtomicBoolean keepGoing = new AtomicBoolean(true); + final Thread reconnectionThread = new Thread(() -> { + while (keepGoing.get()) { + ensureConnections(service); + } + }, "reconnection thread"); + reconnectionThread.start(); + + final var node = new DiscoveryNode("node", buildNewFakeTransportAddress(), Map.of(), Set.of(), Version.CURRENT); + final var nodes = discoveryNodesFromList(List.of(node)); + + final Thread disruptionThread = new Thread(() -> { + while (keepGoing.get()) { + closeConnection(transportService, node); + } + }, "disruption thread"); + disruptionThread.start(); + + final var reconnectPermits = new Semaphore(1000); + final var reconnectThreads = 10; + final var reconnectCountDown = new CountDownLatch(reconnectThreads); + for (int i = 0; i < reconnectThreads; i++) { + threadPool.generic().execute(new Runnable() { + @Override + public void run() { + if (reconnectPermits.tryAcquire()) { + service.connectToNodes(nodes, () -> threadPool.generic().execute(this)); + } else { + reconnectCountDown.countDown(); + } + } + }); + } + + assertTrue(reconnectCountDown.await(10, TimeUnit.SECONDS)); + assertTrue(keepGoing.compareAndSet(true, false)); + reconnectionThread.join(); + disruptionThread.join(); + + ensureConnections(service); + assertConnectedExactlyToNodes(nodes); + } + public void testPeriodicReconnection() { final Settings.Builder settings = Settings.builder(); final long reconnectIntervalMillis; @@ -234,30 +263,24 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti // connect to one node final DiscoveryNode node0 = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT); final DiscoveryNodes nodes0 = DiscoveryNodes.builder().add(node0).build(); - final PlainActionFuture future0 = new PlainActionFuture<>(); - service.connectToNodes(nodes0, () -> future0.onResponse(null)); - future0.actionGet(10, TimeUnit.SECONDS); + connectToNodes(service, nodes0); assertConnectedExactlyToNodes(nodes0); // connection attempts to node0 block indefinitely final CyclicBarrier connectionBarrier = new CyclicBarrier(2); try { - nodeConnectionBlocks.put(node0, connectionBarrier::await); + nodeConnectionBlocks.put(node0, () -> connectionBarrier.await(10, TimeUnit.SECONDS)); transportService.disconnectFromNode(node0); // can still connect to another node without blocking final DiscoveryNode node1 = new DiscoveryNode("node1", buildNewFakeTransportAddress(), Version.CURRENT); final DiscoveryNodes nodes1 = DiscoveryNodes.builder().add(node1).build(); final DiscoveryNodes nodes01 = DiscoveryNodes.builder(nodes0).add(node1).build(); - final PlainActionFuture future1 = new PlainActionFuture<>(); - service.connectToNodes(nodes01, () -> future1.onResponse(null)); - future1.actionGet(10, TimeUnit.SECONDS); + connectToNodes(service, nodes01); assertConnectedExactlyToNodes(nodes1); // can also disconnect from node0 without blocking - final PlainActionFuture future2 = new PlainActionFuture<>(); - service.connectToNodes(nodes1, () -> future2.onResponse(null)); - future2.actionGet(10, TimeUnit.SECONDS); + connectToNodes(service, nodes1); service.disconnectFromNodesExcept(nodes1); assertConnectedExactlyToNodes(nodes1); @@ -273,17 +296,15 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti // the reconnection is also blocked but the connection future doesn't wait, it completes straight away transportService.disconnectFromNode(node0); - final PlainActionFuture future4 = new PlainActionFuture<>(); - service.connectToNodes(nodes01, () -> future4.onResponse(null)); - future4.actionGet(10, TimeUnit.SECONDS); + connectToNodes(service, nodes01); assertConnectedExactlyToNodes(nodes1); // a blocked reconnection attempt doesn't also block the node from being deregistered service.disconnectFromNodesExcept(nodes1); - final PlainActionFuture disconnectFuture1 = new PlainActionFuture<>(); - assertTrue(disconnectListenerRef.compareAndSet(null, disconnectFuture1)); - connectionBarrier.await(); - assertThat(disconnectFuture1.actionGet(10, TimeUnit.SECONDS), equalTo(node0)); // node0 connects briefly, must wait here + assertThat(PlainActionFuture.get(disconnectFuture1 -> { + assertTrue(disconnectListenerRef.compareAndSet(null, disconnectFuture1)); + connectionBarrier.await(10, TimeUnit.SECONDS); + }, 10, TimeUnit.SECONDS), equalTo(node0)); // node0 connects briefly, must wait here assertConnectedExactlyToNodes(nodes1); // a blocked connection attempt to a new node also doesn't prevent an immediate deregistration @@ -294,10 +315,10 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti service.disconnectFromNodesExcept(nodes1); assertConnectedExactlyToNodes(nodes1); - final PlainActionFuture disconnectFuture2 = new PlainActionFuture<>(); - assertTrue(disconnectListenerRef.compareAndSet(null, disconnectFuture2)); - connectionBarrier.await(10, TimeUnit.SECONDS); - assertThat(disconnectFuture2.actionGet(10, TimeUnit.SECONDS), equalTo(node0)); // node0 connects briefly, must wait here + assertThat(PlainActionFuture.get(disconnectFuture2 -> { + assertTrue(disconnectListenerRef.compareAndSet(null, disconnectFuture2)); + connectionBarrier.await(10, TimeUnit.SECONDS); + }, 10, TimeUnit.SECONDS), equalTo(node0)); // node0 connects briefly, must wait here assertConnectedExactlyToNodes(nodes1); assertTrue(future5.isDone()); } finally { @@ -310,7 +331,7 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti reason = "testing that DEBUG-level logging is reasonable", value = "org.elasticsearch.cluster.NodeConnectionsService:DEBUG" ) - public void testDebugLogging() throws IllegalAccessException { + public void testDebugLogging() { final DeterministicTaskQueue deterministicTaskQueue = new DeterministicTaskQueue(); MockTransport transport = new MockTransport(deterministicTaskQueue.getThreadPool()); @@ -706,4 +727,23 @@ public RequestHandlers getRequestHandlers() { return requestHandlers; } } + + private static void connectToNodes(NodeConnectionsService service, DiscoveryNodes discoveryNodes) { + PlainActionFuture.get(future -> service.connectToNodes(discoveryNodes, () -> future.onResponse(null)), 10, TimeUnit.SECONDS); + } + + private static void ensureConnections(NodeConnectionsService service) { + PlainActionFuture.get(future -> service.ensureConnections(() -> future.onResponse(null)), 10, TimeUnit.SECONDS); + } + + private static void closeConnection(TransportService transportService, DiscoveryNode discoveryNode) { + try { + final var connection = transportService.getConnection(discoveryNode); + connection.close(); + PlainActionFuture.get(connection::addRemovedListener, 10, TimeUnit.SECONDS); + } catch (NodeNotConnectedException e) { + // ok + } + } + }