diff --git a/shuffle-plugin/pom.xml b/shuffle-plugin/pom.xml index 140e7e00b1f4..97820a145da0 100644 --- a/shuffle-plugin/pom.xml +++ b/shuffle-plugin/pom.xml @@ -44,7 +44,7 @@ org.openucx jucx - 1.11 + 1.12.0 compile diff --git a/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCX.scala b/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCX.scala index 177eeef1a5cc..de8676d4328b 100644 --- a/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCX.scala +++ b/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCX.scala @@ -136,6 +136,7 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf: } var workerParams = new UcpWorkerParams() + .setClientId(localExecutorId) if (rapidsConf.shuffleUcxUseWakeup) { workerParams = workerParams @@ -775,6 +776,7 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf: // enables `onError` callback .setPeerErrorHandlingMode() .setErrorHandler(this) + .sendClientId() /** * Get a `ClientConnection` after optionally connecting to a peer given by `peerExecutorId`, @@ -806,7 +808,8 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf: onWorkerThreadAsync(() => { endpoints.computeIfAbsent(peerExecutorId, _ => { val sockAddr = new InetSocketAddress(peerHost, peerPort) - val ep = worker.newEndpoint(epParams.setSocketAddress(sockAddr)) + val ep = worker.newEndpoint( + epParams.setSocketAddress(sockAddr)) logDebug(s"Initiator: created an endpoint $ep to $peerExecutorId") reverseLookupEndpoints.put(ep, peerExecutorId) ep @@ -816,12 +819,20 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf: // UcpListenerConnectionHandler interface - called from progress thread // handles an incoming connection to our UCP Listener - // TODO: in the future, this function may reject `ConnectionRequest`s - // given a peer id we already established a connection to: - // https://github.com/openucx/ucx/pull/6859 override def onConnectionRequest(connectionRequest: UcpConnectionRequest): Unit = { logInfo(s"Got UcpListener request from ${connectionRequest.getClientAddress}") + val clientId = connectionRequest.getClientId + + if (endpoints.containsKey(clientId)) { + connectionRequest.reject() + logWarning(s"Rejected connection request from ${clientId}, we already had an " + + s"endpoint established: ${endpoints.get(clientId)}") + return + } else { + logWarning(s"Accepting connection request from ${clientId}!") + } + // accept it val ep = worker.newEndpoint(epParams.setConnectionRequest(connectionRequest))