Skip to content

Commit

Permalink
Protect NodeConnectionsService from stale conns (#92558)
Browse files Browse the repository at this point in the history
A call to `ConnectionTarget#connect` which happens strictly after all
calls that close connections should leave us connected to the target.
However concurrent calls to `ConnectionTarget#connect` can overlap, and
today this means that a connection returned from an earlier call may
overwrite one from a later call. The trouble is that the earlier
connection attempt may yield a closed connection (it was concurrent with
the disconnections) so we must not let it supersede the newer one.

With this commit we prevent concurrent connection attempts, which avoids
earlier attempts from overwriting the connections resulting from later
attempts.

When combined with #92546, closes #92029
  • Loading branch information
DaveCTurner authored Jan 3, 2023
1 parent eb8cb10 commit 1a650ec
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 92 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/92558.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 92558
summary: Protect `NodeConnectionsService` from stale conns
area: Network
type: bug
issues:
- 92029
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.ListenableFuture;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.TimeValue;
Expand Down Expand Up @@ -219,6 +220,13 @@ private class ConnectionTarget {
private final AtomicInteger consecutiveFailureCount = new AtomicInteger();
private final AtomicReference<Releasable> connectionRef = new AtomicReference<>();

// all access to these fields is synchronized
private ActionListener<Void> pendingListener;
private boolean connectionInProgress;

// placeholder listener for a fire-and-forget connection attempt
private static final ActionListener<Void> NOOP = ActionListener.noop();

ConnectionTarget(DiscoveryNode discoveryNode) {
this.discoveryNode = discoveryNode;
}
Expand All @@ -229,57 +237,97 @@ private void setConnectionRef(Releasable connectionReleasable) {

Runnable connect(ActionListener<Void> listener) {
return () -> {
final boolean alreadyConnected = transportService.nodeConnected(discoveryNode);
registerListener(listener);
doConnect();
};
}

if (alreadyConnected) {
logger.trace("refreshing connection to {}", discoveryNode);
} else {
logger.debug("connecting to {}", discoveryNode);
private synchronized void registerListener(ActionListener<Void> listener) {
if (listener == null) {
pendingListener = pendingListener == null ? NOOP : pendingListener;
} else if (pendingListener == null || pendingListener == NOOP) {
pendingListener = listener;
} else if (pendingListener instanceof ListenableFuture<Void> listenableFuture) {
listenableFuture.addListener(listener);
} else {
var wrapper = new ListenableFuture<Void>();
wrapper.addListener(pendingListener);
wrapper.addListener(listener);
pendingListener = wrapper;
}
}

private synchronized ActionListener<Void> acquireListener() {
// Avoid concurrent connection attempts because they don't necessarily complete in order otherwise, and out-of-order completion
// might mean we end up disconnected from a node even though we triggered a call to connect() after all close() calls had
// finished.
if (connectionInProgress == false) {
var listener = pendingListener;
if (listener != null) {
pendingListener = null;
connectionInProgress = true;
return listener;
}
}
return null;
}

// It's possible that connectionRef is a reference to an older connection that closed out from under us, but that something
// else has opened a fresh connection to the node. Therefore we always call connectToNode() and update connectionRef.
transportService.connectToNode(discoveryNode, new ActionListener<>() {
@Override
public void onResponse(Releasable connectionReleasable) {
if (alreadyConnected) {
logger.trace("refreshed connection to {}", discoveryNode);
} else {
logger.debug("connected to {}", discoveryNode);
}
consecutiveFailureCount.set(0);
setConnectionRef(connectionReleasable);

final boolean isActive;
synchronized (mutex) {
isActive = targetsByNode.get(discoveryNode) == ConnectionTarget.this;
}
if (isActive == false) {
logger.debug("connected to stale {} - releasing stale connection", discoveryNode);
setConnectionRef(null);
}
if (listener != null) {
listener.onResponse(null);
}
private synchronized void releaseListener() {
assert connectionInProgress;
connectionInProgress = false;
}

private void doConnect() {
var listener = acquireListener();
if (listener == null) {
return;
}

final boolean alreadyConnected = transportService.nodeConnected(discoveryNode);

if (alreadyConnected) {
logger.trace("refreshing connection to {}", discoveryNode);
} else {
logger.debug("connecting to {}", discoveryNode);
}

// It's possible that connectionRef is a reference to an older connection that closed out from under us, but that something else
// has opened a fresh connection to the node. Therefore we always call connectToNode() and update connectionRef.
transportService.connectToNode(discoveryNode, ActionListener.runAfter(new ActionListener<>() {
@Override
public void onResponse(Releasable connectionReleasable) {
if (alreadyConnected) {
logger.trace("refreshed connection to {}", discoveryNode);
} else {
logger.debug("connected to {}", discoveryNode);
}
consecutiveFailureCount.set(0);
setConnectionRef(connectionReleasable);

@Override
public void onFailure(Exception e) {
final int currentFailureCount = consecutiveFailureCount.incrementAndGet();
// only warn every 6th failure
final Level level = currentFailureCount % 6 == 1 ? Level.WARN : Level.DEBUG;
logger.log(
level,
() -> format("failed to connect to %s (tried [%s] times)", discoveryNode, currentFailureCount),
e
);
final boolean isActive;
synchronized (mutex) {
isActive = targetsByNode.get(discoveryNode) == ConnectionTarget.this;
}
if (isActive == false) {
logger.debug("connected to stale {} - releasing stale connection", discoveryNode);
setConnectionRef(null);
if (listener != null) {
listener.onFailure(e);
}
}
});
};
listener.onResponse(null);
}

@Override
public void onFailure(Exception e) {
final int currentFailureCount = consecutiveFailureCount.incrementAndGet();
// only warn every 6th failure
final Level level = currentFailureCount % 6 == 1 ? Level.WARN : Level.DEBUG;
logger.log(level, () -> format("failed to connect to %s (tried [%s] times)", discoveryNode, currentFailureCount), e);
setConnectionRef(null);
listener.onFailure(e);
}
}, () -> {
releaseListener();
transportService.getThreadPool().generic().execute(this::doConnect);
}));
}

void disconnect() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
Expand Down Expand Up @@ -97,9 +99,7 @@ public void testEventuallyConnectsOnlyToAppliedNodes() throws Exception {
final AtomicBoolean keepGoing = new AtomicBoolean(true);
final Thread reconnectionThread = new Thread(() -> {
while (keepGoing.get()) {
final PlainActionFuture<Void> future = new PlainActionFuture<>();
service.ensureConnections(() -> future.onResponse(null));
future.actionGet();
ensureConnections(service);
}
}, "reconnection thread");
reconnectionThread.start();
Expand All @@ -109,49 +109,78 @@ public void testEventuallyConnectsOnlyToAppliedNodes() throws Exception {
final boolean isDisrupting = randomBoolean();
final Thread disruptionThread = new Thread(() -> {
while (isDisrupting && keepGoing.get()) {
final Transport.Connection connection;
try {
connection = transportService.getConnection(randomFrom(allNodes));
} catch (NodeNotConnectedException e) {
continue;
}

final PlainActionFuture<Void> future = new PlainActionFuture<>();
connection.addRemovedListener(future);
connection.close();
future.actionGet(10, TimeUnit.SECONDS);
closeConnection(transportService, randomFrom(allNodes));
}
}, "disruption thread");
disruptionThread.start();

for (int i = 0; i < 10; i++) {
final DiscoveryNodes connectNodes = discoveryNodesFromList(randomSubsetOf(allNodes));
final PlainActionFuture<Void> future = new PlainActionFuture<>();
service.connectToNodes(connectNodes, () -> future.onResponse(null));
future.actionGet(10, TimeUnit.SECONDS);
final DiscoveryNodes disconnectExceptNodes = discoveryNodesFromList(randomSubsetOf(allNodes));
service.disconnectFromNodesExcept(disconnectExceptNodes);
connectToNodes(service, discoveryNodesFromList(randomSubsetOf(allNodes)));
service.disconnectFromNodesExcept(discoveryNodesFromList(randomSubsetOf(allNodes)));
}

final DiscoveryNodes nodes = discoveryNodesFromList(randomSubsetOf(allNodes));
final PlainActionFuture<Void> connectFuture = new PlainActionFuture<>();
service.connectToNodes(nodes, () -> connectFuture.onResponse(null));
connectFuture.actionGet(10, TimeUnit.SECONDS);
connectToNodes(service, nodes);
service.disconnectFromNodesExcept(nodes);

assertTrue(keepGoing.compareAndSet(true, false));
reconnectionThread.join();
disruptionThread.join();

if (isDisrupting) {
final PlainActionFuture<Void> ensureFuture = new PlainActionFuture<>();
service.ensureConnections(() -> ensureFuture.onResponse(null));
ensureFuture.actionGet(10, TimeUnit.SECONDS);
ensureConnections(service);
}

assertConnected(transportService, nodes);
assertBusy(() -> assertConnectedExactlyToNodes(nodes));
}

public void testConcurrentConnectAndDisconnect() throws Exception {
final NodeConnectionsService service = new NodeConnectionsService(Settings.EMPTY, threadPool, transportService);

final AtomicBoolean keepGoing = new AtomicBoolean(true);
final Thread reconnectionThread = new Thread(() -> {
while (keepGoing.get()) {
ensureConnections(service);
}
}, "reconnection thread");
reconnectionThread.start();

final var node = new DiscoveryNode("node", buildNewFakeTransportAddress(), Map.of(), Set.of(), Version.CURRENT);
final var nodes = discoveryNodesFromList(List.of(node));

final Thread disruptionThread = new Thread(() -> {
while (keepGoing.get()) {
closeConnection(transportService, node);
}
}, "disruption thread");
disruptionThread.start();

final var reconnectPermits = new Semaphore(1000);
final var reconnectThreads = 10;
final var reconnectCountDown = new CountDownLatch(reconnectThreads);
for (int i = 0; i < reconnectThreads; i++) {
threadPool.generic().execute(new Runnable() {
@Override
public void run() {
if (reconnectPermits.tryAcquire()) {
service.connectToNodes(nodes, () -> threadPool.generic().execute(this));
} else {
reconnectCountDown.countDown();
}
}
});
}

assertTrue(reconnectCountDown.await(10, TimeUnit.SECONDS));
assertTrue(keepGoing.compareAndSet(true, false));
reconnectionThread.join();
disruptionThread.join();

ensureConnections(service);
assertConnectedExactlyToNodes(nodes);
}

public void testPeriodicReconnection() {
final Settings.Builder settings = Settings.builder();
final long reconnectIntervalMillis;
Expand Down Expand Up @@ -234,30 +263,24 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti
// connect to one node
final DiscoveryNode node0 = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT);
final DiscoveryNodes nodes0 = DiscoveryNodes.builder().add(node0).build();
final PlainActionFuture<Void> future0 = new PlainActionFuture<>();
service.connectToNodes(nodes0, () -> future0.onResponse(null));
future0.actionGet(10, TimeUnit.SECONDS);
connectToNodes(service, nodes0);
assertConnectedExactlyToNodes(nodes0);

// connection attempts to node0 block indefinitely
final CyclicBarrier connectionBarrier = new CyclicBarrier(2);
try {
nodeConnectionBlocks.put(node0, connectionBarrier::await);
nodeConnectionBlocks.put(node0, () -> connectionBarrier.await(10, TimeUnit.SECONDS));
transportService.disconnectFromNode(node0);

// can still connect to another node without blocking
final DiscoveryNode node1 = new DiscoveryNode("node1", buildNewFakeTransportAddress(), Version.CURRENT);
final DiscoveryNodes nodes1 = DiscoveryNodes.builder().add(node1).build();
final DiscoveryNodes nodes01 = DiscoveryNodes.builder(nodes0).add(node1).build();
final PlainActionFuture<Void> future1 = new PlainActionFuture<>();
service.connectToNodes(nodes01, () -> future1.onResponse(null));
future1.actionGet(10, TimeUnit.SECONDS);
connectToNodes(service, nodes01);
assertConnectedExactlyToNodes(nodes1);

// can also disconnect from node0 without blocking
final PlainActionFuture<Void> future2 = new PlainActionFuture<>();
service.connectToNodes(nodes1, () -> future2.onResponse(null));
future2.actionGet(10, TimeUnit.SECONDS);
connectToNodes(service, nodes1);
service.disconnectFromNodesExcept(nodes1);
assertConnectedExactlyToNodes(nodes1);

Expand All @@ -273,17 +296,15 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti

// the reconnection is also blocked but the connection future doesn't wait, it completes straight away
transportService.disconnectFromNode(node0);
final PlainActionFuture<Void> future4 = new PlainActionFuture<>();
service.connectToNodes(nodes01, () -> future4.onResponse(null));
future4.actionGet(10, TimeUnit.SECONDS);
connectToNodes(service, nodes01);
assertConnectedExactlyToNodes(nodes1);

// a blocked reconnection attempt doesn't also block the node from being deregistered
service.disconnectFromNodesExcept(nodes1);
final PlainActionFuture<DiscoveryNode> disconnectFuture1 = new PlainActionFuture<>();
assertTrue(disconnectListenerRef.compareAndSet(null, disconnectFuture1));
connectionBarrier.await();
assertThat(disconnectFuture1.actionGet(10, TimeUnit.SECONDS), equalTo(node0)); // node0 connects briefly, must wait here
assertThat(PlainActionFuture.get(disconnectFuture1 -> {
assertTrue(disconnectListenerRef.compareAndSet(null, disconnectFuture1));
connectionBarrier.await(10, TimeUnit.SECONDS);
}, 10, TimeUnit.SECONDS), equalTo(node0)); // node0 connects briefly, must wait here
assertConnectedExactlyToNodes(nodes1);

// a blocked connection attempt to a new node also doesn't prevent an immediate deregistration
Expand All @@ -294,10 +315,10 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti
service.disconnectFromNodesExcept(nodes1);
assertConnectedExactlyToNodes(nodes1);

final PlainActionFuture<DiscoveryNode> disconnectFuture2 = new PlainActionFuture<>();
assertTrue(disconnectListenerRef.compareAndSet(null, disconnectFuture2));
connectionBarrier.await(10, TimeUnit.SECONDS);
assertThat(disconnectFuture2.actionGet(10, TimeUnit.SECONDS), equalTo(node0)); // node0 connects briefly, must wait here
assertThat(PlainActionFuture.get(disconnectFuture2 -> {
assertTrue(disconnectListenerRef.compareAndSet(null, disconnectFuture2));
connectionBarrier.await(10, TimeUnit.SECONDS);
}, 10, TimeUnit.SECONDS), equalTo(node0)); // node0 connects briefly, must wait here
assertConnectedExactlyToNodes(nodes1);
assertTrue(future5.isDone());
} finally {
Expand All @@ -310,7 +331,7 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti
reason = "testing that DEBUG-level logging is reasonable",
value = "org.elasticsearch.cluster.NodeConnectionsService:DEBUG"
)
public void testDebugLogging() throws IllegalAccessException {
public void testDebugLogging() {
final DeterministicTaskQueue deterministicTaskQueue = new DeterministicTaskQueue();

MockTransport transport = new MockTransport(deterministicTaskQueue.getThreadPool());
Expand Down Expand Up @@ -706,4 +727,23 @@ public RequestHandlers getRequestHandlers() {
return requestHandlers;
}
}

private static void connectToNodes(NodeConnectionsService service, DiscoveryNodes discoveryNodes) {
PlainActionFuture.get(future -> service.connectToNodes(discoveryNodes, () -> future.onResponse(null)), 10, TimeUnit.SECONDS);
}

private static void ensureConnections(NodeConnectionsService service) {
PlainActionFuture.get(future -> service.ensureConnections(() -> future.onResponse(null)), 10, TimeUnit.SECONDS);
}

private static void closeConnection(TransportService transportService, DiscoveryNode discoveryNode) {
try {
final var connection = transportService.getConnection(discoveryNode);
connection.close();
PlainActionFuture.get(connection::addRemovedListener, 10, TimeUnit.SECONDS);
} catch (NodeNotConnectedException e) {
// ok
}
}

}

0 comments on commit 1a650ec

Please sign in to comment.