diff --git a/shuffle-plugin/pom.xml b/shuffle-plugin/pom.xml
index 140e7e00b1f4..97820a145da0 100644
--- a/shuffle-plugin/pom.xml
+++ b/shuffle-plugin/pom.xml
@@ -44,7 +44,7 @@
org.openucx
jucx
- 1.11
+ 1.12.0
compile
diff --git a/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCX.scala b/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCX.scala
index 177eeef1a5cc..de8676d4328b 100644
--- a/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCX.scala
+++ b/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCX.scala
@@ -136,6 +136,7 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf:
}
var workerParams = new UcpWorkerParams()
+ .setClientId(localExecutorId)
if (rapidsConf.shuffleUcxUseWakeup) {
workerParams = workerParams
@@ -775,6 +776,7 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf:
// enables `onError` callback
.setPeerErrorHandlingMode()
.setErrorHandler(this)
+ .sendClientId()
/**
* Get a `ClientConnection` after optionally connecting to a peer given by `peerExecutorId`,
@@ -806,7 +808,8 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf:
onWorkerThreadAsync(() => {
endpoints.computeIfAbsent(peerExecutorId, _ => {
val sockAddr = new InetSocketAddress(peerHost, peerPort)
- val ep = worker.newEndpoint(epParams.setSocketAddress(sockAddr))
+ val ep = worker.newEndpoint(
+ epParams.setSocketAddress(sockAddr))
logDebug(s"Initiator: created an endpoint $ep to $peerExecutorId")
reverseLookupEndpoints.put(ep, peerExecutorId)
ep
@@ -816,12 +819,20 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf:
// UcpListenerConnectionHandler interface - called from progress thread
// handles an incoming connection to our UCP Listener
- // TODO: in the future, this function may reject `ConnectionRequest`s
- // given a peer id we already established a connection to:
- // https://github.com/openucx/ucx/pull/6859
override def onConnectionRequest(connectionRequest: UcpConnectionRequest): Unit = {
logInfo(s"Got UcpListener request from ${connectionRequest.getClientAddress}")
+ val clientId = connectionRequest.getClientId
+
+ if (endpoints.containsKey(clientId)) {
+ connectionRequest.reject()
+ logWarning(s"Rejected connection request from ${clientId}, we already had an " +
+ s"endpoint established: ${endpoints.get(clientId)}")
+ return
+ } else {
+ logWarning(s"Accepting connection request from ${clientId}!")
+ }
+
// accept it
val ep = worker.newEndpoint(epParams.setConnectionRequest(connectionRequest))