From f156f7bc0951e60cb0a82d9122f144f4d68bc77a Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Tue, 3 Aug 2021 09:02:48 -0500 Subject: [PATCH] Use clientId api --- .../scala/com/nvidia/spark/rapids/shuffle/ucx/UCX.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCX.scala b/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCX.scala index 177eeef1a5c..95f4b742306 100644 --- a/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCX.scala +++ b/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCX.scala @@ -775,6 +775,7 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf: // enables `onError` callback .setPeerErrorHandlingMode() .setErrorHandler(this) + .setClientId(executor.executorId.toLong) /** * Get a `ClientConnection` after optionally connecting to a peer given by `peerExecutorId`, @@ -820,7 +821,13 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf: // given a peer id we already established a connection to: // https://github.com/openucx/ucx/pull/6859 override def onConnectionRequest(connectionRequest: UcpConnectionRequest): Unit = { - logInfo(s"Got UcpListener request from ${connectionRequest.getClientAddress}") + logInfo(s"Got UcpListener request from ${connectionRequest.getClientAddress} for " + + s"executorId ${connectionRequest.getClientId}") + // if we already have an endpoint in our endpoints cache, reject + if (endpoints.contains(connectionRequest.getClientId)) { + logError(s"SHOULD REJECT ${connectionRequest.getClientId}") + return + } // accept it val ep = worker.newEndpoint(epParams.setConnectionRequest(connectionRequest))