From e1734838864b77ab422dec507428209814ddd53c Mon Sep 17 00:00:00 2001 From: Boaz Leskes Date: Tue, 7 Feb 2017 22:11:32 +0200 Subject: [PATCH] TransportService.connectToNode should validate remote node ID (#22828) #22194 gave us the ability to open low level temporary connections to remote node based on their address. With this use case out of the way, actual full blown connections should validate the node on the other side, making sure we speak to who we think we speak to. This helps in case where multiple nodes are started on the same host and a quick node restart causes them to swap addresses, which in turn can cause confusion down the road. --- .../TransportClientNodesService.java | 189 +++++++++--------- .../cluster/node/DiscoveryNode.java | 18 +- .../common/CheckedBiConsumer.java | 30 +++ .../discovery/zen/ZenDiscovery.java | 49 +++-- .../transport/ConnectionProfile.java | 12 ++ .../elasticsearch/transport/TcpTransport.java | 35 +++- .../elasticsearch/transport/Transport.java | 16 +- .../transport/TransportService.java | 23 ++- .../node/tasks/CancellableTasksTests.java | 18 +- .../node/tasks/TaskManagerTestCase.java | 30 ++- .../node/tasks/TransportTasksActionTests.java | 14 +- .../transport/FailAndRetryMockTransport.java | 6 +- .../TransportClientHeadersTests.java | 20 +- .../cluster/NodeConnectionsServiceTests.java | 5 +- .../discovery/ZenFaultDetectionTests.java | 74 ++++--- .../zen/PublishClusterStateActionTests.java | 7 +- .../discovery/zen/UnicastZenPingTests.java | 6 +- .../transport/ConnectionProfileTests.java | 15 ++ .../transport/TCPTransportTests.java | 35 ++++ .../transport/TransportActionProxyTests.java | 9 +- .../TransportServiceHandshakeTests.java | 18 ++ .../netty4/Netty4ScheduledPingTests.java | 9 +- .../test/ClusterServiceUtils.java | 9 +- .../test/transport/CapturingTransport.java | 5 +- .../test/transport/MockTransportService.java | 47 +++-- .../AbstractSimpleTransportTestCase.java | 171 ++++++++++++---- .../transport/MockTcpTransport.java | 3 +- .../transport/MockTcpTransportTests.java | 4 +- 28 files changed, 593 insertions(+), 284 deletions(-) create mode 100644 core/src/main/java/org/elasticsearch/common/CheckedBiConsumer.java diff --git a/core/src/main/java/org/elasticsearch/client/transport/TransportClientNodesService.java b/core/src/main/java/org/elasticsearch/client/transport/TransportClientNodesService.java index ea2906dab67fc..dbcf0edef2897 100644 --- a/core/src/main/java/org/elasticsearch/client/transport/TransportClientNodesService.java +++ b/core/src/main/java/org/elasticsearch/client/transport/TransportClientNodesService.java @@ -22,6 +22,7 @@ import com.carrotsearch.hppc.cursors.ObjectCursor; import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.util.Supplier; +import org.apache.lucene.util.IOUtils; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; @@ -38,6 +39,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.common.util.concurrent.FutureUtils; import org.elasticsearch.threadpool.ThreadPool; @@ -46,6 +48,8 @@ import org.elasticsearch.transport.FutureTransportResponseHandler; import org.elasticsearch.transport.NodeDisconnectedException; import org.elasticsearch.transport.NodeNotConnectedException; +import org.elasticsearch.transport.PlainTransportFuture; +import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.TransportException; import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportResponseHandler; @@ -401,51 +405,37 @@ protected void doSample() { HashSet newNodes = new HashSet<>(); HashSet newFilteredNodes = new HashSet<>(); for (DiscoveryNode listedNode : listedNodes) { - if (!transportService.nodeConnected(listedNode)) { - try { - // its a listed node, light connect to it... - logger.trace("connecting to listed node [{}]", listedNode); - transportService.connectToNode(listedNode, LISTED_NODES_PROFILE); - } catch (Exception e) { - logger.info( - (Supplier) - () -> new ParameterizedMessage("failed to connect to node [{}], removed from nodes list", listedNode), e); - hostFailureListener.onNodeDisconnected(listedNode, e); - newFilteredNodes.add(listedNode); - continue; - } - } - try { - LivenessResponse livenessResponse = transportService.submitRequest(listedNode, TransportLivenessAction.NAME, - new LivenessRequest(), - TransportRequestOptions.builder().withType(TransportRequestOptions.Type.STATE).withTimeout(pingTimeout).build(), - new FutureTransportResponseHandler() { - @Override - public LivenessResponse newInstance() { - return new LivenessResponse(); - } - }).txGet(); + try (Transport.Connection connection = transportService.openConnection(listedNode, LISTED_NODES_PROFILE)){ + final PlainTransportFuture handler = new PlainTransportFuture<>( + new FutureTransportResponseHandler() { + @Override + public LivenessResponse newInstance() { + return new LivenessResponse(); + } + }); + transportService.sendRequest(connection, TransportLivenessAction.NAME, new LivenessRequest(), + TransportRequestOptions.builder().withType(TransportRequestOptions.Type.STATE).withTimeout(pingTimeout).build(), + handler); + final LivenessResponse livenessResponse = handler.txGet(); if (!ignoreClusterName && !clusterName.equals(livenessResponse.getClusterName())) { logger.warn("node {} not part of the cluster {}, ignoring...", listedNode, clusterName); newFilteredNodes.add(listedNode); - } else if (livenessResponse.getDiscoveryNode() != null) { + } else { // use discovered information but do keep the original transport address, // so people can control which address is exactly used. DiscoveryNode nodeWithInfo = livenessResponse.getDiscoveryNode(); newNodes.add(new DiscoveryNode(nodeWithInfo.getName(), nodeWithInfo.getId(), nodeWithInfo.getEphemeralId(), nodeWithInfo.getHostName(), nodeWithInfo.getHostAddress(), listedNode.getAddress(), nodeWithInfo.getAttributes(), nodeWithInfo.getRoles(), nodeWithInfo.getVersion())); - } else { - // although we asked for one node, our target may not have completed - // initialization yet and doesn't have cluster nodes - logger.debug("node {} didn't return any discovery info, temporarily using transport discovery node", listedNode); - newNodes.add(listedNode); } + } catch (ConnectTransportException e) { + logger.debug( + (Supplier) + () -> new ParameterizedMessage("failed to connect to node [{}], ignoring...", listedNode), e); + hostFailureListener.onNodeDisconnected(listedNode, e); } catch (Exception e) { logger.info( (Supplier) () -> new ParameterizedMessage("failed to get node info for {}, disconnecting...", listedNode), e); - transportService.disconnectFromNode(listedNode); - hostFailureListener.onNodeDisconnected(listedNode, e); } } @@ -470,78 +460,91 @@ protected void doSample() { final CountDownLatch latch = new CountDownLatch(nodesToPing.size()); final ConcurrentMap clusterStateResponses = ConcurrentCollections.newConcurrentMap(); - for (final DiscoveryNode listedNode : nodesToPing) { - threadPool.executor(ThreadPool.Names.MANAGEMENT).execute(new Runnable() { - @Override - public void run() { - try { - if (!transportService.nodeConnected(listedNode)) { - try { + try { + for (final DiscoveryNode nodeToPing : nodesToPing) { + threadPool.executor(ThreadPool.Names.MANAGEMENT).execute(new AbstractRunnable() { + + /** + * we try to reuse existing connections but if needed we will open a temporary connection + * that will be closed at the end of the execution. + */ + Transport.Connection connectionToClose = null; + + @Override + public void onAfter() { + IOUtils.closeWhileHandlingException(connectionToClose); + } - // if its one of the actual nodes we will talk to, not to listed nodes, fully connect - if (nodes.contains(listedNode)) { - logger.trace("connecting to cluster node [{}]", listedNode); - transportService.connectToNode(listedNode); - } else { - // its a listed node, light connect to it... - logger.trace("connecting to listed node (light) [{}]", listedNode); - transportService.connectToNode(listedNode, LISTED_NODES_PROFILE); - } - } catch (Exception e) { - logger.debug( - (Supplier) - () -> new ParameterizedMessage("failed to connect to node [{}], ignoring...", listedNode), e); - latch.countDown(); - return; + @Override + public void onFailure(Exception e) { + latch.countDown(); + if (e instanceof ConnectTransportException) { + logger.debug((Supplier) + () -> new ParameterizedMessage("failed to connect to node [{}], ignoring...", nodeToPing), e); + hostFailureListener.onNodeDisconnected(nodeToPing, e); + } else { + logger.info( + (Supplier) () -> new ParameterizedMessage( + "failed to get local cluster state info for {}, disconnecting...", nodeToPing), e); + } + } + + @Override + protected void doRun() throws Exception { + Transport.Connection pingConnection = null; + if (nodes.contains(nodeToPing)) { + try { + pingConnection = transportService.getConnection(nodeToPing); + } catch (NodeNotConnectedException e) { + // will use a temp connection } } - transportService.sendRequest(listedNode, ClusterStateAction.NAME, - Requests.clusterStateRequest().clear().nodes(true).local(true), - TransportRequestOptions.builder().withType(TransportRequestOptions.Type.STATE) - .withTimeout(pingTimeout).build(), - new TransportResponseHandler() { - - @Override - public ClusterStateResponse newInstance() { - return new ClusterStateResponse(); - } + if (pingConnection == null) { + logger.trace("connecting to cluster node [{}]", nodeToPing); + connectionToClose = transportService.openConnection(nodeToPing, LISTED_NODES_PROFILE); + pingConnection = connectionToClose; + } + transportService.sendRequest(pingConnection, ClusterStateAction.NAME, + Requests.clusterStateRequest().clear().nodes(true).local(true), + TransportRequestOptions.builder().withType(TransportRequestOptions.Type.STATE) + .withTimeout(pingTimeout).build(), + new TransportResponseHandler() { + + @Override + public ClusterStateResponse newInstance() { + return new ClusterStateResponse(); + } - @Override - public String executor() { - return ThreadPool.Names.SAME; - } + @Override + public String executor() { + return ThreadPool.Names.SAME; + } - @Override - public void handleResponse(ClusterStateResponse response) { - clusterStateResponses.put(listedNode, response); - latch.countDown(); - } + @Override + public void handleResponse(ClusterStateResponse response) { + clusterStateResponses.put(nodeToPing, response); + latch.countDown(); + } - @Override - public void handleException(TransportException e) { - logger.info( - (Supplier) () -> new ParameterizedMessage( - "failed to get local cluster state for {}, disconnecting...", listedNode), e); - transportService.disconnectFromNode(listedNode); + @Override + public void handleException(TransportException e) { + logger.info( + (Supplier) () -> new ParameterizedMessage( + "failed to get local cluster state for {}, disconnecting...", nodeToPing), e); + try { + hostFailureListener.onNodeDisconnected(nodeToPing, e); + } + finally { latch.countDown(); - hostFailureListener.onNodeDisconnected(listedNode, e); } - }); - } catch (Exception e) { - logger.info( - (Supplier)() -> new ParameterizedMessage( - "failed to get local cluster state info for {}, disconnecting...", listedNode), e); - transportService.disconnectFromNode(listedNode); - latch.countDown(); - hostFailureListener.onNodeDisconnected(listedNode, e); + } + }); } - } - }); - } - - try { + }); + } latch.await(); } catch (InterruptedException e) { + Thread.currentThread().interrupt(); return; } diff --git a/core/src/main/java/org/elasticsearch/cluster/node/DiscoveryNode.java b/core/src/main/java/org/elasticsearch/cluster/node/DiscoveryNode.java index ed6f273db44b3..7ecab461753b8 100644 --- a/core/src/main/java/org/elasticsearch/cluster/node/DiscoveryNode.java +++ b/core/src/main/java/org/elasticsearch/cluster/node/DiscoveryNode.java @@ -192,18 +192,24 @@ public DiscoveryNode(String nodeName, String nodeId, String ephemeralId, String /** Creates a DiscoveryNode representing the local node. */ public static DiscoveryNode createLocal(Settings settings, TransportAddress publishAddress, String nodeId) { Map attributes = new HashMap<>(Node.NODE_ATTRIBUTES.get(settings).getAsMap()); - Set roles = new HashSet<>(); + Set roles = getRolesFromSettings(settings); + + return new DiscoveryNode(Node.NODE_NAME_SETTING.get(settings), nodeId, publishAddress, attributes, roles, Version.CURRENT); + } + + /** extract node roles from the given settings */ + public static Set getRolesFromSettings(Settings settings) { + Set roles = new HashSet<>(); if (Node.NODE_INGEST_SETTING.get(settings)) { - roles.add(DiscoveryNode.Role.INGEST); + roles.add(Role.INGEST); } if (Node.NODE_MASTER_SETTING.get(settings)) { - roles.add(DiscoveryNode.Role.MASTER); + roles.add(Role.MASTER); } if (Node.NODE_DATA_SETTING.get(settings)) { - roles.add(DiscoveryNode.Role.DATA); + roles.add(Role.DATA); } - - return new DiscoveryNode(Node.NODE_NAME_SETTING.get(settings), nodeId, publishAddress, attributes, roles, Version.CURRENT); + return roles; } /** diff --git a/core/src/main/java/org/elasticsearch/common/CheckedBiConsumer.java b/core/src/main/java/org/elasticsearch/common/CheckedBiConsumer.java new file mode 100644 index 0000000000000..3f8b76bf3653f --- /dev/null +++ b/core/src/main/java/org/elasticsearch/common/CheckedBiConsumer.java @@ -0,0 +1,30 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.common; + +import java.util.function.BiConsumer; + +/** + * A {@link BiConsumer}-like interface which allows throwing checked exceptions. + */ +@FunctionalInterface +public interface CheckedBiConsumer { + void accept(T t, U u) throws E; +} diff --git a/core/src/main/java/org/elasticsearch/discovery/zen/ZenDiscovery.java b/core/src/main/java/org/elasticsearch/discovery/zen/ZenDiscovery.java index 23a1b117b6c74..1b91a56ec01e5 100644 --- a/core/src/main/java/org/elasticsearch/discovery/zen/ZenDiscovery.java +++ b/core/src/main/java/org/elasticsearch/discovery/zen/ZenDiscovery.java @@ -27,9 +27,9 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterName; -import org.elasticsearch.cluster.ClusterStateTaskExecutor; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterStateTaskConfig; +import org.elasticsearch.cluster.ClusterStateTaskExecutor; import org.elasticsearch.cluster.ClusterStateTaskListener; import org.elasticsearch.cluster.LocalClusterUpdateTask; import org.elasticsearch.cluster.NotMasterException; @@ -51,6 +51,7 @@ import org.elasticsearch.common.settings.Setting.Property; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.discovery.Discovery; import org.elasticsearch.discovery.DiscoverySettings; import org.elasticsearch.discovery.DiscoveryStats; @@ -114,6 +115,7 @@ public class ZenDiscovery extends AbstractLifecycleComponent implements Discover private final NodesFaultDetection nodesFD; private final PublishClusterStateAction publishClusterState; private final MembershipAction membership; + private final ThreadPool threadPool; private final TimeValue pingTimeout; private final TimeValue joinTimeout; @@ -157,6 +159,7 @@ public ZenDiscovery(Settings settings, ThreadPool threadPool, TransportService t this.joinRetryDelay = JOIN_RETRY_DELAY_SETTING.get(settings); this.maxPingsFromAnotherMaster = MAX_PINGS_FROM_ANOTHER_MASTER_SETTING.get(settings); this.sendLeaveRequest = SEND_LEAVE_REQUEST_SETTING.get(settings); + this.threadPool = threadPool; this.masterElectionIgnoreNonMasters = MASTER_ELECTION_IGNORE_NON_MASTER_PINGS_SETTING.get(settings); this.masterElectionWaitForJoinsTimeout = MASTER_ELECTION_WAIT_FOR_JOINS_TIMEOUT_SETTING.get(settings); @@ -190,7 +193,7 @@ public ZenDiscovery(Settings settings, ThreadPool threadPool, TransportService t discoverySettings, clusterService.getClusterName()); this.membership = new MembershipAction(settings, transportService, this::localNode, new MembershipListener()); - this.joinThreadControl = new JoinThreadControl(threadPool); + this.joinThreadControl = new JoinThreadControl(); transportService.registerRequestHandler( DISCOVERY_REJOIN_ACTION_NAME, RejoinClusterRequest::new, ThreadPool.Names.SAME, new RejoinClusterRequestHandler()); @@ -972,21 +975,28 @@ private ClusterStateTaskExecutor.ClusterTasksResult handleAnotherMaster(ClusterS return rejoin(localClusterState, "zen-disco-discovered another master with a new cluster_state [" + otherMaster + "][" + reason + "]"); } else { logger.warn("discovered [{}] which is also master but with an older cluster_state, telling [{}] to rejoin the cluster ([{}])", otherMaster, otherMaster, reason); - try { - // make sure we're connected to this node (connect to node does nothing if we're already connected) - // since the network connections are asymmetric, it may be that we received a state but have disconnected from the node - // in the past (after a master failure, for example) - transportService.connectToNode(otherMaster); - transportService.sendRequest(otherMaster, DISCOVERY_REJOIN_ACTION_NAME, new RejoinClusterRequest(localClusterState.nodes().getLocalNodeId()), new EmptyTransportResponseHandler(ThreadPool.Names.SAME) { - - @Override - public void handleException(TransportException exp) { - logger.warn((Supplier) () -> new ParameterizedMessage("failed to send rejoin request to [{}]", otherMaster), exp); - } - }); - } catch (Exception e) { - logger.warn((Supplier) () -> new ParameterizedMessage("failed to send rejoin request to [{}]", otherMaster), e); - } + // spawn to a background thread to not do blocking operations on the cluster state thread + threadPool.generic().execute(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + logger.warn((Supplier) () -> new ParameterizedMessage("failed to send rejoin request to [{}]", otherMaster), e); + } + + @Override + protected void doRun() throws Exception { + // make sure we're connected to this node (connect to node does nothing if we're already connected) + // since the network connections are asymmetric, it may be that we received a state but have disconnected from the node + // in the past (after a master failure, for example) + transportService.connectToNode(otherMaster); + transportService.sendRequest(otherMaster, DISCOVERY_REJOIN_ACTION_NAME, new RejoinClusterRequest(localNode().getId()), new EmptyTransportResponseHandler(ThreadPool.Names.SAME) { + + @Override + public void handleException(TransportException exp) { + logger.warn((Supplier) () -> new ParameterizedMessage("failed to send rejoin request to [{}]", otherMaster), exp); + } + }); + } + }); return LocalClusterUpdateTask.unchanged(); } } @@ -1136,14 +1146,9 @@ public void onFailure(String source, Exception e) { */ private class JoinThreadControl { - private final ThreadPool threadPool; private final AtomicBoolean running = new AtomicBoolean(false); private final AtomicReference currentJoinThread = new AtomicReference<>(); - JoinThreadControl(ThreadPool threadPool) { - this.threadPool = threadPool; - } - /** returns true if join thread control is started and there is currently an active join thread */ public boolean joinThreadActive() { Thread currentThread = currentJoinThread.get(); diff --git a/core/src/main/java/org/elasticsearch/transport/ConnectionProfile.java b/core/src/main/java/org/elasticsearch/transport/ConnectionProfile.java index 2f193fc33852b..095b7d2d6ed24 100644 --- a/core/src/main/java/org/elasticsearch/transport/ConnectionProfile.java +++ b/core/src/main/java/org/elasticsearch/transport/ConnectionProfile.java @@ -79,6 +79,18 @@ public static class Builder { private TimeValue connectTimeout; private TimeValue handshakeTimeout; + /** create an empty builder */ + public Builder() { + } + + /** copy constructor, using another profile as a base */ + public Builder(ConnectionProfile source) { + handles.addAll(source.getHandles()); + offset = source.getNumConnections(); + handles.forEach(th -> addedTypes.addAll(th.types)); + connectTimeout = source.getConnectTimeout(); + handshakeTimeout = source.getHandshakeTimeout(); + } /** * Sets a connect timeout for this connection profile */ diff --git a/core/src/main/java/org/elasticsearch/transport/TcpTransport.java b/core/src/main/java/org/elasticsearch/transport/TcpTransport.java index d0376867de431..8337875ffff81 100644 --- a/core/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/core/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -27,6 +27,8 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.Version; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.CheckedBiConsumer; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Strings; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.bytes.BytesArray; @@ -207,6 +209,7 @@ static ConnectionProfile buildDefaultConnectionProfile(Settings settings) { int connectionsPerNodePing = CONNECTIONS_PER_NODE_PING.get(settings); ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); builder.setConnectTimeout(TCP_CONNECT_TIMEOUT.get(settings)); + builder.setHandshakeTimeout(TCP_CONNECT_TIMEOUT.get(settings)); builder.addConnections(connectionsPerNodeBulk, TransportRequestOptions.Type.BULK); builder.addConnections(connectionsPerNodePing, TransportRequestOptions.Type.PING); // if we are not master eligible we don't need a dedicated channel to publish the state @@ -442,8 +445,10 @@ public boolean nodeConnected(DiscoveryNode node) { } @Override - public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) { - connectionProfile = connectionProfile == null ? defaultConnectionProfile : connectionProfile; + public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) + throws ConnectTransportException { + connectionProfile = resolveConnectionProfile(connectionProfile, defaultConnectionProfile); if (node == null) { throw new ConnectTransportException(null, "can't connect to a null node"); } @@ -458,10 +463,12 @@ public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfil try { try { nodeChannels = openConnection(node, connectionProfile); + connectionValidator.accept(nodeChannels, connectionProfile); } catch (Exception e) { logger.trace( (Supplier) () -> new ParameterizedMessage( "failed to connect to [{}], cleaning dangling connections", node), e); + IOUtils.closeWhileHandlingException(nodeChannels); throw e; } // we acquire a connection lock, so no way there is an existing connection @@ -481,6 +488,29 @@ public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfil } } + /** + * takes a {@link ConnectionProfile} that have been passed as a parameter to the public methods + * and resolves it to a fully specified (i.e., no nulls) profile + */ + static ConnectionProfile resolveConnectionProfile(@Nullable ConnectionProfile connectionProfile, + ConnectionProfile defaultConnectionProfile) { + Objects.requireNonNull(defaultConnectionProfile); + if (connectionProfile == null) { + return defaultConnectionProfile; + } else if (connectionProfile.getConnectTimeout() != null && connectionProfile.getHandshakeTimeout() != null) { + return connectionProfile; + } else { + ConnectionProfile.Builder builder = new ConnectionProfile.Builder(connectionProfile); + if (connectionProfile.getConnectTimeout() == null) { + builder.setConnectTimeout(defaultConnectionProfile.getConnectTimeout()); + } + if (connectionProfile.getHandshakeTimeout() == null) { + builder.setHandshakeTimeout(defaultConnectionProfile.getHandshakeTimeout()); + } + return builder.build(); + } + } + @Override public final NodeChannels openConnection(DiscoveryNode node, ConnectionProfile connectionProfile) throws IOException { if (node == null) { @@ -488,6 +518,7 @@ public final NodeChannels openConnection(DiscoveryNode node, ConnectionProfile c } boolean success = false; NodeChannels nodeChannels = null; + connectionProfile = resolveConnectionProfile(connectionProfile, defaultConnectionProfile); globalLock.readLock().lock(); // ensure we don't open connections while we are closing try { ensureOpen(); diff --git a/core/src/main/java/org/elasticsearch/transport/Transport.java b/core/src/main/java/org/elasticsearch/transport/Transport.java index ef31aba5569d1..6636a75ff13b6 100644 --- a/core/src/main/java/org/elasticsearch/transport/Transport.java +++ b/core/src/main/java/org/elasticsearch/transport/Transport.java @@ -21,6 +21,7 @@ import org.elasticsearch.Version; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.component.LifecycleComponent; @@ -71,9 +72,11 @@ public interface Transport extends LifecycleComponent { boolean nodeConnected(DiscoveryNode node); /** - * Connects to a node with the given connection profile. If the node is already connected this method has no effect + * Connects to a node with the given connection profile. If the node is already connected this method has no effect. + * Once a successful is established, it can be validated before being exposed. */ - void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) throws ConnectTransportException; + void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) throws ConnectTransportException; /** * Disconnected from the given node, if not connected, will do nothing. @@ -102,15 +105,16 @@ default CircuitBreaker getInFlightRequestBreaker() { * implementation. * * @throws NodeNotConnectedException if the node is not connected - * @see #connectToNode(DiscoveryNode, ConnectionProfile) + * @see #connectToNode(DiscoveryNode, ConnectionProfile, CheckedBiConsumer) */ Connection getConnection(DiscoveryNode node); /** - * Opens a new connection to the given node and returns it. In contrast to {@link #connectToNode(DiscoveryNode, ConnectionProfile)} - * the returned connection is not managed by the transport implementation. This connection must be closed once it's not needed anymore. + * Opens a new connection to the given node and returns it. In contrast to + * {@link #connectToNode(DiscoveryNode, ConnectionProfile, CheckedBiConsumer)} the returned connection is not managed by + * the transport implementation. This connection must be closed once it's not needed anymore. * This connection type can be used to execute a handshake between two nodes before the node will be published via - * {@link #connectToNode(DiscoveryNode, ConnectionProfile)}. + * {@link #connectToNode(DiscoveryNode, ConnectionProfile, CheckedBiConsumer)}. */ Connection openConnection(DiscoveryNode node, ConnectionProfile profile) throws IOException; diff --git a/core/src/main/java/org/elasticsearch/transport/TransportService.java b/core/src/main/java/org/elasticsearch/transport/TransportService.java index f7f3c932a0d07..a7e4cd23057f3 100644 --- a/core/src/main/java/org/elasticsearch/transport/TransportService.java +++ b/core/src/main/java/org/elasticsearch/transport/TransportService.java @@ -73,7 +73,7 @@ public class TransportService extends AbstractLifecycleComponent { public static final String DIRECT_RESPONSE_PROFILE = ".direct"; - private static final String HANDSHAKE_ACTION_NAME = "internal:transport/handshake"; + public static final String HANDSHAKE_ACTION_NAME = "internal:transport/handshake"; private final CountDownLatch blockIncomingRequestsLatch = new CountDownLatch(1); protected final Transport transport; @@ -130,7 +130,7 @@ public DiscoveryNode getNode() { @Override public void sendRequest(long requestId, String action, TransportRequest request, TransportRequestOptions options) throws IOException, TransportException { - sendLocalRequest(requestId, action, request); + sendLocalRequest(requestId, action, request, options); } @Override @@ -206,6 +206,7 @@ protected void doStart() { HANDSHAKE_ACTION_NAME, () -> HandshakeRequest.INSTANCE, ThreadPool.Names.SAME, + false, false, (request, channel) -> channel.sendResponse( new HandshakeResponse(localNode, clusterName, localNode.getVersion()))); } @@ -311,7 +312,13 @@ public void connectToNode(final DiscoveryNode node, ConnectionProfile connection if (isLocalNode(node)) { return; } - transport.connectToNode(node, connectionProfile); + transport.connectToNode(node, connectionProfile, (newConnection, actualProfile) -> { + // We don't validate cluster names to allow for tribe node connections. + final DiscoveryNode remote = handshake(newConnection, actualProfile.getHandshakeTimeout().millis(), cn -> true); + if (node.equals(remote) == false) { + throw new ConnectTransportException(node, "handshake failed. unexpected remote node " + remote); + } + }); } /** @@ -397,7 +404,7 @@ private HandshakeRequest() { } - static class HandshakeResponse extends TransportResponse { + public static class HandshakeResponse extends TransportResponse { private DiscoveryNode discoveryNode; private ClusterName clusterName; private Version version; @@ -405,7 +412,7 @@ static class HandshakeResponse extends TransportResponse { HandshakeResponse() { } - HandshakeResponse(DiscoveryNode discoveryNode, ClusterName clusterName, Version version) { + public HandshakeResponse(DiscoveryNode discoveryNode, ClusterName clusterName, Version version) { this.discoveryNode = discoveryNode; this.version = version; this.clusterName = clusterName; @@ -599,9 +606,11 @@ protected void doRun() throws Exception { } } - private void sendLocalRequest(long requestId, final String action, final TransportRequest request) { + private void sendLocalRequest(long requestId, final String action, final TransportRequest request, TransportRequestOptions options) { final DirectResponseChannel channel = new DirectResponseChannel(logger, localNode, action, requestId, adapter, threadPool); try { + adapter.onRequestSent(localNode, requestId, action, request, options); + adapter.onRequestReceived(requestId, action); final RequestHandlerRegistry reg = adapter.getRequestHandler(action); if (reg == null) { throw new ActionNotFoundTransportException("Action [" + action + "] not found"); @@ -1080,6 +1089,7 @@ public void sendResponse(TransportResponse response) throws IOException { @Override public void sendResponse(final TransportResponse response, TransportResponseOptions options) throws IOException { + adapter.onResponseSent(requestId, action, response, options); final TransportResponseHandler handler = adapter.onResponseReceived(requestId); // ignore if its null, the adapter logs it if (handler != null) { @@ -1103,6 +1113,7 @@ protected void processResponse(TransportResponseHandler handler, TransportRespon @Override public void sendResponse(Exception exception) throws IOException { + adapter.onResponseSent(requestId, action, exception); final TransportResponseHandler handler = adapter.onResponseReceived(requestId); // ignore if its null, the adapter logs it if (handler != null) { diff --git a/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksTests.java b/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksTests.java index decff2ffc37fb..c28fddf68ad72 100644 --- a/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksTests.java +++ b/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksTests.java @@ -212,7 +212,7 @@ private Task startCancellableTestNodesAction(boolean waitForActionToStart, Colle CancellableTestNodesAction[] actions = new CancellableTestNodesAction[nodesCount]; for (int i = 0; i < testNodes.length; i++) { boolean shouldBlock = blockOnNodes.contains(testNodes[i]); - logger.info("The action in the node [{}] should block: [{}]", testNodes[i].discoveryNode.getId(), shouldBlock); + logger.info("The action in the node [{}] should block: [{}]", testNodes[i].getNodeId(), shouldBlock); actions[i] = new CancellableTestNodesAction(CLUSTER_SETTINGS, "testAction", threadPool, testNodes[i] .clusterService, testNodes[i].transportService, shouldBlock, actionLatch); } @@ -251,7 +251,7 @@ public void onFailure(Exception e) { // Cancel main task CancelTasksRequest request = new CancelTasksRequest(); request.setReason("Testing Cancellation"); - request.setTaskId(new TaskId(testNodes[0].discoveryNode.getId(), mainTask.getId())); + request.setTaskId(new TaskId(testNodes[0].getNodeId(), mainTask.getId())); // And send the cancellation request to a random node CancelTasksResponse response = testNodes[randomIntBetween(0, testNodes.length - 1)].transportCancelTasksAction.execute(request) .get(); @@ -288,7 +288,7 @@ public void onFailure(Exception e) { // Make sure that tasks are no longer running ListTasksResponse listTasksResponse = testNodes[randomIntBetween(0, testNodes.length - 1)] .transportListTasksAction.execute(new ListTasksRequest().setTaskId( - new TaskId(testNodes[0].discoveryNode.getId(), mainTask.getId()))).get(); + new TaskId(testNodes[0].getNodeId(), mainTask.getId()))).get(); assertEquals(0, listTasksResponse.getTasks().size()); // Make sure that there are no leftover bans, the ban removal is async, so we might return from the cancellation @@ -323,7 +323,7 @@ public void onFailure(Exception e) { // Cancel all child tasks without cancelling the main task, which should quit on its own CancelTasksRequest request = new CancelTasksRequest(); request.setReason("Testing Cancellation"); - request.setParentTaskId(new TaskId(testNodes[0].discoveryNode.getId(), mainTask.getId())); + request.setParentTaskId(new TaskId(testNodes[0].getNodeId(), mainTask.getId())); // And send the cancellation request to a random node CancelTasksResponse response = testNodes[randomIntBetween(1, testNodes.length - 1)].transportCancelTasksAction.execute(request) .get(); @@ -339,7 +339,7 @@ public void onFailure(Exception e) { // Make sure that main task is no longer running ListTasksResponse listTasksResponse = testNodes[randomIntBetween(0, testNodes.length - 1)] .transportListTasksAction.execute(new ListTasksRequest().setTaskId( - new TaskId(testNodes[0].discoveryNode.getId(), mainTask.getId()))).get(); + new TaskId(testNodes[0].getNodeId(), mainTask.getId()))).get(); assertEquals(0, listTasksResponse.getTasks().size()); } catch (ExecutionException | InterruptedException ex) { @@ -374,7 +374,7 @@ public void onFailure(Exception e) { } }); - String mainNode = testNodes[0].discoveryNode.getId(); + String mainNode = testNodes[0].getNodeId(); // Make sure that tasks are running ListTasksResponse listTasksResponse = testNodes[randomIntBetween(0, testNodes.length - 1)] @@ -384,12 +384,12 @@ public void onFailure(Exception e) { // Simulate the coordinating node leaving the cluster DiscoveryNode[] discoveryNodes = new DiscoveryNode[testNodes.length - 1]; for (int i = 1; i < testNodes.length; i++) { - discoveryNodes[i - 1] = testNodes[i].discoveryNode; + discoveryNodes[i - 1] = testNodes[i].discoveryNode(); } DiscoveryNode master = discoveryNodes[0]; for (int i = 1; i < testNodes.length; i++) { // Notify only nodes that should remain in the cluster - setState(testNodes[i].clusterService, ClusterStateCreationUtils.state(testNodes[i].discoveryNode, master, discoveryNodes)); + setState(testNodes[i].clusterService, ClusterStateCreationUtils.state(testNodes[i].discoveryNode(), master, discoveryNodes)); } if (simulateBanBeforeLeaving) { @@ -397,7 +397,7 @@ public void onFailure(Exception e) { // Simulate issuing cancel request on the node that is about to leave the cluster CancelTasksRequest request = new CancelTasksRequest(); request.setReason("Testing Cancellation"); - request.setTaskId(new TaskId(testNodes[0].discoveryNode.getId(), mainTask.getId())); + request.setTaskId(new TaskId(testNodes[0].getNodeId(), mainTask.getId())); // And send the cancellation request to a random node CancelTasksResponse response = testNodes[0].transportCancelTasksAction.execute(request).get(); logger.info("--> Done simulating issuing cancel request on the node that is about to leave the cluster"); diff --git a/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java b/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java index 2724e227892a9..17be09740815f 100644 --- a/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java +++ b/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java @@ -18,6 +18,7 @@ */ package org.elasticsearch.action.admin.cluster.node.tasks; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.Version; import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.admin.cluster.node.tasks.cancel.TransportCancelTasksAction; @@ -39,6 +40,8 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.BoundTransportAddress; +import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESTestCase; @@ -56,6 +59,7 @@ import java.util.HashSet; import java.util.List; import java.util.concurrent.TimeUnit; +import java.util.function.Function; import java.util.function.Supplier; import static java.util.Collections.emptyMap; @@ -167,12 +171,15 @@ protected boolean accumulateExceptions() { public static class TestNode implements Releasable { public TestNode(String name, ThreadPool threadPool, Settings settings) { - clusterService = createClusterService(threadPool); + final Function boundTransportAddressDiscoveryNodeFunction = + address -> { + discoveryNode.set(new DiscoveryNode(name, address.publishAddress(), emptyMap(), emptySet(), Version.CURRENT)); + return discoveryNode.get(); + }; transportService = new TransportService(settings, new LocalTransport(settings, threadPool, new NamedWriteableRegistry(Collections.emptyList()), new NoneCircuitBreakerService()), threadPool, TransportService.NOOP_TRANSPORT_INTERCEPTOR, - x -> clusterService.localNode(), null) { - + boundTransportAddressDiscoveryNodeFunction, null) { @Override protected TaskManager createTaskManager() { if (MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING.get(settings)) { @@ -183,9 +190,8 @@ protected TaskManager createTaskManager() { } }; transportService.start(); + clusterService = createClusterService(threadPool, discoveryNode.get()); clusterService.addStateApplier(transportService.getTaskManager()); - discoveryNode = new DiscoveryNode(name, transportService.boundAddress().publishAddress(), - emptyMap(), emptySet(), Version.CURRENT); IndexNameExpressionResolver indexNameExpressionResolver = new IndexNameExpressionResolver(settings); ActionFilters actionFilters = new ActionFilters(emptySet()); transportListTasksAction = new TransportListTasksAction(settings, threadPool, clusterService, transportService, @@ -197,7 +203,7 @@ protected TaskManager createTaskManager() { public final ClusterService clusterService; public final TransportService transportService; - public final DiscoveryNode discoveryNode; + private final SetOnce discoveryNode = new SetOnce<>(); public final TransportListTasksAction transportListTasksAction; public final TransportCancelTasksAction transportCancelTasksAction; @@ -208,22 +214,24 @@ public void close() { } public String getNodeId() { - return discoveryNode.getId(); + return discoveryNode().getId(); } + + public DiscoveryNode discoveryNode() { return discoveryNode.get(); } } public static void connectNodes(TestNode... nodes) { DiscoveryNode[] discoveryNodes = new DiscoveryNode[nodes.length]; for (int i = 0; i < nodes.length; i++) { - discoveryNodes[i] = nodes[i].discoveryNode; + discoveryNodes[i] = nodes[i].discoveryNode(); } DiscoveryNode master = discoveryNodes[0]; for (TestNode node : nodes) { - setState(node.clusterService, ClusterStateCreationUtils.state(node.discoveryNode, master, discoveryNodes)); + setState(node.clusterService, ClusterStateCreationUtils.state(node.discoveryNode(), master, discoveryNodes)); } for (TestNode nodeA : nodes) { for (TestNode nodeB : nodes) { - nodeA.transportService.connectToNode(nodeB.discoveryNode); + nodeA.transportService.connectToNode(nodeB.discoveryNode()); } } } @@ -231,7 +239,7 @@ public static void connectNodes(TestNode... nodes) { public static RecordingTaskManagerListener[] setupListeners(TestNode[] nodes, String... actionMasks) { RecordingTaskManagerListener[] listeners = new RecordingTaskManagerListener[nodes.length]; for (int i = 0; i < nodes.length; i++) { - listeners[i] = new RecordingTaskManagerListener(nodes[i].discoveryNode.getId(), actionMasks); + listeners[i] = new RecordingTaskManagerListener(nodes[i].getNodeId(), actionMasks); ((MockTaskManager) (nodes[i].transportService.getTaskManager())).addListener(listeners[i]); } return listeners; diff --git a/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java b/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java index 07859070d1098..4e624164fa051 100644 --- a/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java +++ b/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java @@ -310,7 +310,7 @@ protected NodeResponse nodeOperation(NodeRequest request) { Thread.currentThread().interrupt(); } logger.info("Action on node {} finished", node); - return new NodeResponse(testNodes[node].discoveryNode); + return new NodeResponse(testNodes[node].discoveryNode()); } }; } @@ -370,10 +370,10 @@ public void onFailure(Exception e) { assertEquals(testNodes.length, response.getPerNodeTasks().size()); // Coordinating node - assertEquals(2, response.getPerNodeTasks().get(testNodes[0].discoveryNode.getId()).size()); + assertEquals(2, response.getPerNodeTasks().get(testNodes[0].getNodeId()).size()); // Other nodes node for (int i = 1; i < testNodes.length; i++) { - assertEquals(1, response.getPerNodeTasks().get(testNodes[i].discoveryNode.getId()).size()); + assertEquals(1, response.getPerNodeTasks().get(testNodes[i].getNodeId()).size()); } // There should be a single main task when grouped by tasks assertEquals(1, response.getTaskGroups().size()); @@ -535,7 +535,7 @@ public void onFailure(Exception e) { // Try to cancel main task using action name CancelTasksRequest request = new CancelTasksRequest(); - request.setNodes(testNodes[0].discoveryNode.getId()); + request.setNodes(testNodes[0].getNodeId()); request.setReason("Testing Cancellation"); request.setActions(actionName); CancelTasksResponse response = testNodes[randomIntBetween(0, testNodes.length - 1)].transportCancelTasksAction.execute(request) @@ -550,7 +550,7 @@ public void onFailure(Exception e) { // Try to cancel main task using id request = new CancelTasksRequest(); request.setReason("Testing Cancellation"); - request.setTaskId(new TaskId(testNodes[0].discoveryNode.getId(), task.getId())); + request.setTaskId(new TaskId(testNodes[0].getNodeId(), task.getId())); response = testNodes[randomIntBetween(0, testNodes.length - 1)].transportCancelTasksAction.execute(request).get(); // Shouldn't match any tasks since testAction doesn't support cancellation @@ -766,11 +766,11 @@ public void testTasksToXContentGrouping() throws Exception { byNodes = (Map) byNodes.get("nodes"); // One element on the top level assertEquals(testNodes.length, byNodes.size()); - Map firstNode = (Map) byNodes.get(testNodes[0].discoveryNode.getId()); + Map firstNode = (Map) byNodes.get(testNodes[0].getNodeId()); firstNode = (Map) firstNode.get("tasks"); assertEquals(2, firstNode.size()); // two tasks for the first node for (int i = 1; i < testNodes.length; i++) { - Map otherNode = (Map) byNodes.get(testNodes[i].discoveryNode.getId()); + Map otherNode = (Map) byNodes.get(testNodes[i].getNodeId()); otherNode = (Map) otherNode.get("tasks"); assertEquals(1, otherNode.size()); // one tasks for the all other nodes } diff --git a/core/src/test/java/org/elasticsearch/client/transport/FailAndRetryMockTransport.java b/core/src/test/java/org/elasticsearch/client/transport/FailAndRetryMockTransport.java index 63d2aaf7c0852..9827d2162891c 100644 --- a/core/src/test/java/org/elasticsearch/client/transport/FailAndRetryMockTransport.java +++ b/core/src/test/java/org/elasticsearch/client/transport/FailAndRetryMockTransport.java @@ -26,6 +26,7 @@ import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.component.Lifecycle; import org.elasticsearch.common.component.LifecycleListener; import org.elasticsearch.common.settings.Settings; @@ -48,7 +49,6 @@ import java.util.Random; import java.util.Set; import java.util.concurrent.CopyOnWriteArraySet; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; @@ -187,7 +187,9 @@ public boolean nodeConnected(DiscoveryNode node) { } @Override - public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) throws ConnectTransportException { + public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) + throws ConnectTransportException { } diff --git a/core/src/test/java/org/elasticsearch/client/transport/TransportClientHeadersTests.java b/core/src/test/java/org/elasticsearch/client/transport/TransportClientHeadersTests.java index 291abc0e84218..b7bbf1009d632 100644 --- a/core/src/test/java/org/elasticsearch/client/transport/TransportClientHeadersTests.java +++ b/core/src/test/java/org/elasticsearch/client/transport/TransportClientHeadersTests.java @@ -49,14 +49,13 @@ import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.transport.TransportResponseHandler; +import org.elasticsearch.transport.TransportService; import java.util.Collections; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import static org.hamcrest.Matchers.is; - public class TransportClientHeadersTests extends AbstractClientHeadersTestCase { private MockTransportService transportService; @@ -149,15 +148,14 @@ public void sendRequest(Transport.Connection conne TransportRequest request, TransportRequestOptions options, TransportResponseHandler handler) { + final ClusterName clusterName = new ClusterName("cluster1"); if (TransportLivenessAction.NAME.equals(action)) { assertHeaders(threadPool); ((TransportResponseHandler) handler).handleResponse( - new LivenessResponse(new ClusterName("cluster1"), connection.getNode())); - return; - } - if (ClusterStateAction.NAME.equals(action)) { + new LivenessResponse(clusterName, connection.getNode())); + } else if (ClusterStateAction.NAME.equals(action)) { assertHeaders(threadPool); - ClusterName cluster1 = new ClusterName("cluster1"); + ClusterName cluster1 = clusterName; ClusterState.Builder builder = ClusterState.builder(cluster1); //the sniffer detects only data nodes builder.nodes(DiscoveryNodes.builder().add(new DiscoveryNode("node_id", "someId", "some_ephemeralId_id", @@ -166,10 +164,12 @@ public void sendRequest(Transport.Connection conne ((TransportResponseHandler) handler) .handleResponse(new ClusterStateResponse(cluster1, builder.build())); clusterStateLatch.countDown(); - return; + } else if (TransportService.HANDSHAKE_ACTION_NAME .equals(action)) { + ((TransportResponseHandler) handler).handleResponse( + new TransportService.HandshakeResponse(connection.getNode(), clusterName, connection.getNode().getVersion())); + } else { + handler.handleException(new TransportException("", new InternalException(action))); } - - handler.handleException(new TransportException("", new InternalException(action))); } }; } diff --git a/core/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java b/core/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java index adb463be5ddb5..d44b027f8ca95 100644 --- a/core/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java +++ b/core/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.Version; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.component.Lifecycle; import org.elasticsearch.common.component.LifecycleListener; @@ -205,7 +206,9 @@ public boolean nodeConnected(DiscoveryNode node) { } @Override - public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) throws ConnectTransportException { + public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) + throws ConnectTransportException { if (connectionProfile == null) { if (connectedNodes.contains(node) == false && randomConnectionExceptions && randomBoolean()) { throw new ConnectTransportException(node, "simulated"); diff --git a/core/src/test/java/org/elasticsearch/discovery/ZenFaultDetectionTests.java b/core/src/test/java/org/elasticsearch/discovery/ZenFaultDetectionTests.java index 0a8a3132012d8..451d9babe6bd9 100644 --- a/core/src/test/java/org/elasticsearch/discovery/ZenFaultDetectionTests.java +++ b/core/src/test/java/org/elasticsearch/discovery/ZenFaultDetectionTests.java @@ -37,6 +37,7 @@ import org.elasticsearch.discovery.zen.NodesFaultDetection; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.breaker.HierarchyCircuitBreakerService; +import org.elasticsearch.node.Node; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.threadpool.TestThreadPool; @@ -56,8 +57,6 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import static java.util.Collections.emptyMap; -import static java.util.Collections.emptySet; import static java.util.Collections.singleton; import static org.elasticsearch.test.ClusterServiceUtils.createClusterService; import static org.elasticsearch.test.ClusterServiceUtils.setState; @@ -73,10 +72,12 @@ public class ZenFaultDetectionTests extends ESTestCase { protected static final Version version0 = Version.fromId(/*0*/99); protected DiscoveryNode nodeA; protected MockTransportService serviceA; + private Settings settingsA; protected static final Version version1 = Version.fromId(199); protected DiscoveryNode nodeB; protected MockTransportService serviceB; + private Settings settingsB; @Override @Before @@ -87,17 +88,19 @@ public void setUp() throws Exception { .build(); ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); threadPool = new TestThreadPool(getClass().getName()); - clusterServiceA = createClusterService(threadPool); - clusterServiceB = createClusterService(threadPool); circuitBreakerService = new HierarchyCircuitBreakerService(settings, clusterSettings); - serviceA = build(Settings.builder().put("name", "TS_A").build(), version0); - nodeA = new DiscoveryNode("TS_A", "TS_A", serviceA.boundAddress().publishAddress(), emptyMap(), emptySet(), version0); - serviceB = build(Settings.builder().put("name", "TS_B").build(), version1); - nodeB = new DiscoveryNode("TS_B", "TS_B", serviceB.boundAddress().publishAddress(), emptyMap(), emptySet(), version1); + settingsA = Settings.builder().put("node.name", "TS_A").put(settings).build(); + serviceA = build(settingsA, version0); + nodeA = serviceA.getLocalDiscoNode(); + settingsB = Settings.builder().put("node.name", "TS_B").put(settings).build(); + serviceB = build(settingsB, version1); + nodeB = serviceB.getLocalDiscoNode(); + clusterServiceA = createClusterService(settingsA, threadPool, nodeA); + clusterServiceB = createClusterService(settingsB, threadPool, nodeB); // wait till all nodes are properly connected and the event has been sent, so tests in this class // will not get this callback called on the connections done in this setup - final CountDownLatch latch = new CountDownLatch(4); + final CountDownLatch latch = new CountDownLatch(2); TransportConnectionListener waitForConnection = new TransportConnectionListener() { @Override public void onNodeConnected(DiscoveryNode node) { @@ -136,18 +139,22 @@ public void tearDown() throws Exception { protected MockTransportService build(Settings settings, Version version) { NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); MockTransportService transportService = - new MockTransportService( - Settings.builder() - // trace zenfd actions but keep the default otherwise - .put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), singleton(TransportLivenessAction.NAME)) - .build(), - new LocalTransport(settings, threadPool, namedWriteableRegistry, circuitBreakerService) { - @Override - protected Version getVersion() { - return version; - } - }, - threadPool, TransportService.NOOP_TRANSPORT_INTERCEPTOR, null); + new MockTransportService( + Settings.builder() + // trace zenfd actions but keep the default otherwise + .put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), singleton(TransportLivenessAction.NAME)) + .build(), + new LocalTransport(settings, threadPool, namedWriteableRegistry, circuitBreakerService) { + @Override + protected Version getVersion() { + return version; + } + }, + threadPool, TransportService.NOOP_TRANSPORT_INTERCEPTOR, + (boundAddress) -> + new DiscoveryNode(Node.NODE_NAME_SETTING.get(settings), boundAddress.publishAddress(), + Node.NODE_ATTRIBUTES.get(settings).getAsMap(), DiscoveryNode.getRolesFromSettings(settings), version), + null); transportService.start(); transportService.acceptIncomingRequests(); return transportService; @@ -172,15 +179,17 @@ private DiscoveryNodes buildNodesForB(boolean master) { } public void testNodesFaultDetectionConnectOnDisconnect() throws InterruptedException { - Settings.Builder settings = Settings.builder(); boolean shouldRetry = randomBoolean(); // make sure we don't ping again after the initial ping - settings.put(FaultDetection.CONNECT_ON_NETWORK_DISCONNECT_SETTING.getKey(), shouldRetry) - .put(FaultDetection.PING_INTERVAL_SETTING.getKey(), "5m"); + final Settings pingSettings = Settings.builder() + .put(FaultDetection.CONNECT_ON_NETWORK_DISCONNECT_SETTING.getKey(), shouldRetry) + .put(FaultDetection.PING_INTERVAL_SETTING.getKey(), "5m").build(); ClusterState clusterState = ClusterState.builder(new ClusterName("test")).nodes(buildNodesForA(true)).build(); - NodesFaultDetection nodesFDA = new NodesFaultDetection(settings.build(), threadPool, serviceA, clusterState.getClusterName()); + NodesFaultDetection nodesFDA = new NodesFaultDetection(Settings.builder().put(settingsA).put(pingSettings).build(), + threadPool, serviceA, clusterState.getClusterName()); nodesFDA.setLocalNode(nodeA); - NodesFaultDetection nodesFDB = new NodesFaultDetection(settings.build(), threadPool, serviceB, clusterState.getClusterName()); + NodesFaultDetection nodesFDB = new NodesFaultDetection(Settings.builder().put(settingsB).put(pingSettings).build(), + threadPool, serviceB, clusterState.getClusterName()); nodesFDB.setLocalNode(nodeB); final CountDownLatch pingSent = new CountDownLatch(1); nodesFDB.addListener(new NodesFaultDetection.Listener() { @@ -262,13 +271,12 @@ public void testMasterFaultDetectionConnectOnDisconnect() throws InterruptedExce } public void testMasterFaultDetectionNotSizeLimited() throws InterruptedException { - Settings.Builder settings = Settings.builder(); boolean shouldRetry = randomBoolean(); ClusterName clusterName = new ClusterName(randomAsciiOfLengthBetween(3, 20)); - settings + final Settings settings = Settings.builder() .put(FaultDetection.CONNECT_ON_NETWORK_DISCONNECT_SETTING.getKey(), shouldRetry) .put(FaultDetection.PING_INTERVAL_SETTING.getKey(), "1s") - .put("cluster.name", clusterName.value()); + .put("cluster.name", clusterName.value()).build(); final ClusterState stateNodeA = ClusterState.builder(clusterName).nodes(buildNodesForA(false)).build(); setState(clusterServiceA, stateNodeA); @@ -280,15 +288,15 @@ public void testMasterFaultDetectionNotSizeLimited() throws InterruptedException serviceA.addTracer(pingProbeA); serviceB.addTracer(pingProbeB); - MasterFaultDetection masterFDNodeA = new MasterFaultDetection(settings.build(), threadPool, serviceA, - clusterServiceA); + MasterFaultDetection masterFDNodeA = new MasterFaultDetection(Settings.builder().put(settingsA).put(settings).build(), + threadPool, serviceA, clusterServiceA); masterFDNodeA.start(nodeB, "test"); final ClusterState stateNodeB = ClusterState.builder(clusterName).nodes(buildNodesForB(true)).build(); setState(clusterServiceB, stateNodeB); - MasterFaultDetection masterFDNodeB = new MasterFaultDetection(settings.build(), threadPool, serviceB, - clusterServiceB); + MasterFaultDetection masterFDNodeB = new MasterFaultDetection(Settings.builder().put(settingsB).put(settings).build(), + threadPool, serviceB, clusterServiceB); masterFDNodeB.start(nodeB, "test"); // let's do a few pings diff --git a/core/src/test/java/org/elasticsearch/discovery/zen/PublishClusterStateActionTests.java b/core/src/test/java/org/elasticsearch/discovery/zen/PublishClusterStateActionTests.java index 426c498c5db13..b6f7293561abe 100644 --- a/core/src/test/java/org/elasticsearch/discovery/zen/PublishClusterStateActionTests.java +++ b/core/src/test/java/org/elasticsearch/discovery/zen/PublishClusterStateActionTests.java @@ -43,7 +43,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.discovery.Discovery; import org.elasticsearch.discovery.DiscoverySettings; -import org.elasticsearch.env.NodeEnvironment; import org.elasticsearch.node.Node; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.junit.annotations.TestLogging; @@ -168,11 +167,10 @@ public static MockNode createMockNode(String name, final Settings basSettings, @ .build(); MockTransportService service = buildTransportService(settings, threadPool); - DiscoveryNode discoveryNode = DiscoveryNode.createLocal(settings, service.boundAddress().publishAddress(), - NodeEnvironment.generateNodeId(settings)); + DiscoveryNode discoveryNode = service.getLocalDiscoNode(); MockNode node = new MockNode(discoveryNode, service, listener, logger); node.action = buildPublishClusterStateAction(settings, service, () -> node.clusterState, node); - final CountDownLatch latch = new CountDownLatch(nodes.size() * 2 + 1); + final CountDownLatch latch = new CountDownLatch(nodes.size() * 2); TransportConnectionListener waitForConnection = new TransportConnectionListener() { @Override public void onNodeConnected(DiscoveryNode node) { @@ -190,7 +188,6 @@ public void onNodeDisconnected(DiscoveryNode node) { curNode.connectTo(node.discoveryNode); node.connectTo(curNode.discoveryNode); } - node.connectTo(node.discoveryNode); assertThat("failed to wait for all nodes to connect", latch.await(5, TimeUnit.SECONDS), equalTo(true)); for (MockNode curNode : nodes.values()) { curNode.service.removeConnectionListener(waitForConnection); diff --git a/core/src/test/java/org/elasticsearch/discovery/zen/UnicastZenPingTests.java b/core/src/test/java/org/elasticsearch/discovery/zen/UnicastZenPingTests.java index 05c13c76a549a..cf08e80bb2aaa 100644 --- a/core/src/test/java/org/elasticsearch/discovery/zen/UnicastZenPingTests.java +++ b/core/src/test/java/org/elasticsearch/discovery/zen/UnicastZenPingTests.java @@ -28,6 +28,7 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode.Role; import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkAddress; import org.elasticsearch.common.network.NetworkService; @@ -45,6 +46,7 @@ import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.transport.ConnectionProfile; import org.elasticsearch.transport.MockTcpTransport; import org.elasticsearch.transport.Transport; @@ -150,7 +152,9 @@ public void testSimplePings() throws IOException, InterruptedException, Executio networkService, v) { @Override - public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) { + public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) + throws ConnectTransportException { throw new AssertionError("zen pings should never connect to node (got [" + node + "])"); } }; diff --git a/core/src/test/java/org/elasticsearch/transport/ConnectionProfileTests.java b/core/src/test/java/org/elasticsearch/transport/ConnectionProfileTests.java index 1785853d0e162..b18b57e371782 100644 --- a/core/src/test/java/org/elasticsearch/transport/ConnectionProfileTests.java +++ b/core/src/test/java/org/elasticsearch/transport/ConnectionProfileTests.java @@ -29,10 +29,15 @@ public class ConnectionProfileTests extends ESTestCase { public void testBuildConnectionProfile() { ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); TimeValue connectTimeout = TimeValue.timeValueMillis(randomIntBetween(1, 10)); + TimeValue handshaketTimeout = TimeValue.timeValueMillis(randomIntBetween(1, 10)); final boolean setConnectTimeout = randomBoolean(); if (setConnectTimeout) { builder.setConnectTimeout(connectTimeout); } + final boolean setHandshakeTimeout = randomBoolean(); + if (setHandshakeTimeout) { + builder.setHandshakeTimeout(handshaketTimeout); + } builder.addConnections(1, TransportRequestOptions.Type.BULK); builder.addConnections(2, TransportRequestOptions.Type.STATE, TransportRequestOptions.Type.RECOVERY); builder.addConnections(3, TransportRequestOptions.Type.PING); @@ -44,12 +49,22 @@ public void testBuildConnectionProfile() { assertEquals("type [PING] is already registered", illegalArgumentException.getMessage()); builder.addConnections(4, TransportRequestOptions.Type.REG); ConnectionProfile build = builder.build(); + if (randomBoolean()) { + build = new ConnectionProfile.Builder(build).build(); + } assertEquals(10, build.getNumConnections()); if (setConnectTimeout) { assertEquals(connectTimeout, build.getConnectTimeout()); } else { assertNull(build.getConnectTimeout()); } + + if (setHandshakeTimeout) { + assertEquals(handshaketTimeout, build.getHandshakeTimeout()); + } else { + assertNull(build.getHandshakeTimeout()); + } + Integer[] array = new Integer[10]; for (int i = 0; i < array.length; i++) { array[i] = i; diff --git a/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java b/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java index c4c0a768ff6c8..f40a3abd0afc8 100644 --- a/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java +++ b/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java @@ -28,6 +28,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.LocalTransportAddress; import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; @@ -39,6 +40,8 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import static org.hamcrest.Matchers.equalTo; + /** Unit tests for TCPTransport */ public class TCPTransportTests extends ESTestCase { @@ -240,6 +243,38 @@ public void writeTo(StreamOutput out) throws IOException { } } + public void testConnectionProfileResolve() { + final ConnectionProfile defaultProfile = TcpTransport.buildDefaultConnectionProfile(Settings.EMPTY); + assertEquals(defaultProfile, TcpTransport.resolveConnectionProfile(null, defaultProfile)); + + final ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); + builder.addConnections(randomIntBetween(0, 5), TransportRequestOptions.Type.BULK); + builder.addConnections(randomIntBetween(0, 5), TransportRequestOptions.Type.RECOVERY); + builder.addConnections(randomIntBetween(0, 5), TransportRequestOptions.Type.REG); + builder.addConnections(randomIntBetween(0, 5), TransportRequestOptions.Type.STATE); + builder.addConnections(randomIntBetween(0, 5), TransportRequestOptions.Type.PING); + + final boolean connectionTimeoutSet = randomBoolean(); + if (connectionTimeoutSet) { + builder.setConnectTimeout(TimeValue.timeValueMillis(randomNonNegativeLong())); + } + final boolean connectionHandshakeSet = randomBoolean(); + if (connectionHandshakeSet) { + builder.setHandshakeTimeout(TimeValue.timeValueMillis(randomNonNegativeLong())); + } + + final ConnectionProfile profile = builder.build(); + final ConnectionProfile resolved = TcpTransport.resolveConnectionProfile(profile, defaultProfile); + assertNotEquals(resolved, defaultProfile); + assertThat(resolved.getNumConnections(), equalTo(profile.getNumConnections())); + assertThat(resolved.getHandles(), equalTo(profile.getHandles())); + + assertThat(resolved.getConnectTimeout(), + equalTo(connectionTimeoutSet ? profile.getConnectTimeout() : defaultProfile.getConnectTimeout())); + assertThat(resolved.getHandshakeTimeout(), + equalTo(connectionHandshakeSet ? profile.getHandshakeTimeout() : defaultProfile.getHandshakeTimeout())); + } + public void testDefaultConnectionProfile() { ConnectionProfile profile = TcpTransport.buildDefaultConnectionProfile(Settings.EMPTY); assertEquals(13, profile.getNumConnections()); diff --git a/core/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java b/core/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java index 2a3006f2f7def..54af79ca3cb52 100644 --- a/core/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java +++ b/core/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java @@ -35,9 +35,6 @@ import java.io.IOException; import java.util.concurrent.CountDownLatch; -import static java.util.Collections.emptyMap; -import static java.util.Collections.emptySet; - public class TransportActionProxyTests extends ESTestCase { protected ThreadPool threadPool; // we use always a non-alpha or beta version here otherwise minimumCompatibilityVersion will be different for the two used versions @@ -61,11 +58,11 @@ public void setUp() throws Exception { super.setUp(); threadPool = new TestThreadPool(getClass().getName()); serviceA = buildService(version0); // this one supports dynamic tracer updates - nodeA = new DiscoveryNode("TS_A", serviceA.boundAddress().publishAddress(), emptyMap(), emptySet(), version0); + nodeA = serviceA.getLocalDiscoNode(); serviceB = buildService(version1); // this one doesn't support dynamic tracer updates - nodeB = new DiscoveryNode("TS_B", serviceB.boundAddress().publishAddress(), emptyMap(), emptySet(), version1); + nodeB = serviceB.getLocalDiscoNode(); serviceC = buildService(version1); // this one doesn't support dynamic tracer updates - nodeC = new DiscoveryNode("TS_C", serviceC.boundAddress().publishAddress(), emptyMap(), emptySet(), version1); + nodeC = serviceC.getLocalDiscoNode(); } @Override diff --git a/core/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java b/core/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java index c00f5fb07a5b1..3b6165adab158 100644 --- a/core/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java +++ b/core/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java @@ -161,6 +161,24 @@ public void testIncompatibleVersions() { assertFalse(handleA.transportService.nodeConnected(discoveryNode)); } + public void testNodeConnectWithDifferentNodeId() { + Settings settings = Settings.builder().put("cluster.name", "test").build(); + NetworkHandle handleA = startServices("TS_A", settings, Version.CURRENT); + NetworkHandle handleB = startServices("TS_B", settings, Version.CURRENT); + DiscoveryNode discoveryNode = new DiscoveryNode( + randomAsciiOfLength(10), + handleB.discoveryNode.getAddress(), + emptyMap(), + emptySet(), + handleB.discoveryNode.getVersion()); + ConnectTransportException ex = expectThrows(ConnectTransportException.class, () -> { + handleA.transportService.connectToNode(discoveryNode, MockTcpTransport.LIGHT_PROFILE); + }); + assertThat(ex.getMessage(), containsString("unexpected remote node")); + assertFalse(handleA.transportService.nodeConnected(discoveryNode)); + } + + private static class NetworkHandle { private TransportService transportService; private DiscoveryNode discoveryNode; diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4ScheduledPingTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4ScheduledPingTests.java index 22fd472fe1853..3657074e778c3 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4ScheduledPingTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4ScheduledPingTests.java @@ -18,7 +18,6 @@ */ package org.elasticsearch.transport.netty4; -import org.elasticsearch.Version; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.lease.Releasables; @@ -48,8 +47,6 @@ import java.io.IOException; import java.util.Collections; -import static java.util.Collections.emptyMap; -import static java.util.Collections.emptySet; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; @@ -81,10 +78,8 @@ public void testScheduledPing() throws Exception { serviceB.start(); serviceB.acceptIncomingRequests(); - DiscoveryNode nodeA = - new DiscoveryNode("TS_A", "TS_A", serviceA.boundAddress().publishAddress(), emptyMap(), emptySet(), Version.CURRENT); - DiscoveryNode nodeB = - new DiscoveryNode("TS_B", "TS_B", serviceB.boundAddress().publishAddress(), emptyMap(), emptySet(), Version.CURRENT); + DiscoveryNode nodeA = serviceA.getLocalDiscoNode(); + DiscoveryNode nodeB = serviceB.getLocalDiscoNode(); if (randomBoolean()) { // use connection profile with different connect timeout diff --git a/test/framework/src/main/java/org/elasticsearch/test/ClusterServiceUtils.java b/test/framework/src/main/java/org/elasticsearch/test/ClusterServiceUtils.java index 8150b3a7564df..ac04de48cf761 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ClusterServiceUtils.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ClusterServiceUtils.java @@ -48,8 +48,13 @@ public static ClusterService createClusterService(ThreadPool threadPool) { } public static ClusterService createClusterService(ThreadPool threadPool, DiscoveryNode localNode) { - ClusterService clusterService = new ClusterService(Settings.builder().put("cluster.name", "ClusterServiceTests").build(), - new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), + return createClusterService(Settings.EMPTY, threadPool, localNode); + } + + public static ClusterService createClusterService(Settings settings, ThreadPool threadPool, DiscoveryNode localNode) { + ClusterService clusterService = new ClusterService( + Settings.builder().put("cluster.name", "ClusterServiceTests").put(settings).build(), + new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), threadPool, () -> localNode); clusterService.setNodeConnectionsService(new NodeConnectionsService(Settings.EMPTY, null, null) { @Override diff --git a/test/framework/src/main/java/org/elasticsearch/test/transport/CapturingTransport.java b/test/framework/src/main/java/org/elasticsearch/test/transport/CapturingTransport.java index bb1250b31b241..503503dd68b8f 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/transport/CapturingTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/test/transport/CapturingTransport.java @@ -21,6 +21,7 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.Randomness; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.component.Lifecycle; @@ -243,7 +244,9 @@ public boolean nodeConnected(DiscoveryNode node) { } @Override - public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) throws ConnectTransportException { + public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) + throws ConnectTransportException { } diff --git a/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java b/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java index e7a31735416a1..4644b6b1a78c0 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java +++ b/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java @@ -22,6 +22,7 @@ import org.elasticsearch.Version; import org.elasticsearch.cluster.ClusterModule; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.component.Lifecycle; @@ -105,7 +106,16 @@ protected Version getVersion() { return version; } }; - return new MockTransportService(settings, transport, threadPool, TransportService.NOOP_TRANSPORT_INTERCEPTOR, clusterSettings); + return createNewService(settings, transport, version, threadPool, clusterSettings); + } + + public static MockTransportService createNewService(Settings settings, Transport transport, Version version, ThreadPool threadPool, + @Nullable ClusterSettings clusterSettings) { + return new MockTransportService(settings, transport, threadPool, TransportService.NOOP_TRANSPORT_INTERCEPTOR, + boundAddress -> + new DiscoveryNode(Node.NODE_NAME_SETTING.get(settings), UUIDs.randomBase64UUID(), boundAddress.publishAddress(), + Node.NODE_ATTRIBUTES.get(settings).getAsMap(), DiscoveryNode.getRolesFromSettings(settings), version), + clusterSettings); } public static MockTransportService mockTcp(Settings settings, Version version, ThreadPool threadPool, @@ -236,7 +246,9 @@ public void addFailToSendNoConnectRule(TransportAddress transportAddress) { addDelegate(transportAddress, new DelegateTransport(original) { @Override - public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) throws ConnectTransportException { + public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) + throws ConnectTransportException { if (original.nodeConnected(node) == false) { // connecting to an already connected node is a no-op throw new ConnectTransportException(node, "DISCONNECT: simulated"); @@ -282,8 +294,10 @@ public void addFailToSendNoConnectRule(TransportAddress transportAddress, final addDelegate(transportAddress, new DelegateTransport(original) { @Override - public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) throws ConnectTransportException { - original.connectToNode(node, connectionProfile); + public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) + throws ConnectTransportException { + original.connectToNode(node, connectionProfile, connectionValidator); } @Override @@ -316,7 +330,9 @@ public void addUnresponsiveRule(TransportAddress transportAddress) { addDelegate(transportAddress, new DelegateTransport(original) { @Override - public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) throws ConnectTransportException { + public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) + throws ConnectTransportException { if (original.nodeConnected(node) == false) { // connecting to an already connected node is a no-op throw new ConnectTransportException(node, "UNRESPONSIVE: simulated"); @@ -361,14 +377,16 @@ TimeValue getDelay() { } @Override - public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) throws ConnectTransportException { + public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) + throws ConnectTransportException { if (original.nodeConnected(node)) { // connecting to an already connected node is a no-op return; } TimeValue delay = getDelay(); if (delay.millis() <= 0) { - original.connectToNode(node, connectionProfile); + original.connectToNode(node, connectionProfile, connectionValidator); return; } @@ -377,7 +395,7 @@ public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfil try { if (delay.millis() < connectingTimeout.millis()) { Thread.sleep(delay.millis()); - original.connectToNode(node, connectionProfile); + original.connectToNode(node, connectionProfile, connectionValidator); } else { Thread.sleep(connectingTimeout.millis()); throw new ConnectTransportException(node, "UNRESPONSIVE: simulated"); @@ -524,10 +542,11 @@ public boolean nodeConnected(DiscoveryNode node) { return getTransport(node).nodeConnected(node); } - @Override - public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) throws ConnectTransportException { - getTransport(node).connectToNode(node, connectionProfile); + public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) + throws ConnectTransportException { + getTransport(node).connectToNode(node, connectionProfile, connectionValidator); } @Override @@ -585,8 +604,10 @@ public boolean nodeConnected(DiscoveryNode node) { } @Override - public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) throws ConnectTransportException { - transport.connectToNode(node, connectionProfile); + public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) + throws ConnectTransportException { + transport.connectToNode(node, connectionProfile, connectionValidator); } @Override diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index 7aaf821405a78..d58ce660ef8da 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -24,6 +24,7 @@ import org.apache.lucene.util.CollectionUtil; import org.apache.lucene.util.Constants; import org.apache.lucene.util.IOUtils; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListenerResponseHandler; @@ -73,6 +74,7 @@ import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import static java.util.Collections.emptyMap; @@ -361,6 +363,101 @@ public String executor() { assertThat(responseString.get(), equalTo("test")); } + public void testAdapterSendReceiveCallbacks() throws Exception { + final TransportRequestHandler requestHandler = (request, channel) -> { + try { + if (randomBoolean()) { + channel.sendResponse(TransportResponse.Empty.INSTANCE); + } else { + channel.sendResponse(new ElasticsearchException("simulated")); + } + } catch (IOException e) { + logger.error("Unexpected failure", e); + fail(e.getMessage()); + } + }; + serviceA.registerRequestHandler("action", TransportRequest.Empty::new, ThreadPool.Names.GENERIC, + requestHandler); + serviceB.registerRequestHandler("action", TransportRequest.Empty::new, ThreadPool.Names.GENERIC, + requestHandler); + + + class CountingTracer extends MockTransportService.Tracer { + AtomicInteger requestsReceived = new AtomicInteger(); + AtomicInteger requestsSent = new AtomicInteger(); + AtomicInteger responseReceived = new AtomicInteger(); + AtomicInteger responseSent = new AtomicInteger(); + @Override + public void receivedRequest(long requestId, String action) { + requestsReceived.incrementAndGet(); + } + + @Override + public void responseSent(long requestId, String action) { + responseSent.incrementAndGet(); + } + + @Override + public void responseSent(long requestId, String action, Throwable t) { + responseSent.incrementAndGet(); + } + + @Override + public void receivedResponse(long requestId, DiscoveryNode sourceNode, String action) { + responseReceived.incrementAndGet(); + } + + @Override + public void requestSent(DiscoveryNode node, long requestId, String action, TransportRequestOptions options) { + requestsSent.incrementAndGet(); + } + } + final CountingTracer tracerA = new CountingTracer(); + final CountingTracer tracerB = new CountingTracer(); + serviceA.addTracer(tracerA); + serviceB.addTracer(tracerB); + + try { + serviceA + .submitRequest(nodeB, "action", TransportRequest.Empty.INSTANCE, EmptyTransportResponseHandler.INSTANCE_SAME).get(); + } catch (ExecutionException e) { + assertThat(e.getCause(), instanceOf(ElasticsearchException.class)); + assertThat(ExceptionsHelper.unwrapCause(e.getCause()).getMessage(), equalTo("simulated")); + } + + // use assert busy as call backs are sometime called after the response have been sent + assertBusy(() -> { + assertThat(tracerA.requestsReceived.get(), equalTo(0)); + assertThat(tracerA.requestsSent.get(), equalTo(1)); + assertThat(tracerA.responseReceived.get(), equalTo(1)); + assertThat(tracerA.responseSent.get(), equalTo(0)); + assertThat(tracerB.requestsReceived.get(), equalTo(1)); + assertThat(tracerB.requestsSent.get(), equalTo(0)); + assertThat(tracerB.responseReceived.get(), equalTo(0)); + assertThat(tracerB.responseSent.get(), equalTo(1)); + }); + + try { + serviceA + .submitRequest(nodeA, "action", TransportRequest.Empty.INSTANCE, EmptyTransportResponseHandler.INSTANCE_SAME).get(); + } catch (ExecutionException e) { + assertThat(e.getCause(), instanceOf(ElasticsearchException.class)); + assertThat(ExceptionsHelper.unwrapCause(e.getCause()).getMessage(), equalTo("simulated")); + } + + // use assert busy as call backs are sometime called after the response have been sent + assertBusy(() -> { + assertThat(tracerA.requestsReceived.get(), equalTo(1)); + assertThat(tracerA.requestsSent.get(), equalTo(2)); + assertThat(tracerA.responseReceived.get(), equalTo(2)); + assertThat(tracerA.responseSent.get(), equalTo(1)); + assertThat(tracerB.requestsReceived.get(), equalTo(1)); + assertThat(tracerB.requestsSent.get(), equalTo(0)); + assertThat(tracerB.responseReceived.get(), equalTo(0)); + assertThat(tracerB.responseSent.get(), equalTo(1)); + }); + } + public void testVoidMessageCompressed() { serviceA.registerRequestHandler("sayHello", TransportRequest.Empty::new, ThreadPool.Names.GENERIC, (request, channel) -> { @@ -621,7 +718,7 @@ public void onAfter() { MockTransportService newService = buildService("TS_B_" + i, version1, null); newService.registerRequestHandler("test", TestRequest::new, ThreadPool.Names.SAME, ignoringRequestHandler); serviceB = newService; - nodeB = new DiscoveryNode("TS_B_" + i, "TS_B", serviceB.boundAddress().publishAddress(), emptyMap(), emptySet(), version1); + nodeB = newService.getLocalDiscoNode(); serviceB.connectToNode(nodeA); serviceA.connectToNode(nodeB); } else if (serviceA.nodeConnected(nodeB)) { @@ -1468,42 +1565,42 @@ public void testBlockingIncomingRequests() throws Exception { channel.sendResponse(TransportResponse.Empty.INSTANCE); }); - DiscoveryNode node = - new DiscoveryNode("TS_TEST", "TS_TEST", service.boundAddress().publishAddress(), emptyMap(), emptySet(), version0); + DiscoveryNode node = service.getLocalNode(); serviceA.close(); serviceA = buildService("TS_A", version0, null, Settings.EMPTY, true, false); - serviceA.connectToNode(node); - - CountDownLatch latch = new CountDownLatch(1); - serviceA.sendRequest(node, "action", new TestRequest(), new TransportResponseHandler() { - @Override - public TestResponse newInstance() { - return new TestResponse(); - } + try (Transport.Connection connection = serviceA.openConnection(node, null)) { + CountDownLatch latch = new CountDownLatch(1); + serviceA.sendRequest(connection, "action", new TestRequest(), TransportRequestOptions.EMPTY, + new TransportResponseHandler() { + @Override + public TestResponse newInstance() { + return new TestResponse(); + } - @Override - public void handleResponse(TestResponse response) { - latch.countDown(); - } + @Override + public void handleResponse(TestResponse response) { + latch.countDown(); + } - @Override - public void handleException(TransportException exp) { - latch.countDown(); - } + @Override + public void handleException(TransportException exp) { + latch.countDown(); + } - @Override - public String executor() { - return ThreadPool.Names.SAME; - } - }); + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + }); - assertFalse(requestProcessed.get()); + assertFalse(requestProcessed.get()); - service.acceptIncomingRequests(); - assertBusy(() -> assertTrue(requestProcessed.get())); + service.acceptIncomingRequests(); + assertBusy(() -> assertTrue(requestProcessed.get())); - latch.await(); + latch.await(); + } } } @@ -1784,12 +1881,12 @@ public void testTimeoutPerConnection() throws IOException { // connection with one connection and a large timeout -- should consume the one spot in the backlog queue try (TransportService service = buildService("TS_TPC", Version.CURRENT, null, Settings.EMPTY, true, false)) { - service.connectToNode(first, builder.build()); + IOUtils.close(service.openConnection(first, builder.build())); builder.setConnectTimeout(TimeValue.timeValueMillis(1)); final ConnectionProfile profile = builder.build(); // now with the 1ms timeout we got and test that is it's applied long startTime = System.nanoTime(); - ConnectTransportException ex = expectThrows(ConnectTransportException.class, () -> service.connectToNode(second, profile)); + ConnectTransportException ex = expectThrows(ConnectTransportException.class, () -> service.openConnection(second, profile)); final long now = System.nanoTime(); final long timeTaken = TimeValue.nsecToMSec(now - startTime); assertTrue("test didn't timeout quick enough, time taken: [" + timeTaken + "]", @@ -1873,13 +1970,13 @@ protected String handleRequest(MockChannel mockChannel, String profileName, Stre serviceA.disconnectFromNode(node); } - try (TransportService service = buildService("TS_TPC", Version.CURRENT, null)) { - DiscoveryNode node = - new DiscoveryNode("TS_TPC", "TS_TPC", service.boundAddress().publishAddress(), emptyMap(), emptySet(), version0); - serviceA.connectToNode(node); - TcpTransport.NodeChannels connection = originalTransport.getConnection(node); - Version version = originalTransport.executeHandshake(node, connection.channel(TransportRequestOptions.Type.PING), - TimeValue.timeValueSeconds(10)); + try (TransportService service = buildService("TS_TPC", Version.CURRENT, null); + TcpTransport.NodeChannels connection = originalTransport.openConnection( + new DiscoveryNode("TS_TPC", "TS_TPC", service.boundAddress().publishAddress(), emptyMap(), emptySet(), version0), + null + ) ) { + Version version = originalTransport.executeHandshake(connection.getNode(), + connection.channel(TransportRequestOptions.Type.PING), TimeValue.timeValueSeconds(10)); assertEquals(version, Version.CURRENT); } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java index e58771c4e03ed..022f7033ec55a 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java @@ -201,8 +201,7 @@ protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile p InetSocketAddress address = ((InetSocketTransportAddress) node.getAddress()).address(); // we just use a single connections configureSocket(socket); - final TimeValue connectTimeout = profile.getConnectTimeout() == null ? defaultConnectionProfile.getConnectTimeout() - : profile.getConnectTimeout(); + final TimeValue connectTimeout = profile.getConnectTimeout(); try { socket.connect(address, Math.toIntExact(connectTimeout.millis())); } catch (SocketTimeoutException ex) { diff --git a/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java b/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java index 2cc84c4c0cd0f..75d450b5d53ca 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java @@ -48,8 +48,8 @@ protected Version executeHandshake(DiscoveryNode node, MockChannel mockChannel, } } }; - MockTransportService mockTransportService = new MockTransportService(Settings.EMPTY, transport, threadPool, - TransportService.NOOP_TRANSPORT_INTERCEPTOR, clusterSettings); + MockTransportService mockTransportService = + MockTransportService.createNewService(Settings.EMPTY, transport, version, threadPool, clusterSettings); mockTransportService.start(); return mockTransportService; }