From 8788f30a177ed117bc29a57d8045c4551c139f04 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Thu, 20 May 2021 22:55:57 -0400 Subject: [PATCH] Change shuffle metadata messages to use UCX Active Messages (#2409) * Change metadata messages to use Active Messages Signed-off-by: Alessandro Bellina * Code review comments * Comments: copyright, visibility, tests, conf check * Comments: spacing * Comments: private final, no side effects in BSS * Comments: simplify UCXActiveMessage * Comments: UCX.scala better comments, interface changes * Comments: UCX.scala putIfAbsent * Comments: small cleanup in UCX.scala * Move some static tag-handling functions out of the UCXConnection instances. To help testing * Fix a couple of bugs I introduced during refactorings in this PR * Comments: Refactor Active Message registrations a bit * Update comment * UCXConnectionSuite * Comments: shift-then-mask while extracting executorId * Add a few test cases touching the active message header generation * Comments: Make request am handler a constructor argument * Fix RequestActiveMessageRegistration --- docs/configs.md | 1 + .../nvidia/spark/rapids/shuffle/ucx/UCX.scala | 357 +++++++++++++++--- .../rapids/shuffle/ucx/UCXConnection.scala | 344 +++++++++++------ .../shuffle/ucx/UCXShuffleTransport.scala | 9 +- .../rapids/shuffle/ucx/UCXTransaction.scala | 105 +++++- .../main/format/ShuffleMetadataRequest.fbs | 11 +- .../main/format/ShuffleMetadataResponse.fbs | 5 - .../main/format/ShuffleTransferRequest.fbs | 6 - .../spark/rapids/format/MetadataRequest.java | 34 +- .../spark/rapids/format/MetadataResponse.java | 20 +- .../spark/rapids/format/TransferRequest.java | 26 +- .../com/nvidia/spark/rapids/MetaUtils.scala | 32 +- .../com/nvidia/spark/rapids/RapidsConf.scala | 15 +- .../rapids/shuffle/BufferSendState.scala | 43 ++- .../rapids/shuffle/RapidsShuffleClient.scala | 183 +++------ .../rapids/shuffle/RapidsShuffleServer.scala | 216 +++++------ .../shuffle/RapidsShuffleTransport.scala | 143 ++++--- .../shuffle/RapidsShuffleClientSuite.scala | 55 +-- .../shuffle/RapidsShuffleIteratorSuite.scala | 10 +- .../shuffle/RapidsShuffleServerSuite.scala | 28 +- .../shuffle/RapidsShuffleTestHelper.scala | 51 ++- .../rapids/shuffle/UCXConnectionSuite.scala | 45 +++ 22 files changed, 1057 insertions(+), 682 deletions(-) create mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/shuffle/UCXConnectionSuite.scala diff --git a/docs/configs.md b/docs/configs.md index b15ca027ba9..1f6d739b30e 100644 --- a/docs/configs.md +++ b/docs/configs.md @@ -49,6 +49,7 @@ Name | Description | Default Value spark.rapids.shuffle.transport.earlyStart|Enable early connection establishment for RAPIDS Shuffle|true spark.rapids.shuffle.transport.earlyStart.heartbeatInterval|Shuffle early start heartbeat interval (milliseconds)|5000 spark.rapids.shuffle.transport.maxReceiveInflightBytes|Maximum aggregate amount of bytes that be fetched at any given time from peers during shuffle|1073741824 +spark.rapids.shuffle.ucx.activeMessages.mode|Set to 'rndv', 'eager', or 'auto' to indicate what UCX Active Message mode to use. We set 'rndv' (Rendezvous) by default because UCX 1.10.x doesn't support 'eager' fully. This restriction can be lifted if the user is running UCX 1.11+.|rndv spark.rapids.shuffle.ucx.managementServerHost|The host to be used to start the management server|null spark.rapids.shuffle.ucx.useWakeup|When set to true, use UCX's event-based progress (epoll) in order to wake up the progress thread when needed, instead of a hot loop.|true spark.rapids.sql.batchSizeBytes|Set the target number of bytes for a GPU batch. Splits sizes for input data is covered by separate configs. The maximum setting is 2 GB to avoid exceeding the cudf row count limit of a column.|2147483647 diff --git a/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCX.scala b/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCX.scala index c114aadb296..67ccc892ce8 100644 --- a/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCX.scala +++ b/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCX.scala @@ -27,7 +27,7 @@ import scala.util.Random import ai.rapids.cudf.{MemoryBuffer, NvtxColor, NvtxRange} import com.google.common.util.concurrent.ThreadFactoryBuilder -import com.nvidia.spark.rapids.{GpuDeviceManager, RapidsConf} +import com.nvidia.spark.rapids.{Arm, GpuDeviceManager, RapidsConf} import com.nvidia.spark.rapids.shuffle.{AddressLengthTag, ClientConnection, MemoryRegistrationCallback, TransportUtils} import org.openucx.jucx._ import org.openucx.jucx.ucp._ @@ -35,12 +35,27 @@ import org.openucx.jucx.ucs.UcsConstants import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging +import org.apache.spark.sql.rapids.storage.RapidsStorageUtils import org.apache.spark.storage.BlockManagerId case class WorkerAddress(address: ByteBuffer) case class Rkeys(rkeys: Seq[ByteBuffer]) +/** + * A simple wrapper for an Active Message Id and a header. This pair + * is used together when dealing with Active Messages, with `activeMessageId` + * being a fire-and-forget registration with UCX, and `header` being a dynamic long + * we continue to update (it contains the local executor id, and the transaction id). + * + * This allows us to send a request (with a header that the response handler knows about), + * and for the request handler to echo back that header when it's done. + */ +case class UCXActiveMessage(activeMessageId: Int, header: Long) { + override def toString: String = + UCX.formatAmIdAndHeader(activeMessageId, header) +} + /** * The UCX class wraps JUCX classes and handles all communication with UCX from other * parts of the shuffle code. It manages a `UcpContext` and `UcpWorker`, for the @@ -52,12 +67,16 @@ case class Rkeys(rkeys: Seq[ByteBuffer]) * This class uses an extra TCP management connection to perform a handshake with remote peers, * this port should be distributed to peers by other means (e.g. via the `BlockManagerId`) * + * @param transport transport instance for UCX * @param executor blockManagerId of the local executorId * @param rapidsConf rapids configuration */ -class UCX(executor: BlockManagerId, rapidsConf: RapidsConf) extends AutoCloseable with Logging { +class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf: RapidsConf) + extends AutoCloseable with Logging with Arm { private[this] val context = { - val contextParams = new UcpParams().requestTagFeature() + val contextParams = new UcpParams() + .requestTagFeature() + .requestAmFeature() if (rapidsConf.shuffleUcxUseWakeup) { contextParams.requestWakeupFeature() } @@ -66,8 +85,10 @@ class UCX(executor: BlockManagerId, rapidsConf: RapidsConf) extends AutoCloseabl logInfo(s"UCX context created") + def getExecutorId: Int = executor.executorId.toInt + // this object implements the transport-friendly interface for UCX - private[this] val serverConnection = new UCXServerConnection(this) + private[this] val serverConnection = new UCXServerConnection(this, transport) // monotonically increasing counter that holds the txId (for debug purposes, at this stage) private[this] val txId = new AtomicLong(0L) @@ -110,7 +131,6 @@ class UCX(executor: BlockManagerId, rapidsConf: RapidsConf) extends AutoCloseabl // This makes sure that all executor threads get the same [[Connection]] object for a specific // management (host, port) key. private val connectionCache = new ConcurrentHashMap[Long, ClientConnection]() - private val executorIdToPeerTag = new ConcurrentHashMap[Long, Long]() // holds memory registered against UCX that should be de-register on exit (used for bounce // buffers) @@ -121,13 +141,19 @@ class UCX(executor: BlockManagerId, rapidsConf: RapidsConf) extends AutoCloseabl // the worker thread. We need this to complete prior to getting the `rkeys`. private var pendingRegistration = false + // There will be 1 entry in this map per UCX-registered Active Message. Presently + // that means: 2 request active messages (Metadata and Transfer Request), and 2 + // response active messages (Metadata and Transfer response). + private val amRegistrations = new ConcurrentHashMap[Int, ActiveMessageRegistration]() + // Error handler that would be invoked on endpoint failure. private val epErrorHandler = new UcpEndpointErrorHandler { override def onError(ucpEndpoint: UcpEndpoint, errorCode: Int, errorString: String): Unit = { - endpoints.values().removeIf(ep => ep == ucpEndpoint) - ucpEndpoint.close() - if (errorCode != UcsConstants.STATUS.UCS_ERR_CONNECTION_RESET) { - logError(s"Endpoint to $ucpEndpoint got error: $errorString") + withResource(ucpEndpoint) { _ => + if (errorCode != UcsConstants.STATUS.UCS_ERR_CONNECTION_RESET) { + logError(s"Endpoint to $ucpEndpoint got error: $errorString") + } + endpoints.values().removeIf(ep => ep == ucpEndpoint) } } } @@ -205,11 +231,8 @@ class UCX(executor: BlockManagerId, rapidsConf: RapidsConf) extends AutoCloseabl // this could change in the future to 1 progress call per loop, or be used // entirely differently once polling is figured out def drainWorker(): Unit = { - val nvtxRange = new NvtxRange("UCX Draining Worker", NvtxColor.RED) - try { + withResource(new NvtxRange("UCX Draining Worker", NvtxColor.RED)) { _ => while (worker.progress() > 0) {} - } finally { - nvtxRange.close() } } @@ -219,23 +242,17 @@ class UCX(executor: BlockManagerId, rapidsConf: RapidsConf) extends AutoCloseabl // else worker.progress returned 0 if (rapidsConf.shuffleUcxUseWakeup) { drainWorker() - val sleepRange = new NvtxRange("UCX Sleeping", NvtxColor.PURPLE) - try { + withResource(new NvtxRange("UCX Sleeping", NvtxColor.PURPLE)) { _ => worker.waitForEvents() - } finally { - sleepRange.close() } } while (!workerTasks.isEmpty) { - val nvtxRange = new NvtxRange("UCX Handling Tasks", NvtxColor.CYAN) - try { + withResource(new NvtxRange("UCX Handling Tasks", NvtxColor.CYAN)) { _ => val wt = workerTasks.poll() if (wt != null) { wt() } - } finally { - nvtxRange.close() } worker.progress() } @@ -320,7 +337,7 @@ class UCX(executor: BlockManagerId, rapidsConf: RapidsConf) extends AutoCloseabl override def onError(ucsStatus: Int, errorMsg: String): Unit = { if (ucsStatus == UCX.UCS_ERR_CANCELED) { logWarning( - s"Cancelled: tag=${TransportUtils.formatTag(alt.tag)}," + + s"Cancelled: tag=${TransportUtils.toHex(alt.tag)}," + s" status=$ucsStatus, msg=$errorMsg") cb.onCancel(alt) } else { @@ -345,6 +362,246 @@ class UCX(executor: BlockManagerId, rapidsConf: RapidsConf) extends AutoCloseabl }) } + /** + * This trait and next two implementations represent the mapping between an Active Message Id + * and the callback that should be triggered when a message is received. + * + * There are two types of Active Messages we care about: requests and responses. + * + * For requests: + * - `activeMessageId` for requests is the value of the `RequestType` enum, and it is + * set once when the `RapidsShuffleServer` is initialized, and no new request handlers + * are established. + * + * - The Active Message header is handed to the request handler, via the transaction. + * The request handler needs to echo the header back for the response handler on the + * other side of the request. + * + * - On a request, a callback is instantiated using `requestCallbackGen`, which creates + * a transaction each time. This is one way to handle several requests inbound to a server. + * + * For responses: + * - `activeMessageId` for responses is the value of the `RequestType` enum with an extra + * bit flipped (see `UCXConnection.composeResponseAmId`). These are also set once, as requests + * are sent out. + * + * - The Active Message header is used to pick the correct callback to call. In this case + * there could be several expected responses, for a single response `activeMessageId`, so the + * server echoes back our header so we can invoke the correct response callback. + * + * - Each response received at the response activeMessageId, will be demuxed using the header: + * responseActiveMessageId1 -> [callbackForHeader1, callbackForHeader2, ..., callbackForHeaderN] + */ + trait ActiveMessageRegistration { + val activeMessageId: Int + def getCallback(header: Long): UCXAmCallback + } + + class RequestActiveMessageRegistration(override val activeMessageId: Int, + requestCbGen: () => UCXAmCallback) + extends ActiveMessageRegistration { + + def getCallback(header: Long): UCXAmCallback = requestCbGen() + } + + class ResponseActiveMessageRegistration(override val activeMessageId: Int) + extends ActiveMessageRegistration { + private val responseCallbacks = new ConcurrentHashMap[Long, UCXAmCallback]() + + def getCallback(header: Long): UCXAmCallback = { + val cb = responseCallbacks.remove(header) // 1 callback per header + require (cb != null, + s"Failed to get a response Active Message callback for " + + s"${UCX.formatAmIdAndHeader(activeMessageId, header)}") + cb + } + + def addResponseActiveMessageHandler(header: Long, responseCallback: UCXAmCallback): Unit = { + val prior = responseCallbacks.putIfAbsent(header, responseCallback) + require(prior == null, + s"Invalid Active Message re-registration of response handler for " + + s"${UCX.formatAmIdAndHeader(activeMessageId, header)}") + } + } + + /** + * Register a response handler (clients will use this) + * + * @note This function will be called for each client, with the same `am.activeMessageId` + * @param activeMessageId (up to 5 bits) used to register with UCX an Active Message + * @param header a long used to demux responses arriving at `activeMessageId` + * @param responseCallback callback to handle a particular response + */ + def registerResponseHandler( + activeMessageId: Int, header: Long, responseCallback: UCXAmCallback): Unit = { + logDebug(s"Register Active Message " + + s"${UCX.formatAmIdAndHeader(activeMessageId, header)} response handler") + + amRegistrations.computeIfAbsent(activeMessageId, + _ => { + val reg = new ResponseActiveMessageRegistration(activeMessageId) + registerActiveMessage(reg) + reg + }) match { + case reg: ResponseActiveMessageRegistration => + reg.addResponseActiveMessageHandler(header, responseCallback) + case other => + throw new IllegalStateException( + s"Attempted to add a response Active Message handler to existing registration $other " + + s"for ${UCX.formatAmIdAndHeader(activeMessageId, header)}") + } + } + + /** + * Register a request handler (the server will use this) + * @note This function will be called once for the server for an `activeMessageId` + * @param activeMessageId (up to 5 bits) used to register with UCX an Active Message + * @param requestCallbackGen a function that instantiates a callback to handle + * a particular request + */ + def registerRequestHandler(activeMessageId: Int, + requestCallbackGen: () => UCXAmCallback): Unit = { + logDebug(s"Register Active Message $TransportUtils.request handler") + val reg = new RequestActiveMessageRegistration(activeMessageId, requestCallbackGen) + val oldReg = amRegistrations.putIfAbsent(activeMessageId, reg) + require(oldReg == null, + s"Tried to re-register a request handler for $activeMessageId") + registerActiveMessage(reg) + } + + private def registerActiveMessage(reg: ActiveMessageRegistration): Unit = { + onWorkerThreadAsync(() => { + worker.setAmRecvHandler(reg.activeMessageId, + (headerAddr, headerSize, amData: UcpAmData, _) => { + if (headerSize != 8) { + // this is a coding error, so I am just blowing up. It should never happen. + throw new IllegalStateException( + s"Received message with wrong header size $headerSize") + } else { + val header = UcxUtils.getByteBufferView(headerAddr, headerSize).getLong + val am = UCXActiveMessage(reg.activeMessageId, header) + + logDebug(s"Active Message received: $am") + + val cb = reg.getCallback(header) + + if (amData.isDataValid) { + require(notForcingAmRndv, + s"Handling an eager Active Message, but we are using " + + s"'${rapidsConf.shuffleUcxActiveMessagesMode}' as our configured mode.") + logDebug(s"Handling an EAGER active message receive ${amData}") + val resp = UcxUtils.getByteBufferView(amData.getDataAddress, amData.getLength) + + // copy the data onto a buffer we own because it is going to be reused + // in UCX + val dbb = cb.onHostMessageReceived(amData.getLength) + val bb = dbb.getBuffer() + bb.put(resp) + bb.rewind() + cb.onSuccess(am, dbb) + + // we return OK telling UCX `amData` is ok to be closed, along with the eagerly + // received data + UcsConstants.STATUS.UCS_OK + } else { + // RNDV case: we get a direct buffer and UCX will fill it with data at `receive` + // callback + val resp = cb.onHostMessageReceived(amData.getLength) + + val receiveAm = amData.receive(UcxUtils.getAddress(resp.getBuffer()), + new UcxCallback { + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + withResource(resp) { _ => + withResource(amData) { _ => + if (ucsStatus == UCX.UCS_ERR_CANCELED) { + logWarning( + s"Cancelled Active Message " + + s"${TransportUtils.toHex(reg.activeMessageId)}" + + s" status=$ucsStatus, msg=$errorMsg") + cb.onCancel(am) + } else { + cb.onError(am, ucsStatus, errorMsg) + } + } + } + } + + override def onSuccess(request: UcpRequest): Unit = { + withResource(amData) { _ => + cb.onSuccess(am, resp) + } + } + }) + + cb.onMessageStarted(receiveAm) + UcsConstants.STATUS.UCS_INPROGRESS + } + } + }) + }) + } + + // If we are not forcing RNDV (i.e. we are in auto or eager) other handling + // can happen when we receive an Active Message message (it can contain + // inline data that must be copied out in the callback). + private lazy val notForcingAmRndv: Boolean = { + !rapidsConf.shuffleUcxActiveMessagesMode + .equalsIgnoreCase("rndv") + } + + private lazy val activeMessageMode: Long = { + rapidsConf.shuffleUcxActiveMessagesMode match { + case "eager" => + UcpConstants.UCP_AM_SEND_FLAG_EAGER + case "rndv" => + UcpConstants.UCP_AM_SEND_FLAG_RNDV + case "auto" => + 0L + case _ => + throw new IllegalArgumentException( + s"${rapidsConf.shuffleUcxActiveMessagesMode} is an invalid Active Message mode. " + + s"Please ensure that ${RapidsConf.SHUFFLE_UCX_ACTIVE_MESSAGES_MODE.key} is set correctly") + } + } + + def sendActiveMessage(endpointId: Long, am: UCXActiveMessage, + dataAddress: Long, dataSize: Long, cb: UcxCallback): Unit = { + onWorkerThreadAsync(() => { + val ep = endpoints.get(endpointId) + if (ep == null) { + throw new IllegalStateException( + s"Trying to send a message to an endpoint that doesn't exist ${endpointId}") + } + logDebug(s"Sending $am msg of size $dataSize") + + // This isn't coming from the pool right now because it would be a bit of a + // waste to get a larger hard-partitioned buffer just for 8 bytes. + // TODO: since we no longer have metadata limits, the pool can be managed using the + // address-space allocator, so we should obtain this direct buffer from that pool + val header = ByteBuffer.allocateDirect(8) + header.putLong(am.header) + header.rewind() + + ep.sendAmNonBlocking( + am.activeMessageId, + TransportUtils.getAddress(header), + header.remaining(), + dataAddress, + dataSize, + activeMessageMode, + new UcxCallback { + override def onSuccess(request: UcpRequest): Unit = { + cb.onSuccess(request) + RapidsStorageUtils.dispose(header) + } + + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + cb.onError(ucsStatus, errorMsg) + RapidsStorageUtils.dispose(header) + } + }) + }) + } def getServerConnection: UCXServerConnection = serverConnection @@ -353,7 +610,7 @@ class UCX(executor: BlockManagerId, rapidsConf: RapidsConf) extends AutoCloseabl override def onError(ucsStatus: Int, errorMsg: String): Unit = { if (ucsStatus == UCX.UCS_ERR_CANCELED) { logWarning( - s"Cancelled: tag=${TransportUtils.formatTag(alt.tag)}," + + s"Cancelled: tag=${TransportUtils.toHex(alt.tag)}," + s" status=$ucsStatus, msg=$errorMsg") cb.onCancel(alt) } else { @@ -363,13 +620,13 @@ class UCX(executor: BlockManagerId, rapidsConf: RapidsConf) extends AutoCloseabl } override def onSuccess(request: UcpRequest): Unit = { - logTrace(s"Success receiving calling callback ${TransportUtils.formatTag(alt.tag)}") + logTrace(s"Success receiving calling callback ${TransportUtils.toHex(alt.tag)}") cb.onSuccess(alt) } } onWorkerThreadAsync(() => { - logTrace(s"Handling receive for tag ${TransportUtils.formatTag(alt.tag)}") + logTrace(s"Handling receive for tag ${TransportUtils.toHex(alt.tag)}") val request = worker.recvTaggedNonBlocking( alt.address, alt.length, @@ -391,7 +648,7 @@ class UCX(executor: BlockManagerId, rapidsConf: RapidsConf) extends AutoCloseabl }) } - private[ucx] def assignResponseTag(): Long = responseTag.incrementAndGet() + def assignResponseTag(): Long = responseTag.incrementAndGet() private lazy val ucxAddress: ByteBuffer = if (rapidsConf.shuffleUcxUseSockaddr) { val listenerAddress = listener.get.getAddress @@ -418,7 +675,7 @@ class UCX(executor: BlockManagerId, rapidsConf: RapidsConf) extends AutoCloseabl * @return returns a [[UcpEndpoint]] that can later be used to send on (from the * progress thread) */ - private[ucx] def setupEndpoint( + def setupEndpoint( endpointId: Long, workerAddress: WorkerAddress, peerRkeys: Rkeys): UcpEndpoint = { logDebug(s"Starting/reusing an endpoint to $workerAddress with id $endpointId") // create an UCX endpoint using workerAddress or socket address @@ -454,7 +711,8 @@ class UCX(executor: BlockManagerId, rapidsConf: RapidsConf) extends AutoCloseabl peerMgmtPort: Int): ClientConnection = { val getConnectionStartTime = System.currentTimeMillis() val result = connectionCache.computeIfAbsent(peerExecutorId, _ => { - val connection = new UCXClientConnection(peerExecutorId, peerTag.incrementAndGet(), this) + val connection = new UCXClientConnection( + peerExecutorId, peerTag.incrementAndGet(), this, transport) startConnection(connection, peerMgmtHost, peerMgmtPort) connection }) @@ -463,7 +721,7 @@ class UCX(executor: BlockManagerId, rapidsConf: RapidsConf) extends AutoCloseabl result } - private[ucx] def onWorkerThreadAsync(task: () => Unit): Unit = { + def onWorkerThreadAsync(task: () => Unit): Unit = { workerTasks.add(task) if (rapidsConf.shuffleUcxUseWakeup) { worker.signal() @@ -475,16 +733,14 @@ class UCX(executor: BlockManagerId, rapidsConf: RapidsConf) extends AutoCloseabl peerMgmtHost: String, peerMgmtPort: Int) = { logInfo(s"Connecting to $peerMgmtHost:$peerMgmtPort") - val nvtx = new NvtxRange(s"UCX Connect to $peerMgmtHost:$peerMgmtPort", NvtxColor.RED) - try { - val socket = new Socket(peerMgmtHost, peerMgmtPort) - try { + withResource(new NvtxRange(s"UCX Connect to $peerMgmtHost:$peerMgmtPort", NvtxColor.RED)) { _ => + withResource(new Socket(peerMgmtHost, peerMgmtPort)) { socket => socket.setTcpNoDelay(true) val os = socket.getOutputStream val is = socket.getInputStream // this executor id will receive on tmpLocalReceiveTag for this Connection - UCXConnection.writeHandshakeHeader(os, getUcxAddress, executor.executorId.toInt, localRkeys) + UCXConnection.writeHandshakeHeader(os, getUcxAddress, getExecutorId, localRkeys) // the remote executor will receive on remoteReceiveTag, and expects this executor to // receive on localReceiveTag @@ -501,30 +757,22 @@ class UCX(executor: BlockManagerId, rapidsConf: RapidsConf) extends AutoCloseabl }) logInfo(s"NEW OUTGOING UCX CONNECTION $connection") - } finally { - socket.close() } connection - } finally { - nvtx.close() } } - def assignPeerTag(peerExecutorId: Long): Long = - executorIdToPeerTag.computeIfAbsent(peerExecutorId, _ => peerTag.incrementAndGet()) - /** * Handle an incoming connection on the TCP management port * This will fetch the [[WorkerAddress]] from the peer, and establish a UcpEndpoint * * @param socket an accepted socket to a remote client */ - private[ucx] def handleSocket(socket: Socket): Unit = { - val connectionRange = - new NvtxRange(s"UCX Handle Connection from ${socket.getInetAddress}", NvtxColor.RED) - try { + private def handleSocket(socket: Socket): Unit = { + withResource(new NvtxRange(s"UCX Handle Connection from ${socket.getInetAddress}", + NvtxColor.RED)) { _ => logDebug(s"Reading worker address from: $socket") - try { + withResource(socket) { _ => val is = socket.getInputStream val os = socket.getOutputStream @@ -535,7 +783,7 @@ class UCX(executor: BlockManagerId, rapidsConf: RapidsConf) extends AutoCloseabl logInfo(s"Got peer worker address from executor $peerExecutorId") // ack what we saw as the local and remote peer tags - UCXConnection.writeHandshakeHeader(os, getUcxAddress, executor.executorId.toInt, localRkeys) + UCXConnection.writeHandshakeHeader(os, getUcxAddress, getExecutorId, localRkeys) onWorkerThreadAsync(() => { setupEndpoint(peerExecutorId, peerWorkerAddress, peerRkeys) @@ -543,13 +791,7 @@ class UCX(executor: BlockManagerId, rapidsConf: RapidsConf) extends AutoCloseabl // peer would have established an endpoint peer -> local logInfo(s"Sent server UCX worker address to executor $peerExecutorId") - } finally { - // at this point we have handshaked, UCX is ready to go for this point-to-point connection. - // assume that we get a list of block ids, tag tuples we want to transfer out - socket.close() } - } finally { - connectionRange.close() } } @@ -608,6 +850,12 @@ class UCX(executor: BlockManagerId, rapidsConf: RapidsConf) extends AutoCloseabl override def close(): Unit = { onWorkerThreadAsync(() => { + amRegistrations.forEach { (activeMessageId, _) => + logDebug(s"Removing Active Message registration for " + + s"${TransportUtils.toHex(activeMessageId)}") + worker.removeAmRecvHandler(activeMessageId) + } + logInfo(s"De-registering UCX ${registeredMemory.size} memory buffers.") registeredMemory.synchronized { registeredMemory.foreach(_.deregister()) @@ -663,4 +911,7 @@ object UCX { // We may consider matching tags partially for different request types private val MATCH_FULL_TAG: Long = 0xFFFFFFFFFFFFFFFFL + + def formatAmIdAndHeader(activeMessageId: Int, header: Long) = + s"[amId=${TransportUtils.toHex(activeMessageId)}, hdr=${TransportUtils.toHex(header)}]" } diff --git a/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCXConnection.scala b/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCXConnection.scala index 4af67739663..3098bab3739 100644 --- a/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCXConnection.scala +++ b/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCXConnection.scala @@ -24,14 +24,19 @@ import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{NvtxColor, NvtxRange} import com.nvidia.spark.rapids.shuffle._ +import org.openucx.jucx.UcxCallback import org.openucx.jucx.ucp.UcpRequest import org.apache.spark.internal.Logging /** - * This is a private api used within the ucx package. - * It is used by [[Transaction]] to call into the UCX functions. It adds the tag - * as we use that to track the message and for debugging. + * These are private apis used within the ucx package. + */ + +/** + * `UCXTagCallback` is used by [[Transaction]] to handle UCX tag-based operations. + * It adds the `AddressLengthTag` instance as we use that to track the message and + * for debugging. */ private[ucx] abstract class UCXTagCallback { def onError(alt: AddressLengthTag, ucsStatus: Int, errorMsg: String): Unit @@ -40,21 +45,98 @@ private[ucx] abstract class UCXTagCallback { def onCancel(alt: AddressLengthTag): Unit } -class UCXServerConnection(ucx: UCX) extends UCXConnection(ucx) with ServerConnection { +/** + * `UCXAmCallback` is used by [[Transaction]] to handle UCX Active Messages operations. + * The `UCXActiveMessage` object encapsulates an activeMessageId and a header. + */ +private[ucx] abstract class UCXAmCallback { + def onError(am: UCXActiveMessage, ucsStatus: Int, errorMsg: String): Unit + def onMessageStarted(receiveAm: UcpRequest): Unit + def onSuccess(am: UCXActiveMessage, buff: RefCountedDirectByteBuffer): Unit + def onCancel(am: UCXActiveMessage): Unit + + // hook to allocate memory on the host + // a similar hook will be needed for GPU memory + def onHostMessageReceived(size: Long): RefCountedDirectByteBuffer +} + +class UCXServerConnection(ucx: UCX, transport: UCXShuffleTransport) + extends UCXConnection(ucx) with ServerConnection with Logging { override def startManagementPort(host: String): Int = { ucx.startManagementPort(host) } - override def send(sendPeerExecutorId: Long, bounceBuffers: Seq[AddressLengthTag], + override def send(sendPeerExecutorId: Long, buffer: AddressLengthTag, cb: TransactionCallback): Transaction = - send(sendPeerExecutorId, null, bounceBuffers, cb) + send(sendPeerExecutorId, buffer, Seq.empty, cb) - override def send(sendPeerExecutorId: Long, header: AddressLengthTag, - cb: TransactionCallback): Transaction = - send(sendPeerExecutorId, header, Seq.empty, cb) + override def registerRequestHandler( + requestType: RequestType.Value, cb: TransactionCallback): Unit = { + + ucx.registerRequestHandler( + UCXConnection.composeRequestAmId(requestType), () => new UCXAmCallback { + private val tx = createTransaction + + tx.start(UCXTransactionType.Request, 1, cb) + + override def onSuccess(am: UCXActiveMessage, buff: RefCountedDirectByteBuffer): Unit = { + logDebug(s"At requestHandler for ${requestType} and am: " + + s"$am") + tx.completeWithSuccess(requestType, Option(am.header), Option(buff)) + } + + override def onHostMessageReceived(size: Long): RefCountedDirectByteBuffer = { + transport.getDirectByteBuffer(size) + } + + override def onError(am: UCXActiveMessage, ucsStatus: Int, errorMsg: String): Unit = { + tx.completeWithError(errorMsg) + } + + override def onCancel(am: UCXActiveMessage): Unit = { + tx.completeCancelled(requestType, am.header) + } + + override def onMessageStarted(receiveAm: UcpRequest): Unit = { + tx.registerPendingMessage(receiveAm) + } + }) + } + + override def respond(peerExecutorId: Long, + messageType: RequestType.Value, + header: Long, + response: ByteBuffer, + cb: TransactionCallback): Transaction = { + + val tx = createTransaction + tx.start(UCXTransactionType.Request, 1, cb) + + logDebug(s"Responding to ${peerExecutorId} at ${TransportUtils.toHex(header)} " + + s"with ${response}") + + val responseAm = UCXActiveMessage( + UCXConnection.composeResponseAmId(messageType), header) + + ucx.sendActiveMessage(peerExecutorId, responseAm, + TransportUtils.getAddress(response), response.remaining(), + new UcxCallback { + override def onSuccess(request: UcpRequest): Unit = { + logDebug(s"AM success respond $responseAm") + tx.complete(TransactionStatus.Success) + } + + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + logError(s"AM Error responding ${ucsStatus} ${errorMsg} for $responseAm") + tx.completeWithError(errorMsg) + } + }) + tx + } } -class UCXClientConnection(peerExecutorId: Int, peerClientId: Long, ucx: UCX) +class UCXClientConnection(peerExecutorId: Int, peerClientId: Long, + ucx: UCX, transport: UCXShuffleTransport) extends UCXConnection(peerExecutorId, ucx) with ClientConnection { @@ -66,108 +148,79 @@ class UCXClientConnection(peerExecutorId: Int, peerClientId: Long, ucx: UCX) logInfo(s"UCX Client $this started") - // tag used for a unique response to a request initiated by [[Transaction.request]] - override def assignResponseTag: Long = composeResponseTag(peerClientId, ucx.assignResponseTag()) - - override def assignBufferTag(msgId: Int): Long = composeBufferTag(peerClientId, msgId) + override def assignBufferTag(msgId: Int): Long = + UCXConnection.composeBufferTag(peerClientId, msgId) override def getPeerExecutorId: Long = peerExecutorId override def request( - request: AddressLengthTag, - response: AddressLengthTag, + requestType: RequestType.Value, request: ByteBuffer, cb: TransactionCallback): Transaction = { val tx = createTransaction + tx.start(UCXTransactionType.Request, 1, cb) - tx.start(UCXTransactionType.Request, 2, cb) + // this header is unique, so we can send it with the request + // expecting it to be echoed back in the response + val requestHeader = UCXConnection.composeRequestHeader(ucx.getExecutorId.toLong, tx.txId) - logDebug(s"Performing header request on tag ${TransportUtils.formatTag(request.tag)} " + - s"for tx $tx") - send(peerExecutorId, request, Seq.empty, (sendTx: Transaction) => { - logDebug(s"UCX request send callback $sendTx") - if (sendTx.getStatus == TransactionStatus.Success) { - tx.incrementSendSize(request.length) - if (tx.decrementPendingAndGet <= 0) { - logDebug(s"Header request is done on send: ${sendTx.getStatus}, " + - s"tag: ${TransportUtils.formatTag(request.tag)} for $tx") - tx.txCallback(TransactionStatus.Success) - } + // This is the response active message handler, when the response shows up + // we'll create a transaction, and set header/message and complete it. + val amCallback = new UCXAmCallback { + override def onHostMessageReceived(size: Long): RefCountedDirectByteBuffer = { + transport.getDirectByteBuffer(size.toInt) } - sendTx.close() - }) - receive(response, receiveTx => { - logDebug(s"UCX request receive callback $receiveTx") - if (receiveTx.getStatus == TransactionStatus.Success) { - tx.incrementReceiveSize(response.length) - if (tx.decrementPendingAndGet <= 0) { - logDebug(s"Header request is done on receive: $this, " + - s"tag: ${TransportUtils.formatTag(response.tag)}") - tx.txCallback(TransactionStatus.Success) - } + override def onSuccess(am: UCXActiveMessage, buff: RefCountedDirectByteBuffer): Unit = { + tx.completeWithSuccess(requestType, Option(am.header), Option(buff)) } - receiveTx.close() - }) - tx - } -} - -class UCXConnection(peerExecutorId: Int, ucx: UCX) extends Connection with Logging { - // alternate constructor for a server connection (-1 because it is not to a peer) - def this(ucx: UCX) = this(-1, ucx) + override def onError(am: UCXActiveMessage, ucsStatus: Int, errorMsg: String): Unit = { + tx.completeWithError(errorMsg) + } - private[this] val pendingTransactions = new ConcurrentHashMap[Long, UCXTransaction]() + override def onMessageStarted(receiveAm: UcpRequest): Unit = { + tx.registerPendingMessage(receiveAm) + } - /** - * 1) client gets upper 28 bits - * 2) then comes the type, which gets 4 bits - * 3) the remaining 32 bits are used for buffer specific tags - */ - private val requestMsgType: Long = 0x00000000000000000L - private val responseMsgType: Long = 0x0000000000000000AL - private val bufferMsgType: Long = 0x0000000000000000BL - - override def composeRequestTag(requestType: RequestType.Value): Long = { - val requestTypeId = requestType match { - case RequestType.MetadataRequest => 0 - case RequestType.TransferRequest => 1 + override def onCancel(am: UCXActiveMessage): Unit = { + tx.completeCancelled(requestType, am.header) + } } - requestMsgType | requestTypeId - } - protected def composeResponseTag(peerClientId: Long, bufferTag: Long): Long = { - // response tags are [peerClientId, 1, bufferTag] - composeTag(composeUpperBits(peerClientId, responseMsgType), bufferTag) - } + // Register the active message response handler. Note that the `requestHeader` + // is expected to come back with the response, and is used to find the + // correct callback (this is an implementation detail in UCX.scala) + ucx.registerResponseHandler( + UCXConnection.composeResponseAmId(requestType), requestHeader, amCallback) - protected def composeBufferTag(peerClientId: Long, bufferTag: Long): Long = { - // buffer tags are [peerClientId, 0, bufferTag] - composeTag(composeUpperBits(peerClientId, bufferMsgType), bufferTag) - } + // kick-off the request + val requestAm = UCXActiveMessage( + UCXConnection.composeRequestAmId(requestType), requestHeader) - private def composeTag(upperBits: Long, lowerBits: Long): Long = { - if ((upperBits & 0xFFFFFFFF00000000L) != upperBits) { - throw new IllegalArgumentException( - s"Invalid tag, upperBits would alias: ${TransportUtils.formatTag(upperBits)}") - } - // the lower 32bits aliasing is not a big deal, we expect it with msg rollover - // so we don't check for it - upperBits | (lowerBits & 0x00000000FFFFFFFFL) - } + logDebug(s"Performing a ${requestType} request of size ${request.remaining()} " + + s"with tx ${tx}. Active messages: request $requestAm") - private def composeUpperBits(peerClientId: Long, msgType: Long): Long = { - if ((peerClientId & 0x000000000FFFFFFFL) != peerClientId) { - throw new IllegalArgumentException( - s"Invalid tag, peerClientId would alias: ${TransportUtils.formatTag(peerClientId)}") - } - if ((msgType & 0x000000000000000FL) != msgType) { - throw new IllegalArgumentException( - s"Invalid tag, msgType would alias: ${TransportUtils.formatTag(msgType)}") - } - (peerClientId << 36) | (msgType << 32) + ucx.sendActiveMessage(peerExecutorId, requestAm, + TransportUtils.getAddress(request), request.remaining(), + new UcxCallback { + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + tx.completeWithError(errorMsg) + } + // we don't handle `onSuccess` here, because we want the response + // to complete that + }) + + tx } +} + +class UCXConnection(peerExecutorId: Int, val ucx: UCX) extends Connection with Logging { + // alternate constructor for a server connection (-1 because it is not to a peer) + def this(ucx: UCX) = this(-1, ucx) + + private[this] val pendingTransactions = new ConcurrentHashMap[Long, UCXTransaction]() + private[ucx] def send(executorId: Long, alt: AddressLengthTag, ucxCallback: UCXTagCallback): Unit = { val sendRange = new NvtxRange("Connection Send", NvtxColor.PURPLE) @@ -187,26 +240,27 @@ class UCXConnection(peerExecutorId: Int, ucx: UCX) extends Connection with Loggi val numMessages = if (header != null) buffers.size + 1 else buffers.size + tx.start(UCXTransactionType.Send, numMessages, cb) val ucxCallback = new UCXTagCallback { override def onError(alt: AddressLengthTag, ucsStatus: Int, errorMsg: String): Unit = { - tx.handleTagError(alt.tag, errorMsg) + tx.handleTagError(alt.tag) logError(s"Error sending: $errorMsg, tx: $tx") - tx.txCallback(TransactionStatus.Error) + tx.completeWithError(errorMsg) } override def onSuccess(alt: AddressLengthTag): Unit = { - logDebug(s"Successful send: ${TransportUtils.formatTag(alt.tag)}, tx = $tx") + logDebug(s"Successful send: ${TransportUtils.toHex(alt.tag)}, tx = $tx") tx.handleTagCompleted(alt.tag) if (tx.decrementPendingAndGet <= 0) { - tx.txCallback(TransactionStatus.Success) + tx.complete(TransactionStatus.Success) } } override def onCancel(alt: AddressLengthTag): Unit = { tx.handleTagCancelled(alt.tag) - tx.txCallback(TransactionStatus.Cancelled) + tx.complete(TransactionStatus.Cancelled) } override def onMessageStarted(ucxMessage: UcpRequest): Unit = { @@ -217,14 +271,14 @@ class UCXConnection(peerExecutorId: Int, ucx: UCX) extends Connection with Loggi // send the header if (header != null) { logDebug(s"Sending meta [executor_id=$sendPeerExecutorId, " + - s"tag=${TransportUtils.formatTag(header.tag)}, size=${header.length}]") + s"tag=${TransportUtils.toHex(header.tag)}, size=${header.length}]") send(sendPeerExecutorId, header, ucxCallback) tx.incrementSendSize(header.length) } buffers.foreach { alt => logDebug(s"Sending [executor_id=$sendPeerExecutorId, " + - s"tag=${TransportUtils.formatTag(alt.tag)}, size=${alt.length}]") + s"tag=${TransportUtils.toHex(alt.tag)}, size=${alt.length}]") send(sendPeerExecutorId, alt, ucxCallback) tx.incrementSendSize(alt.length) } @@ -238,23 +292,23 @@ class UCXConnection(peerExecutorId: Int, ucx: UCX) extends Connection with Loggi tx.start(UCXTransactionType.Receive, 1, cb) val ucxCallback = new UCXTagCallback { override def onError(alt: AddressLengthTag, ucsStatus: Int, errorMsg: String): Unit = { - logError(s"Got an error... for tag: ${TransportUtils.formatTag(alt.tag)}, tx $tx") - tx.handleTagError(alt.tag, errorMsg) - tx.txCallback(TransactionStatus.Error) + logError(s"Got an error... for tag: ${TransportUtils.toHex(alt.tag)}, tx $tx") + tx.handleTagError(alt.tag) + tx.completeWithError(errorMsg) } override def onSuccess(alt: AddressLengthTag): Unit = { - logDebug(s"Successful receive: ${TransportUtils.formatTag(alt.tag)}, tx $tx") + logDebug(s"Successful receive: ${TransportUtils.toHex(alt.tag)}, tx $tx") tx.handleTagCompleted(alt.tag) if (tx.decrementPendingAndGet <= 0) { - logDebug(s"Receive done for tag: ${TransportUtils.formatTag(alt.tag)}, tx $tx") - tx.txCallback(TransactionStatus.Success) + logDebug(s"Receive done for tag: ${TransportUtils.toHex(alt.tag)}, tx $tx") + tx.complete(TransactionStatus.Success) } } override def onCancel(alt: AddressLengthTag): Unit = { tx.handleTagCancelled(alt.tag) - tx.txCallback(TransactionStatus.Cancelled) + tx.complete(TransactionStatus.Cancelled) } override def onMessageStarted(ucxMessage: UcpRequest): Unit = { @@ -262,7 +316,7 @@ class UCXConnection(peerExecutorId: Int, ucx: UCX) extends Connection with Loggi } } - logDebug(s"Receiving [tag=${TransportUtils.formatTag(alt.tag)}, size=${alt.length}]") + logDebug(s"Receiving [tag=${TransportUtils.toHex(alt.tag)}, size=${alt.length}]") ucx.receive(alt, ucxCallback) tx.incrementReceiveSize(alt.length) tx @@ -292,6 +346,84 @@ class UCXConnection(peerExecutorId: Int, ucx: UCX) extends Connection with Loggi } object UCXConnection extends Logging { + /** + * 1) client gets upper 28 bits + * 2) then comes the type, which gets 4 bits + * 3) the remaining 32 bits are used for buffer specific tags + */ + private final val bufferMsgType: Long = 0x0000000000000000BL + + // message type mask for UCX tags. The only message type using tags + // is `bufferMsgType` + private final val msgTypeMask: Long = 0x000000000000000FL + + // UCX Active Message masks (we can use up to 5 bits for these ids) + private final val amRequestMask: Int = 0x0000000F + private final val amResponseMask: Int = 0x0000001F + + // We pick the 5th bit set to 1 as a "response" active message + private final val amResponseFlag: Int = 0x00000010 + + // pick up the lower and upper parts of a long + private final val lowerBitsMask: Long = 0x00000000FFFFFFFFL + private final val upperBitsMask: Long = 0xFFFFFFFF00000000L + + def composeBufferTag(peerClientId: Long, bufferTag: Long): Long = { + // buffer tags are [peerClientId, 0, bufferTag] + composeTag(composeUpperBits(peerClientId, bufferMsgType), bufferTag) + } + + def composeRequestAmId(requestType: RequestType.Value): Int = { + val amId = requestType.id + if ((amId & amRequestMask) != amId) { + throw new IllegalArgumentException( + s"Invalid request amId, it must be 4 bits: ${TransportUtils.toHex(amId)}") + } + amId + } + + def composeResponseAmId(requestType: RequestType.Value): Int = { + val amId = amResponseFlag | composeRequestAmId(requestType) + if ((amId & amResponseMask) != amId) { + throw new IllegalArgumentException( + s"Invalid response amId, it must be 5 bits: ${TransportUtils.toHex(amId)}") + } + amId + } + + def composeTag(upperBits: Long, lowerBits: Long): Long = { + if ((upperBits & upperBitsMask) != upperBits) { + throw new IllegalArgumentException( + s"Invalid tag, upperBits would alias: ${TransportUtils.toHex(upperBits)}") + } + // the lower 32bits aliasing is not a big deal, we expect it with msg rollover + // so we don't check for it + upperBits | (lowerBits & lowerBitsMask) + } + + private def composeUpperBits(peerClientId: Long, msgType: Long): Long = { + if ((peerClientId & lowerBitsMask) != peerClientId) { + throw new IllegalArgumentException( + s"Invalid tag, peerClientId would alias: ${TransportUtils.toHex(peerClientId)}") + } + if ((msgType & msgTypeMask) != msgType) { + throw new IllegalArgumentException( + s"Invalid tag, msgType would alias: ${TransportUtils.toHex(msgType)}") + } + (peerClientId << 36) | (msgType << 32) + } + + def composeRequestHeader(executorId: Long, txId: Long): Long = { + require(executorId >= 0, + s"Attempted to pack negative $executorId") + require((executorId & lowerBitsMask) == executorId, + s"ExecutorId would alias: ${TransportUtils.toHex(executorId)}") + composeTag(executorId << 32, txId) + } + + def extractExecutorId(header: Long): Long = { + (header >> 32) & lowerBitsMask + } // // Handshake message code. This, I expect, could be folded into the [[BlockManagerId]], // but I have not tried this. If we did, it would eliminate the extra TCP connection diff --git a/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCXShuffleTransport.scala b/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCXShuffleTransport.scala index a26b400efa3..8acbbe82ce1 100644 --- a/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCXShuffleTransport.scala +++ b/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCXShuffleTransport.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -68,7 +68,7 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon private[this] lazy val ucx = { logWarning("UCX Shuffle Transport Enabled") - val ucxImpl = new UCX(shuffleServerId, rapidsConf) + val ucxImpl = new UCX(this, shuffleServerId, rapidsConf) ucxImpl.init() initBounceBufferPools(bounceBufferSize, @@ -88,7 +88,7 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon ucxImpl } - override def getMetaBuffer(size: Long): RefCountedDirectByteBuffer = { + override def getDirectByteBuffer(size: Long): RefCountedDirectByteBuffer = { if (size > rapidsConf.shuffleMaxMetadataSize) { logWarning(s"Large metadata message size $size B, larger " + s"than ${rapidsConf.shuffleMaxMetadataSize} B. " + @@ -251,8 +251,7 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon clientConnection, this, clientExecutor, - clientCopyExecutor, - rapidsConf.shuffleMaxMetadataSize) + clientCopyExecutor) }) } diff --git a/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCXTransaction.scala b/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCXTransaction.scala index 69523da91f3..26c84fc6238 100644 --- a/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCXTransaction.scala +++ b/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCXTransaction.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package com.nvidia.spark.rapids.shuffle.ucx +import java.nio.ByteBuffer import java.util.concurrent.{ConcurrentLinkedQueue, TimeUnit} import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.locks.ReentrantLock @@ -24,7 +25,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{NvtxColor, NvtxRange} -import com.nvidia.spark.rapids.shuffle.{AddressLengthTag, Transaction, TransactionCallback, TransactionStats, TransactionStatus, TransportUtils} +import com.nvidia.spark.rapids.shuffle.{AddressLengthTag, RefCountedDirectByteBuffer, RequestType, Transaction, TransactionCallback, TransactionStats, TransactionStatus, TransportUtils} import org.openucx.jucx.ucp.UcpRequest import org.apache.spark.internal.Logging @@ -42,6 +43,12 @@ private[ucx] object UCXTransactionType extends Enumeration { private[ucx] class UCXTransaction(conn: UCXConnection, val txId: Long) extends Transaction with Logging { + // Active Messages: header used to disambiguate responses for a request + private var header: Option[Long] = None + + // Type of request this transaction is handling, used to simplify the `respond` method + private var messageType: Option[RequestType.Value] = None + // various threads can access the status during the course of a Transaction // the UCX progress thread, client/server pools, and the executor task thread @volatile private[this] var status = TransactionStatus.NotStarted @@ -67,14 +74,12 @@ private[ucx] class UCXTransaction(conn: UCXConnection, val txId: Long) * This will mark the tag as having an error for debugging purposes. * * @param tag the tag involved in the error - * @param errorMsg error description from UCX */ - def handleTagError(tag: Long, errorMsg: String): Unit = { + def handleTagError(tag: Long): Unit = { if (registeredByTag.contains(tag)) { val origBuff = registeredByTag(tag) errored += origBuff } - errorMessage = Some(errorMsg) } /** @@ -106,7 +111,7 @@ private[ucx] class UCXTransaction(conn: UCXConnection, val txId: Long) private var hadError: Boolean = false - private[ucx] var txCallback: TransactionStatus.Value => Unit = _ + private var txCallback: TransactionStatus.Value => Unit = _ // Start and end times used for metrics private var start: Long = 0 @@ -244,7 +249,7 @@ private[ucx] class UCXTransaction(conn: UCXConnection, val txId: Long) def registerForSend(alt: AddressLengthTag): Unit = { registeredByTag.put(alt.tag, alt) registered += alt - logTrace(s"Assigned tag for send ${TransportUtils.formatTag(alt.tag)} for message at " + + logTrace(s"Assigned tag for send ${TransportUtils.toHex(alt.tag)} for message at " + s"buffer ${alt.address} with size ${alt.length}") } @@ -254,7 +259,7 @@ private[ucx] class UCXTransaction(conn: UCXConnection, val txId: Long) def registerForReceive(alt: AddressLengthTag): Unit = { registered += alt registeredByTag.put(alt.tag, alt) - logTrace(s"Assigned tag for receive ${TransportUtils.formatTag(alt.tag)} for message at " + + logTrace(s"Assigned tag for receive ${TransportUtils.toHex(alt.tag)} for message at " + s"buffer ${alt.address} with size ${alt.length}") } @@ -323,6 +328,9 @@ private[ucx] class UCXTransaction(conn: UCXConnection, val txId: Long) hadError = true } } + // close any active message we may have + activeMessageData.foreach(_.close()) + activeMessageData = None } catch { case t: Throwable => if (ex == null) { @@ -360,5 +368,86 @@ private[ucx] class UCXTransaction(conn: UCXConnection, val txId: Long) } var callbackCalled: Boolean = false + + private var activeMessageData: Option[RefCountedDirectByteBuffer] = None + + override def respond(response: ByteBuffer, + cb: TransactionCallback): Transaction = { + logDebug(s"Responding to ${peerExecutorId} at ${TransportUtils.toHex(this.getHeader)} " + + s"with ${response}") + + conn match { + case serverConnection: UCXServerConnection => + serverConnection.respond(peerExecutorId(), messageType.get, this.getHeader, response, cb) + case _ => + throw new IllegalStateException("Tried to respond using a client connection. " + + "This is not supported.") + } + } + + def complete(status: TransactionStatus.Value, + messageType: Option[RequestType.Value] = None, + header: Option[Long] = None, + message: Option[RefCountedDirectByteBuffer] = None, + errorMessage: Option[String] = None): Unit = { + setHeader(header) + setActiveMessageData(message) + setMessageType(messageType) + setErrorMessage(errorMessage) + setHeader(header) + txCallback(status) + } + + def completeWithError(errorMsg: String): Unit = { + complete(TransactionStatus.Error, + errorMessage = Option(errorMsg)) + } + + def completeCancelled(requestType: RequestType.Value, hdr: Long): Unit = { + complete(TransactionStatus.Cancelled, + messageType = Option(requestType), + header = Option(hdr)) + } + + def completeWithSuccess( + messageType: RequestType.Value, + hdr: Option[Long], + message: Option[RefCountedDirectByteBuffer]): Unit = { + complete(TransactionStatus.Success, + messageType = Option(messageType), + header = hdr, + message = message) + } + + // Reference count is not updated here. The caller is responsible to close + private[ucx] def setActiveMessageData(data: Option[RefCountedDirectByteBuffer]): Unit = { + activeMessageData = data + } + + // Reference count is not updated here. The caller is responsible to close + override def releaseMessage(): RefCountedDirectByteBuffer = { + val msg = activeMessageData.get + activeMessageData = None + msg + } + + private[ucx] def setHeader(id: Option[Long]): Unit = header = id + + override def getHeader: Long = { + require(header.nonEmpty, + "Attempted to get an Active Message header, but it was not set!") + header.get + } + + private[ucx] def setMessageType(msgType: Option[RequestType.Value]): Unit = { + messageType = msgType + } + + private[ucx] def setErrorMessage(errorMsg: Option[String]): Unit = { + errorMessage = errorMessage + } + + override def peerExecutorId(): Long = + UCXConnection.extractExecutorId(getHeader) } diff --git a/sql-plugin/src/main/format/ShuffleMetadataRequest.fbs b/sql-plugin/src/main/format/ShuffleMetadataRequest.fbs index b83e60c0cee..78f9c3477d4 100644 --- a/sql-plugin/src/main/format/ShuffleMetadataRequest.fbs +++ b/sql-plugin/src/main/format/ShuffleMetadataRequest.fbs @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2020, NVIDIA CORPORATION. +// Copyright (c) 2019-2021, NVIDIA CORPORATION. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,15 +24,6 @@ table BlockIdMeta { /// Flat buffer for Rapids UCX Shuffle Metadata Request. table MetadataRequest { - /// Spark executor ID - executor_id: long; - - /// UCX message tag to use when sending the response - response_tag: long; - - /// maximum size in bytes for the response message. - max_response_size: long; - /// array of shuffle block descriptors for which metadata is needed block_ids : [BlockIdMeta]; } diff --git a/sql-plugin/src/main/format/ShuffleMetadataResponse.fbs b/sql-plugin/src/main/format/ShuffleMetadataResponse.fbs index 9bb346ac7f4..a7f7754f3ab 100644 --- a/sql-plugin/src/main/format/ShuffleMetadataResponse.fbs +++ b/sql-plugin/src/main/format/ShuffleMetadataResponse.fbs @@ -30,11 +30,6 @@ table TableMeta { /// Flat buffer for Rapids UCX Shuffle Metadata Response table MetadataResponse { - /// Buffer size in bytes required to hold the full response. If this value is larger than the - /// maximum response size sent in the corresponding request then the metadata contents in this - /// response are incomplete and must be re-requested with a larger response buffer allocated. - full_response_size: long; - /// metadata for each table table_metas: [TableMeta]; } diff --git a/sql-plugin/src/main/format/ShuffleTransferRequest.fbs b/sql-plugin/src/main/format/ShuffleTransferRequest.fbs index 5639fd93c40..bee2fe11b6d 100644 --- a/sql-plugin/src/main/format/ShuffleTransferRequest.fbs +++ b/sql-plugin/src/main/format/ShuffleTransferRequest.fbs @@ -24,12 +24,6 @@ table BufferTransferRequest { /// Flat buffer for Rapids UCX Shuffle Transfer Request. table TransferRequest { - /// peer executor id to send response to - executor_id: long; - - /// UCX message tag to use when sending the response - response_tag: long; - /// array of table requests to transfer requests : [BufferTransferRequest]; } diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/MetadataRequest.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/MetadataRequest.java index fc2849ae9eb..301d1a9a682 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/MetadataRequest.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/MetadataRequest.java @@ -17,46 +17,22 @@ public final class MetadataRequest extends Table { public void __init(int _i, ByteBuffer _bb) { bb_pos = _i; bb = _bb; vtable_start = bb_pos - bb.getInt(bb_pos); vtable_size = bb.getShort(vtable_start); } public MetadataRequest __assign(int _i, ByteBuffer _bb) { __init(_i, _bb); return this; } - /** - * Spark executor ID - */ - public long executorId() { int o = __offset(4); return o != 0 ? bb.getLong(o + bb_pos) : 0L; } - public boolean mutateExecutorId(long executor_id) { int o = __offset(4); if (o != 0) { bb.putLong(o + bb_pos, executor_id); return true; } else { return false; } } - /** - * UCX message tag to use when sending the response - */ - public long responseTag() { int o = __offset(6); return o != 0 ? bb.getLong(o + bb_pos) : 0L; } - public boolean mutateResponseTag(long response_tag) { int o = __offset(6); if (o != 0) { bb.putLong(o + bb_pos, response_tag); return true; } else { return false; } } - /** - * maximum size in bytes for the response message. - */ - public long maxResponseSize() { int o = __offset(8); return o != 0 ? bb.getLong(o + bb_pos) : 0L; } - public boolean mutateMaxResponseSize(long max_response_size) { int o = __offset(8); if (o != 0) { bb.putLong(o + bb_pos, max_response_size); return true; } else { return false; } } /** * array of shuffle block descriptors for which metadata is needed */ public BlockIdMeta blockIds(int j) { return blockIds(new BlockIdMeta(), j); } - public BlockIdMeta blockIds(BlockIdMeta obj, int j) { int o = __offset(10); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; } - public int blockIdsLength() { int o = __offset(10); return o != 0 ? __vector_len(o) : 0; } + public BlockIdMeta blockIds(BlockIdMeta obj, int j) { int o = __offset(4); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; } + public int blockIdsLength() { int o = __offset(4); return o != 0 ? __vector_len(o) : 0; } public static int createMetadataRequest(FlatBufferBuilder builder, - long executor_id, - long response_tag, - long max_response_size, int block_idsOffset) { - builder.startObject(4); - MetadataRequest.addMaxResponseSize(builder, max_response_size); - MetadataRequest.addResponseTag(builder, response_tag); - MetadataRequest.addExecutorId(builder, executor_id); + builder.startObject(1); MetadataRequest.addBlockIds(builder, block_idsOffset); return MetadataRequest.endMetadataRequest(builder); } - public static void startMetadataRequest(FlatBufferBuilder builder) { builder.startObject(4); } - public static void addExecutorId(FlatBufferBuilder builder, long executorId) { builder.addLong(0, executorId, 0L); } - public static void addResponseTag(FlatBufferBuilder builder, long responseTag) { builder.addLong(1, responseTag, 0L); } - public static void addMaxResponseSize(FlatBufferBuilder builder, long maxResponseSize) { builder.addLong(2, maxResponseSize, 0L); } - public static void addBlockIds(FlatBufferBuilder builder, int blockIdsOffset) { builder.addOffset(3, blockIdsOffset, 0); } + public static void startMetadataRequest(FlatBufferBuilder builder) { builder.startObject(1); } + public static void addBlockIds(FlatBufferBuilder builder, int blockIdsOffset) { builder.addOffset(0, blockIdsOffset, 0); } public static int createBlockIdsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } public static void startBlockIdsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } public static int endMetadataRequest(FlatBufferBuilder builder) { diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/MetadataResponse.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/MetadataResponse.java index 4e0c384f57e..63d2cdd311b 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/MetadataResponse.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/MetadataResponse.java @@ -17,32 +17,22 @@ public final class MetadataResponse extends Table { public void __init(int _i, ByteBuffer _bb) { bb_pos = _i; bb = _bb; vtable_start = bb_pos - bb.getInt(bb_pos); vtable_size = bb.getShort(vtable_start); } public MetadataResponse __assign(int _i, ByteBuffer _bb) { __init(_i, _bb); return this; } - /** - * Buffer size in bytes required to hold the full response. If this value is larger than the - * maximum response size sent in the corresponding request then the metadata contents in this - * response are incomplete and must be re-requested with a larger response buffer allocated. - */ - public long fullResponseSize() { int o = __offset(4); return o != 0 ? bb.getLong(o + bb_pos) : 0L; } - public boolean mutateFullResponseSize(long full_response_size) { int o = __offset(4); if (o != 0) { bb.putLong(o + bb_pos, full_response_size); return true; } else { return false; } } /** * metadata for each table */ public TableMeta tableMetas(int j) { return tableMetas(new TableMeta(), j); } - public TableMeta tableMetas(TableMeta obj, int j) { int o = __offset(6); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; } - public int tableMetasLength() { int o = __offset(6); return o != 0 ? __vector_len(o) : 0; } + public TableMeta tableMetas(TableMeta obj, int j) { int o = __offset(4); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; } + public int tableMetasLength() { int o = __offset(4); return o != 0 ? __vector_len(o) : 0; } public static int createMetadataResponse(FlatBufferBuilder builder, - long full_response_size, int table_metasOffset) { - builder.startObject(2); - MetadataResponse.addFullResponseSize(builder, full_response_size); + builder.startObject(1); MetadataResponse.addTableMetas(builder, table_metasOffset); return MetadataResponse.endMetadataResponse(builder); } - public static void startMetadataResponse(FlatBufferBuilder builder) { builder.startObject(2); } - public static void addFullResponseSize(FlatBufferBuilder builder, long fullResponseSize) { builder.addLong(0, fullResponseSize, 0L); } - public static void addTableMetas(FlatBufferBuilder builder, int tableMetasOffset) { builder.addOffset(1, tableMetasOffset, 0); } + public static void startMetadataResponse(FlatBufferBuilder builder) { builder.startObject(1); } + public static void addTableMetas(FlatBufferBuilder builder, int tableMetasOffset) { builder.addOffset(0, tableMetasOffset, 0); } public static int createTableMetasVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } public static void startTableMetasVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } public static int endMetadataResponse(FlatBufferBuilder builder) { diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/TransferRequest.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/TransferRequest.java index eb8b83a49a8..2f1bc1ceb7b 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/TransferRequest.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/TransferRequest.java @@ -17,38 +17,22 @@ public final class TransferRequest extends Table { public void __init(int _i, ByteBuffer _bb) { bb_pos = _i; bb = _bb; vtable_start = bb_pos - bb.getInt(bb_pos); vtable_size = bb.getShort(vtable_start); } public TransferRequest __assign(int _i, ByteBuffer _bb) { __init(_i, _bb); return this; } - /** - * peer executor id to send response to - */ - public long executorId() { int o = __offset(4); return o != 0 ? bb.getLong(o + bb_pos) : 0L; } - public boolean mutateExecutorId(long executor_id) { int o = __offset(4); if (o != 0) { bb.putLong(o + bb_pos, executor_id); return true; } else { return false; } } - /** - * UCX message tag to use when sending the response - */ - public long responseTag() { int o = __offset(6); return o != 0 ? bb.getLong(o + bb_pos) : 0L; } - public boolean mutateResponseTag(long response_tag) { int o = __offset(6); if (o != 0) { bb.putLong(o + bb_pos, response_tag); return true; } else { return false; } } /** * array of table requests to transfer */ public BufferTransferRequest requests(int j) { return requests(new BufferTransferRequest(), j); } - public BufferTransferRequest requests(BufferTransferRequest obj, int j) { int o = __offset(8); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; } - public int requestsLength() { int o = __offset(8); return o != 0 ? __vector_len(o) : 0; } + public BufferTransferRequest requests(BufferTransferRequest obj, int j) { int o = __offset(4); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; } + public int requestsLength() { int o = __offset(4); return o != 0 ? __vector_len(o) : 0; } public static int createTransferRequest(FlatBufferBuilder builder, - long executor_id, - long response_tag, int requestsOffset) { - builder.startObject(3); - TransferRequest.addResponseTag(builder, response_tag); - TransferRequest.addExecutorId(builder, executor_id); + builder.startObject(1); TransferRequest.addRequests(builder, requestsOffset); return TransferRequest.endTransferRequest(builder); } - public static void startTransferRequest(FlatBufferBuilder builder) { builder.startObject(3); } - public static void addExecutorId(FlatBufferBuilder builder, long executorId) { builder.addLong(0, executorId, 0L); } - public static void addResponseTag(FlatBufferBuilder builder, long responseTag) { builder.addLong(1, responseTag, 0L); } - public static void addRequests(FlatBufferBuilder builder, int requestsOffset) { builder.addOffset(2, requestsOffset, 0); } + public static void startTransferRequest(FlatBufferBuilder builder) { builder.startObject(1); } + public static void addRequests(FlatBufferBuilder builder, int requestsOffset) { builder.addOffset(0, requestsOffset, 0); } public static int createRequestsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } public static void startRequestsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } public static int endTransferRequest(FlatBufferBuilder builder) { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala index a7941034228..884f138878d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala @@ -261,26 +261,16 @@ object ShuffleMetadata extends Logging{ }.toArray } - def buildMetaResponse(tables: Seq[TableMeta], maximumResponseSize: Long): ByteBuffer = { + def buildMetaResponse(tables: Seq[TableMeta]): ByteBuffer = { val fbb = new FlatBufferBuilder(1024, bbFactory) val tableOffsets = copyTables(fbb, tables) val tableMetasOffset = MetadataResponse.createTableMetasVector(fbb, tableOffsets) - val finIndex = MetadataResponse.createMetadataResponse(fbb, 0, tableMetasOffset) + val finIndex = MetadataResponse.createMetadataResponse(fbb, tableMetasOffset) fbb.finish(finIndex) - val bb = fbb.dataBuffer() - val responseSize = bb.remaining() - if (responseSize > maximumResponseSize) { - throw new IllegalStateException("response size is bigger than what receiver wants") - } - val materializedResponse = ShuffleMetadata.getMetadataResponse(bb) - materializedResponse.mutateFullResponseSize(responseSize) - bb + fbb.dataBuffer() } - def buildShuffleMetadataRequest(executorId: Long, - responseTag: Long, - blockIds : Seq[ShuffleBlockBatchId], - maxResponseSize: Long) : ByteBuffer = { + def buildShuffleMetadataRequest(blockIds : Seq[ShuffleBlockBatchId]) : ByteBuffer = { val fbb = new FlatBufferBuilder(1024, bbFactory) val blockIdOffsets = blockIds.map { blockId => BlockIdMeta.createBlockIdMeta(fbb, @@ -290,8 +280,7 @@ object ShuffleMetadata extends Logging{ blockId.endReduceId) } val blockIdVectorOffset = MetadataRequest.createBlockIdsVector(fbb, blockIdOffsets.toArray) - val finIndex = MetadataRequest.createMetadataRequest(fbb, executorId, responseTag, - maxResponseSize, blockIdVectorOffset) + val finIndex = MetadataRequest.createMetadataRequest(fbb, blockIdVectorOffset) fbb.finish(finIndex) fbb.dataBuffer() } @@ -333,10 +322,7 @@ object ShuffleMetadata extends Logging{ fbb.dataBuffer() } - def buildTransferRequest( - localExecutorId: Long, - responseTag: Long, - toIssue: Seq[(TableMeta, Long)]): ByteBuffer = { + def buildTransferRequest(toIssue: Seq[(TableMeta, Long)]): ByteBuffer = { val fbb = ShuffleMetadata.getBuilder val requestIds = new ArrayBuffer[Int](toIssue.size) toIssue.foreach { case (tableMeta, tag) => @@ -347,8 +333,7 @@ object ShuffleMetadata extends Logging{ tag)) } val requestVec = TransferRequest.createRequestsVector(fbb, requestIds.toArray) - val transferRequestOffset = TransferRequest.createTransferRequest(fbb, localExecutorId, - responseTag, requestVec) + val transferRequestOffset = TransferRequest.createTransferRequest(fbb, requestVec) fbb.finish(transferRequestOffset) fbb.dataBuffer() } @@ -360,8 +345,7 @@ object ShuffleMetadata extends Logging{ out.append("------------------------------------------------------------------------------\n") for (tableIndex <- 0 until res.tableMetasLength()) { val tableMeta = res.tableMetas(tableIndex) - out.append(s"table: $tableIndex rows=${tableMeta.rowCount}, " + - s"full_content_size=${res.fullResponseSize()}]\n") + out.append(s"table: $tableIndex rows=${tableMeta.rowCount}") } out.append(s"----------------------- END METADATA RESPONSE $state ----------------------\n") out.toString() diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index bd01d2b120e..5a6e5a901a1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -935,6 +935,14 @@ object RapidsConf { .bytesConf(ByteUnit.BYTE) .createWithDefault(1024 * 1024 * 1024) + val SHUFFLE_UCX_ACTIVE_MESSAGES_MODE = conf("spark.rapids.shuffle.ucx.activeMessages.mode") + .doc("Set to 'rndv', 'eager', or 'auto' to indicate what UCX Active Message mode to " + + "use. We set 'rndv' (Rendezvous) by default because UCX 1.10.x doesn't support 'eager' " + + "fully. This restriction can be lifted if the user is running UCX 1.11+.") + .stringConf + .checkValues(Set("rndv", "eager", "auto")) + .createWithDefault("rndv") + val SHUFFLE_UCX_USE_WAKEUP = conf("spark.rapids.shuffle.ucx.useWakeup") .doc("When set to true, use UCX's event-based progress (epoll) in order to wake up " + "the progress thread when needed, instead of a hot loop.") @@ -1014,10 +1022,11 @@ object RapidsConf { .createWithDefault(1000) val SHUFFLE_MAX_METADATA_SIZE = conf("spark.rapids.shuffle.maxMetadataSize") - .doc("The maximum size of a metadata message used in the shuffle.") + .doc("The maximum size of a metadata message that the shuffle plugin will keep in its " + + "direct message pool. ") .internal() .bytesConf(ByteUnit.BYTE) - .createWithDefault(50 * 1024) + .createWithDefault(500 * 1024) val SHUFFLE_COMPRESSION_CODEC = conf("spark.rapids.shuffle.compression.codec") .doc("The GPU codec used to compress shuffle data when using RAPIDS shuffle. " + @@ -1484,6 +1493,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val shuffleTransportMaxReceiveInflightBytes: Long = get( SHUFFLE_TRANSPORT_MAX_RECEIVE_INFLIGHT_BYTES) + lazy val shuffleUcxActiveMessagesMode: String = get(SHUFFLE_UCX_ACTIVE_MESSAGES_MODE) + lazy val shuffleUcxUseWakeup: Boolean = get(SHUFFLE_UCX_USE_WAKEUP) lazy val shuffleUcxUseSockaddr: Boolean = get(SHUFFLE_UCX_LISTENER_ENABLED) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala index 30cbc88d0a0..b4318d6e487 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala @@ -16,10 +16,10 @@ package com.nvidia.spark.rapids.shuffle -import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer} +import ai.rapids.cudf.{Cuda, MemoryBuffer} import com.nvidia.spark.rapids.{Arm, RapidsBuffer, ShuffleMetadata, StorageTier} import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.format.{BufferMeta, BufferTransferRequest, TransferRequest} +import com.nvidia.spark.rapids.format.{BufferMeta, BufferTransferRequest} import org.apache.spark.internal.Logging @@ -44,14 +44,14 @@ import org.apache.spark.internal.Logging * start, it lasts through all buffers being transmitted, and ultimately finishes when a * TransferResponse is sent back to the client. * - * @param request a transfer request + * @param transaction a request transaction * @param sendBounceBuffers - an object that contains a device and potentially a host * buffer also * @param requestHandler - impl of trait that interfaces to the catalog * @param serverStream - CUDA stream to use for copies. */ class BufferSendState( - request: RefCountedDirectByteBuffer, + transaction: Transaction, sendBounceBuffers: SendBounceBuffers, requestHandler: RapidsShuffleRequestHandler, serverStream: Cuda.Stream = Cuda.DEFAULT_STREAM) @@ -62,20 +62,26 @@ class BufferSendState( override def size: Long = tableSize } - private[this] val transferRequest = ShuffleMetadata.getTransferRequest(request.getBuffer()) - private[this] val bufferMetas = new Array[BufferMeta](transferRequest.requestsLength()) private[this] var isClosed = false - private[this] val blocksToSend: Seq[SendBlock] = { - val btr = new BufferTransferRequest() // for reuse - (0 until transferRequest.requestsLength()).map { ix => - val bufferTransferRequest = transferRequest.requests(btr, ix) - withResource(requestHandler.acquireShuffleBuffer( - bufferTransferRequest.bufferId())) { table => - bufferMetas(ix) = table.meta.bufferMeta() - new SendBlock(bufferTransferRequest.bufferId(), - bufferTransferRequest.tag(), table.size) + private[this] val (bufferMetas: Array[BufferMeta], blocksToSend: Seq[SendBlock]) = { + withResource(transaction.releaseMessage()) { msg => + val transferRequest = ShuffleMetadata.getTransferRequest(msg.getBuffer()) + + val bufferMetas = new Array[BufferMeta](transferRequest.requestsLength()) + + val btr = new BufferTransferRequest() // for reuse + val blocksToSend = (0 until transferRequest.requestsLength()).map { ix => + val bufferTransferRequest = transferRequest.requests(btr, ix) + withResource(requestHandler.acquireShuffleBuffer( + bufferTransferRequest.bufferId())) { table => + bufferMetas(ix) = table.meta.bufferMeta() + new SendBlock(bufferTransferRequest.bufferId(), + bufferTransferRequest.tag(), table.size) + } } + + (bufferMetas, blocksToSend) } } @@ -96,8 +102,8 @@ class BufferSendState( private[this] var acquiredBuffs: Seq[RangeBuffer] = Seq.empty - def getTransferRequest: TransferRequest = synchronized { - transferRequest + def getRequestTransaction: Transaction = synchronized { + transaction } def hasNext: Boolean = synchronized { hasMoreBlocks } @@ -120,7 +126,6 @@ class BufferSendState( } isClosed = true freeBounceBuffers() - request.close() releaseAcquiredToCatalog() } @@ -195,7 +200,7 @@ class BufferSendState( } logDebug(s"Sending ${buffsToSend} for transfer request, " + - s" [peer_executor_id=${transferRequest.executorId()}]") + s" [peer_executor_id=${transaction.peerExecutorId()}]") buffsToSend } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala index 35badc64b8d..1c7bcaf738f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala @@ -92,8 +92,6 @@ case class PendingTransferRequest(client: RapidsShuffleClient, * @param exec Executor used to handle tasks that take time, and should not be in the * transport's thread * @param clientCopyExecutor Executors used to handle synchronous mem copies - * @param maximumMetadataSize The maximum metadata buffer size we are able to request - * TODO: this should go away */ class RapidsShuffleClient( localExecutorId: Long, @@ -101,7 +99,6 @@ class RapidsShuffleClient( transport: RapidsShuffleTransport, exec: Executor, clientCopyExecutor: Executor, - maximumMetadataSize: Long, devStorage: RapidsDeviceMemoryStore = RapidsBufferCatalog.getDeviceStorage, catalog: ShuffleReceivedBufferCatalog = GpuShuffleEnv.getReceivedCatalog) extends Logging with Arm { @@ -115,21 +112,9 @@ class RapidsShuffleClient( * @param rapidsShuffleFetchHandler the handler (iterator) to callback to */ case class HandleMetadataResponse(tx: Transaction, - resp: RefCountedDirectByteBuffer, shuffleRequests: Seq[ShuffleBlockBatchId], rapidsShuffleFetchHandler: RapidsShuffleFetchHandler) - /** - * Represents retry due to metadata being larger than expected. - * - * @param shuffleRequests request to retry - * @param rapidsShuffleFetchHandler the handler (iterator) to callback to - * @param fullResponseSize response size to allocate to fit the server's response in full - */ - case class FetchRetry(shuffleRequests: Seq[ShuffleBlockBatchId], - rapidsShuffleFetchHandler: RapidsShuffleFetchHandler, - fullResponseSize: Long) - /** * Used to have this client handle the enclosed [[BufferReceiveState]] asynchronously. * @@ -157,10 +142,8 @@ class RapidsShuffleClient( private[this] def handleOp(op: Any): Unit = { // functions we dispatch to must not throw op match { - case HandleMetadataResponse(tx, resp, shuffleRequests, rapidsShuffleFetchHandler) => - doHandleMetadataResponse(tx, resp, shuffleRequests, rapidsShuffleFetchHandler) - case FetchRetry(shuffleRequests, rapidsShuffleFetchHandler, fullResponseSize) => - doFetch(shuffleRequests, rapidsShuffleFetchHandler, fullResponseSize) + case HandleMetadataResponse(tx, shuffleRequests, rapidsShuffleFetchHandler) => + doHandleMetadataResponse(tx, shuffleRequests, rapidsShuffleFetchHandler) case IssueBufferReceives(bufferReceiveState) => doIssueBufferReceives(bufferReceiveState) case HandleBounceBufferReceive(tx, bufferReceiveState) => @@ -183,54 +166,36 @@ class RapidsShuffleClient( clientCopyExecutor.execute(() => handleOp(op)) } + /** * Starts a fetch request for all the shuffleRequests, using `handler` to communicate * events back to the iterator. * * @param shuffleRequests blocks to fetch * @param handler iterator to callback to - * @param metadataSize metadata size to use for this fetch */ def doFetch(shuffleRequests: Seq[ShuffleBlockBatchId], - handler: RapidsShuffleFetchHandler, - metadataSize: Long = maximumMetadataSize): Unit = { + handler: RapidsShuffleFetchHandler): Unit = { try { withResource(new NvtxRange("Client.fetch", NvtxColor.PURPLE)) { _ => if (shuffleRequests.isEmpty) { throw new IllegalStateException("Sending empty blockIds in the MetadataRequest?") } - // get a metadata response tag so we can send it with the request - val responseTag = connection.assignResponseTag - - // serialize a request, note that this includes the responseTag in the message - val metaReq = new RefCountedDirectByteBuffer(ShuffleMetadata.buildShuffleMetadataRequest( - localExecutorId, // needed s.t. the server knows what endpoint to pick - responseTag, - shuffleRequests, - metadataSize)) + val metaReq = new RefCountedDirectByteBuffer( + ShuffleMetadata.buildShuffleMetadataRequest(shuffleRequests)) logDebug(s"Requesting block_ids=[$shuffleRequests] from connection $connection, req: \n " + s"${ShuffleMetadata.printRequest( ShuffleMetadata.getMetadataRequest(metaReq.getBuffer()))}") - val resp = transport.getMetaBuffer(metadataSize) - // make request - connection.request( - AddressLengthTag.from( - metaReq.acquire(), - connection.composeRequestTag(RequestType.MetadataRequest)), - AddressLengthTag.from( - resp.acquire(), - responseTag), - tx => { - try { - asyncOrBlock(HandleMetadataResponse(tx, resp, shuffleRequests, handler)) - } finally { - metaReq.close() - } - }) + connection.request(RequestType.MetadataRequest, metaReq.acquire(), tx => { + withResource(metaReq) { _ => + logDebug(s"at callback for ${tx}") + asyncOrBlock(HandleMetadataResponse(tx, shuffleRequests, handler)) + } + }) } } catch { case t: Throwable => @@ -242,53 +207,41 @@ class RapidsShuffleClient( * Function to handle MetadataResponses, as a result of the [[HandleMetadataResponse]] event. * * @param tx live metadata response transaction to be closed in this handler - * @param resp response buffer, to be closed in this handler * @param shuffleRequests blocks to fetch * @param handler iterator to callback to */ private[this] def doHandleMetadataResponse( tx: Transaction, - resp: RefCountedDirectByteBuffer, shuffleRequests: Seq[ShuffleBlockBatchId], handler: RapidsShuffleFetchHandler): Unit = { - try { - val start = System.currentTimeMillis() - val handleMetaRange = new NvtxRange("Client.handleMeta", NvtxColor.CYAN) - try { - tx.getStatus match { - case TransactionStatus.Success => - // start the receives - val respBuffer = resp.getBuffer() - val metadataResponse: MetadataResponse = ShuffleMetadata.getMetadataResponse(respBuffer) - - logDebug(s"Received from ${tx} response: \n:" + - s"${ShuffleMetadata.printResponse("received response", metadataResponse)}") - - if (metadataResponse.fullResponseSize() <= respBuffer.capacity()) { - // signal to the handler how many batches are expected - handler.start(metadataResponse.tableMetasLength()) - - // queue up the receives - queueTransferRequests(metadataResponse, handler) - } else { - // NOTE: this path hasn't been tested yet. - logWarning("Large metadata message received, widening the receive size.") - asyncOrBlock( - FetchRetry(shuffleRequests, handler, metadataResponse.fullResponseSize())) + withResource(tx) { _ => + withResource(tx.releaseMessage()) { resp => + withResource(new NvtxRange("Client.handleMeta", NvtxColor.CYAN)) { _ => + try { + tx.getStatus match { + case TransactionStatus.Success => + // start the receives + val metadataResponse: MetadataResponse = + ShuffleMetadata.getMetadataResponse(resp.getBuffer()) + + logDebug(s"Received from ${tx} response: \n:" + + s"${ShuffleMetadata.printResponse("received response", metadataResponse)}") + + // signal to the handler how many batches are expected + handler.start(metadataResponse.tableMetasLength()) + + // queue up the receives + queueTransferRequests(metadataResponse, handler) + case _ => + handler.transferError( + tx.getErrorMessage.getOrElse(s"Unsuccessful metadata request ${tx}")) } - case _ => - handler.transferError( - tx.getErrorMessage.getOrElse(s"Unsuccessful metadata request ${tx}")) + } catch { + case t: Throwable => + handler.transferError("Error occurred while handling metadata", t) + } } - } finally { - logDebug(s"Metadata response handled in ${TransportUtils.timeDiffMs(start)} ms") - handleMetaRange.close() - resp.close() - tx.close() } - } catch { - case t: Throwable => - handler.transferError("Error occurred while handling metadata", t) } } @@ -324,13 +277,13 @@ class RapidsShuffleClient( private def receiveBuffers(bufferReceiveState: BufferReceiveState): Transaction = { val alt = bufferReceiveState.next() - logDebug(s"Issuing receive for ${TransportUtils.formatTag(alt.tag)}") + logDebug(s"Issuing receive for $alt") connection.receive(alt, tx => { tx.getStatus match { case TransactionStatus.Success => - logDebug(s"Handling response for ${TransportUtils.formatTag(alt.tag)}") + logDebug(s"Handling response for $alt") asyncOnCopyThread(HandleBounceBufferReceive(tx, bufferReceiveState)) case _ => try { val errMsg = s"Unsuccessful buffer receive ${tx}" @@ -353,49 +306,31 @@ class RapidsShuffleClient( private[this] def sendTransferRequest(toIssue: BufferReceiveState): Unit = { val requestsToIssue = toIssue.getRequests logDebug(s"Sending a transfer request for " + - s"${requestsToIssue.map(r => TransportUtils.formatTag(r.tag)).mkString(",")}") - - // get a tag that the server can use to send its reply - val responseTag = connection.assignResponseTag + s"${requestsToIssue.map(r => TransportUtils.toHex(r.tag)).mkString(",")}") val transferReq = new RefCountedDirectByteBuffer( - ShuffleMetadata.buildTransferRequest(localExecutorId, responseTag, - requestsToIssue.map(i => (i.tableMeta, i.tag)))) - - if (transferReq.getBuffer().remaining() > maximumMetadataSize) { - throw new IllegalStateException("Trying to send a transfer request metadata buffer that " + - "is larger than the limit.") - } - - val res = transport.getMetaBuffer(maximumMetadataSize) + ShuffleMetadata.buildTransferRequest(requestsToIssue.map(i => (i.tableMeta, i.tag)))) //issue the buffer transfer request - connection.request( - AddressLengthTag.from( - transferReq.acquire(), - connection.composeRequestTag(RequestType.TransferRequest)), - AddressLengthTag.from( - res.acquire(), - responseTag), - tx => { - try { - // make sure all bufferTxs are still valid (e.g. resp says that they have STARTED) - val transferResponse = ShuffleMetadata.getTransferResponse(res.getBuffer()) - (0 until transferResponse.responsesLength()).foreach(r => { - val response = transferResponse.responses(r) - if (response.state() != TransferState.STARTED) { - // we could either re-issue the request, cancelling and releasing memory - // or we could re-issue, and leave the old receive waiting - // for now, leaving the old receive waiting. - throw new IllegalStateException("NOT IMPLEMENTED") - } - }) - } finally { - transferReq.close() - res.close() - tx.close() + connection.request(RequestType.TransferRequest, transferReq.acquire(), tx => { + withResource(tx.releaseMessage()) { res => + withResource(transferReq) { _ => + withResource(tx) { _ => + // make sure all bufferTxs are still valid (e.g. resp says that they have STARTED) + val transferResponse = ShuffleMetadata.getTransferResponse(res.getBuffer()) + (0 until transferResponse.responsesLength()).foreach(r => { + val response = transferResponse.responses(r) + if (response.state() != TransferState.STARTED) { + // we could either re-issue the request, cancelling and releasing memory + // or we could re-issue, and leave the old receive waiting + // for now, leaving the old receive waiting. + throw new IllegalStateException("NOT IMPLEMENTED") + } + }) + } } - }) + } + }) } /** diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala index cb2736b3f1e..41c1b485843 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala @@ -28,7 +28,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.storage.{BlockManagerId, ShuffleBlockBatchId} - /** * Trait used for the server to get buffer metadata (for metadata requests), and * also to acquire a buffer (for transfer requests) @@ -97,10 +96,8 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport, * When a transfer request is received during a callback, the handle code is offloaded via this * event to the server thread. * @param tx the live transaction that should be closed by the handler - * @param metaRequestBuffer contains the metadata request that should be closed by the - * handler */ - case class HandleMeta(tx: Transaction, metaRequestBuffer: RefCountedDirectByteBuffer) + case class HandleMeta(tx: Transaction) /** * When transfer request is received (to begin sending buffers), the handling is offloaded via @@ -133,16 +130,17 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport, */ def start(): Unit = { port = serverConnection.startManagementPort(originalShuffleServerId.host) - // kick off our first receives - doIssueReceive(RequestType.MetadataRequest) - doIssueReceive(RequestType.TransferRequest) + + // register request type interest against the transport + registerRequestHandler(RequestType.MetadataRequest) + registerRequestHandler(RequestType.TransferRequest) } def handleOp(serverTask: Any): Unit = { try { serverTask match { - case HandleMeta(tx, metaRequestBuffer) => - doHandleMeta(tx, metaRequestBuffer) + case HandleMeta(tx) => + doHandleMetadataRequest(tx) case HandleTransferRequest(wt: Seq[BufferSendState]) => doHandleTransferRequest(wt) } @@ -203,7 +201,7 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport, if (sendBounceBuffers.nonEmpty) { val pendingTransfer = pendingTransfersQueue.poll() bssToIssue.append(new BufferSendState( - pendingTransfer.metaRequest, + pendingTransfer.tx, sendBounceBuffers.head, // there's only one bounce buffer here for now pendingTransfer.requestHandler, serverStream)) @@ -238,125 +236,82 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport, * * @param requestType The request type received */ - private def doIssueReceive(requestType: RequestType.Value): Unit = { - logDebug(s"Waiting for a new connection. Posting ${requestType} receive.") - val metaRequest = transport.getMetaBuffer(rapidsConf.shuffleMaxMetadataSize) - - val alt = AddressLengthTag.from( - metaRequest.acquire(), - serverConnection.composeRequestTag(requestType)) - - serverConnection.receive(alt, - tx => { - val handleMetaRange = new NvtxRange("Handle Meta Request", NvtxColor.PURPLE) - try { - if (requestType == RequestType.MetadataRequest) { - doIssueReceive(RequestType.MetadataRequest) - doHandleMeta(tx, metaRequest) - } else { - val pendingTransfer = PendingTransferResponse(metaRequest, requestHandler) - // tell the bssExec to wake up to try to handle the new BufferSendState + private def registerRequestHandler(requestType: RequestType.Value): Unit = { + logDebug(s"Registering ${requestType} request callback") + serverConnection.registerRequestHandler(requestType, tx => { + withResource(new NvtxRange("Handle Meta Request", NvtxColor.PURPLE)) { _ => + requestType match { + case RequestType.MetadataRequest => + doHandleMetadataRequest(tx) + case RequestType.TransferRequest => + val pendingTransfer = PendingTransferResponse(tx, requestHandler) bssExec.synchronized { pendingTransfersQueue.add(pendingTransfer) bssExec.notifyAll() } logDebug(s"Got a transfer request ${pendingTransfer} from ${tx}. " + s"Pending requests [new=${pendingTransfersQueue.size}, " + - s"continuing=${bssContinueQueue.size}]") - doIssueReceive(RequestType.TransferRequest) - } - } finally { - handleMetaRange.close() - tx.close() + s"continuing=${bssContinueQueue.size}]") } - }) - } - - case class PendingTransferResponse( - metaRequest: RefCountedDirectByteBuffer, - requestHandler: RapidsShuffleRequestHandler) - - /** - * Function to handle `MetadataRequest`s. It will populate and issue a - * `MetadataResponse` response for the appropriate client. - * - * @param tx the inbound [[Transaction]] - * @param metaRequest a [[RefCountedDirectByteBuffer]] holding a `MetadataRequest` message. - */ - def doHandleMeta(tx: Transaction, metaRequest: RefCountedDirectByteBuffer): Unit = { - val doHandleMetaRange = new NvtxRange("doHandleMeta", NvtxColor.PURPLE) - val start = System.currentTimeMillis() - try { - if (tx.getStatus == TransactionStatus.Error) { - logError("error getting metadata request: " + tx) - metaRequest.close() // the buffer is not going to be handed anywhere else, so lets close it - } else { - logDebug(s"Received metadata request: $tx => $metaRequest") - handleMetadataRequest(metaRequest) } - } finally { - logDebug(s"Metadata request handled in ${TransportUtils.timeDiffMs(start)} ms") - doHandleMetaRange.close() - } + }) } + case class PendingTransferResponse(tx: Transaction, requestHandler: RapidsShuffleRequestHandler) + /** * Handles the very first message that a client will send, in order to request Table/Buffer info. - * @param metaRequest a [[RefCountedDirectByteBuffer]] holding a `MetadataRequest` message. + * @param tx: [[Transaction]] - a transaction object that carries status and payload */ - def handleMetadataRequest(metaRequest: RefCountedDirectByteBuffer): Unit = { - try { - val req = ShuffleMetadata.getMetadataRequest(metaRequest.getBuffer()) - - // target executor to respond to - val peerExecutorId = req.executorId() - - // tag to use for the response message - val responseTag = req.responseTag() - - logDebug(s"Received request req:\n: ${ShuffleMetadata.printRequest(req)}") - logDebug(s"HandleMetadataRequest for peerExecutorId $peerExecutorId and " + - s"responseTag ${TransportUtils.formatTag(req.responseTag())}") - - // NOTE: MetaUtils will have a simpler/better way of handling creating a response. - // That said, at this time, I see some issues with that approach from the flatbuffer - // library, so the code to create the metadata response will likely change. - val responseTables = (0 until req.blockIdsLength()).flatMap { i => - val blockId = req.blockIds(i) - // this is getting shuffle buffer ids - requestHandler.getShuffleBufferMetas( - ShuffleBlockBatchId(blockId.shuffleId(), blockId.mapId(), - blockId.startReduceId(), blockId.endReduceId())) - } - - val metadataResponse = - ShuffleMetadata.buildMetaResponse(responseTables, req.maxResponseSize()) - // Wrap the buffer so we keep a reference to it, and we destroy it later on .close - val respBuffer = new RefCountedDirectByteBuffer(metadataResponse) - val materializedResponse = ShuffleMetadata.getMetadataResponse(metadataResponse) - - logDebug(s"Response will be at tag ${TransportUtils.formatTag(responseTag)}:\n"+ - s"${ShuffleMetadata.printResponse("responding", materializedResponse)}") - - val response = AddressLengthTag.from(respBuffer.acquire(), responseTag) - - // Issue the send against [[peerExecutorId]] as described by the metadata message - val tx = serverConnection.send(peerExecutorId, response, tx => { - try { + def doHandleMetadataRequest(tx: Transaction): Unit = { + withResource(tx) { _ => + withResource(new NvtxRange("doHandleMeta", NvtxColor.PURPLE)) { _ => + withResource(tx.releaseMessage()) { metaRequest => if (tx.getStatus == TransactionStatus.Error) { - logError(s"Error sending metadata response in tx $tx") + logError("error getting metadata request: " + tx) } else { - val stats = tx.getStats - logDebug(s"Sent metadata ${stats.sendSize} in ${stats.txTimeMs} ms") + val req = ShuffleMetadata.getMetadataRequest(metaRequest.getBuffer()) + + logDebug(s"Received request req:\n: ${ShuffleMetadata.printRequest(req)}") + logDebug(s"HandleMetadataRequest for peerExecutorId ${tx.peerExecutorId()} and " + + s"tx ${tx}") + + // NOTE: MetaUtils will have a simpler/better way of handling creating a response. + // That said, at this time, I see some issues with that approach from the flatbuffer + // library, so the code to create the metadata response will likely change. + val responseTables = (0 until req.blockIdsLength()).flatMap { i => + val blockId = req.blockIds(i) + // this is getting shuffle buffer ids + requestHandler.getShuffleBufferMetas( + ShuffleBlockBatchId(blockId.shuffleId(), blockId.mapId(), + blockId.startReduceId(), blockId.endReduceId())) + } + + val metadataResponse = + ShuffleMetadata.buildMetaResponse(responseTables) + // Wrap the buffer so we keep a reference to it, and we destroy it later on .close + val respBuffer = new RefCountedDirectByteBuffer(metadataResponse) + val materializedResponse = ShuffleMetadata.getMetadataResponse(metadataResponse) + + logDebug(s"Response will be at header ${TransportUtils.toHex(tx.getHeader)}:\n" + + s"${ShuffleMetadata.printResponse("responding", materializedResponse)}") + + val responseTx = tx.respond(respBuffer.getBuffer(), responseTx => { + withResource(responseTx) { responseTx => + withResource(respBuffer) { _ => + if (responseTx.getStatus == TransactionStatus.Error) { + logError(s"Error sending metadata response in tx $tx") + } else { + val stats = responseTx.getStats + logDebug(s"Sent metadata ${stats.sendSize} in ${stats.txTimeMs} ms") + } + } + } + }) + logDebug(s"Waiting for send metadata to complete: $responseTx") } - } finally { - respBuffer.close() - tx.close() } - }) - logDebug(s"Waiting for send metadata to complete: $tx") - } finally { - metaRequest.close() + } } } @@ -388,15 +343,15 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport, bssBuffers.foreach { case (bufferSendState, buffersToSend) => - val transferRequest = bufferSendState.getTransferRequest - serverConnection.send(transferRequest.executorId(), buffersToSend, bufferTx => - try { + val peerExecutorId = bufferSendState.getRequestTransaction.peerExecutorId() + serverConnection.send(peerExecutorId, buffersToSend, bufferTx => + withResource(bufferTx) { _ => logDebug(s"Done with the send for ${bufferSendState} with ${buffersToSend}") if (bufferSendState.hasNext) { // continue issuing sends. logDebug(s"Buffer send state ${bufferSendState} is NOT done. " + - s"Still pending: ${pendingTransfersQueue.size}.") + s"Still pending: ${pendingTransfersQueue.size}.") bssExec.synchronized { bssContinueQueue.add(bufferSendState) bssExec.notifyAll() @@ -404,28 +359,35 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport, } else { val transferResponse = bufferSendState.getTransferResponse() - logDebug(s"Handling transfer request for ${transferRequest.executorId()} " + - s"with ${buffersToSend}") + val requestTx = bufferSendState.getRequestTransaction + + logDebug(s"Handling transfer request ${requestTx} for " + + s"${peerExecutorId} " + + s"with ${buffersToSend}") // send the transfer response - serverConnection.send( - transferRequest.executorId, - AddressLengthTag.from(transferResponse.acquire(), transferRequest.responseTag()), + requestTx.respond(transferResponse.acquire(), transferResponseTx => { - transferResponse.close() - transferResponseTx.close() + withResource(transferResponseTx) { _ => + withResource(transferResponse) { _ => + transferResponseTx.getStatus match { + case TransactionStatus.Cancelled | TransactionStatus.Error => + logError(s"Error while handling TransferResponse: " + + s"${transferResponseTx.getErrorMessage}") + case _ => + } + } + } }) // wake up the bssExec since bounce buffers became available - logDebug(s"Buffer send state ${buffersToSend.tag} is done. Closing. " + - s"Still pending: ${pendingTransfersQueue.size}.") + logDebug(s"Buffer send state ${buffersToSend} is done. Closing. " + + s"Still pending: ${pendingTransfersQueue.size}.") bssExec.synchronized { bufferSendState.close() bssExec.notifyAll() } } - } finally { - bufferTx.close() }) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTransport.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTransport.scala index 584794bbb38..5ae535e9b37 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTransport.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTransport.scala @@ -79,7 +79,7 @@ class AddressLengthTag(val address: Long, var length: Long, val tag: Long, } override def toString: String = { - s"AddressLengthTag[address=$address, length=$length, tag=${TransportUtils.formatTag(tag)}]" + s"AddressLengthTag[address=$address, length=$length, tag=${TransportUtils.toHex(tag)}]" } /** @@ -105,20 +105,6 @@ object AddressLengthTag { tag, Some(memoryBuffer)) } - - /** - * Construct an [[AddressLengthTag]] given a `ByteBuffer` - * This is used for metadata messages, and the buffers are direct. - * @param byteBuffer the buffer the [[AddressLengthTag]] should point to - * @param tag the transport tag to use to send/receive this buffer - * @return an instance of [[AddressLengthTag]] - */ - def from(byteBuffer: ByteBuffer, tag: Long): AddressLengthTag = { - new AddressLengthTag( - TransportUtils.getAddress(byteBuffer), - byteBuffer.remaining(), - tag) - } } trait TransactionCallback { @@ -149,24 +135,37 @@ trait ServerConnection extends Connection { /** * Function to send bounce buffers to a peer * @param peerExecutorId peer's executor id to target - * @param bounceBuffers bounce buffers to send + * @param buffer an [[AddressLengthTag]] for a buffer to send * @param cb callback to trigger once done * @return the [[Transaction]], which can be used to block wait for this send. */ def send(peerExecutorId: Long, - bounceBuffers: Seq[AddressLengthTag], + buffer: AddressLengthTag, cb: TransactionCallback): Transaction /** - * Function to send bounce buffers to a peer - * @param peerExecutorId peer's executor id to target - * @param header an [[AddressLengthTag]] containing a metadata message to send - * @param cb callback to trigger once done - * @return the [[Transaction]], which can be used to block wait for this send. + * Registers a callback that will be called any type a `RequestType` message is + * received by this `ServerConnection` + * @param requestType see `RequestType` enum + * @param cb triggered for a success or error on this request */ - def send(peerExecutorId: Long, - header: AddressLengthTag, - cb: TransactionCallback): Transaction + def registerRequestHandler(requestType: RequestType.Value, cb: TransactionCallback): Unit + + /** + * Respond to a request. + * @param peerExecutorId - executor to send response to + * @param messageType - the type of message this is + * @param header - a long that should match a request header, requester use this header + * to disambiguate messages + * @param response - a direct `ByteBuffer` to be transmitted + * @param cb callback to trigger once this respond completes + * @return a [[Transaction]] that can be used to block while this transaction is not done + */ + def respond(peerExecutorId: Long, + messageType: RequestType.Value, + header: Long, + response: ByteBuffer, + cb: TransactionCallback): Transaction } /** @@ -193,27 +192,19 @@ object RequestType extends Enumeration { */ trait ClientConnection extends Connection { /** - * This performs a request/response, where the request is read from one - * `AddressLengthTag` `request`, and the response is populated at the memory - * described by `response`. + * This performs a request/response for a request of type `RequestType`. The response + * `Transaction` on completion will call the callback (`cb`). The caller of `request` + * must close, or consume the response in the transaction, otherwise we can leak. * - * @param request the populated request buffer [[AddressLengthTag]] - * @param response the response buffer [[AddressLengthTag]] where the response will be - * stored when the request succeeds. + * @param requestType value of the `RequestType` enum + * @param request the populated request direct buffer `ByteBuffer` * @param cb callback to handle transaction status. If successful the memory described * using "response" will hold the response as expected, otherwise its contents * are not defined. * @return a transaction representing the request */ - def request(request: AddressLengthTag, - response: AddressLengthTag, - cb: TransactionCallback): Transaction - - /** - * This function assigns tags that are valid for responses in this connection. - * @return a Long tag to use for a response - */ - def assignResponseTag: Long + def request(requestType: RequestType.Value, request: ByteBuffer, + cb: TransactionCallback): Transaction /** * This function assigns tags for individual buffers to be received in this connection. @@ -236,18 +227,6 @@ trait ClientConnection extends Connection { * in each case. */ trait Connection { - /** - * Both the client and the server need to compose request tags depending on the - * type of request being sent or handled. - * - * Note it is up to the implemented to compose the tag in whatever way makes - * most sense for the underlying transport. - * - * @param requestType the type of request this tag is for - * @return a Long tag to be used for this request - */ - def composeRequestTag(requestType: RequestType.Value): Long - /** * Function to receive a buffer * @param alt an [[AddressLengthTag]] to receive the message @@ -256,6 +235,7 @@ trait Connection { */ def receive(alt: AddressLengthTag, cb: TransactionCallback): Transaction + } object TransactionStatus extends Enumeration { @@ -290,6 +270,17 @@ case class TransactionStats(txTimeMs: Double, * outside of [[waitForCompletion]] produces undefined behavior. */ trait Transaction extends AutoCloseable { + /** + * Get the peer executor id if available + * @note this can throw if the `Transaction` was not created due to an Active Message + */ + def peerExecutorId(): Long + + /** + * Return this transaction's header, for debug purposes + */ + def getHeader: Long + /** * Get the status this transaction is in. Callbacks use this to handle various transaction states * (e.g. success, error, etc.) @@ -316,6 +307,25 @@ trait Transaction extends AutoCloseable { * deadlock. */ def waitForCompletion(): Unit + + /** + * Hands over a message (a host-side request or response at the moment) + * that is held in the Transaction + * @note The caller must call `close` on the returned message + * @return a direct ref counted byte buffer, possible from the byte buffer pool + */ + def releaseMessage(): RefCountedDirectByteBuffer + + /** + * For `Request` transactions, `respond` will be able to reply to a peer who issued + * the request + * @note this is only available for server-side transactions, and will throw + * if attempted from a client + * @param response a direct ByteBuffer + * @param cb triggered when the response succeds/fails + * @return a `Transaction` object that can be used to wait for this response to complete + */ + def respond(response: ByteBuffer, cb: TransactionCallback): Transaction } /** @@ -364,7 +374,7 @@ trait RapidsShuffleTransport extends AutoCloseable { * @param size size of buffer required * @return the ref counted buffer */ - def getMetaBuffer(size: Long): RefCountedDirectByteBuffer + def getDirectByteBuffer(size: Long): RefCountedDirectByteBuffer /** * (throttle) Adds a set of requests to be throttled as limits allowed. @@ -421,7 +431,7 @@ trait RapidsShuffleTransport extends AutoCloseable { * @param bufferSize the size of direct `ByteBuffer` to allocate. */ class DirectByteBufferPool(bufferSize: Long) extends Logging { - val buffers = new ConcurrentLinkedQueue[RefCountedDirectByteBuffer]() + val buffers = new ConcurrentLinkedQueue[ByteBuffer]() val high = new AtomicInteger(0) def getBuffer(size: Long): RefCountedDirectByteBuffer = { @@ -429,19 +439,20 @@ class DirectByteBufferPool(bufferSize: Long) extends Logging { throw new IllegalStateException(s"Buffers of size $bufferSize are the only ones supported, " + s"asked for $size") } - var buff = buffers.poll() + val buff = buffers.poll() if (buff == null) { high.incrementAndGet() logDebug(s"Allocating new direct buffer, high watermark = $high") - buff = new RefCountedDirectByteBuffer(ByteBuffer.allocateDirect(bufferSize.toInt), Some(this)) + new RefCountedDirectByteBuffer(ByteBuffer.allocateDirect(bufferSize.toInt), Option(this)) + } else { + buff.clear() + new RefCountedDirectByteBuffer(buff, Option(this)) } - buff.getBuffer().clear() - buff } def releaseBuffer(buff: RefCountedDirectByteBuffer): Boolean = { logDebug(s"Free direct buffers ${buffers.size()}") - buffers.offer(buff) + buffers.offer(buff.getBuffer()) } } @@ -465,6 +476,8 @@ class RefCountedDirectByteBuffer( var refCount: Int = 0 + var closed: Boolean = false + /** * Adds one to the ref count. Caller should call .close() when done * @return wrapped buffer @@ -480,11 +493,16 @@ class RefCountedDirectByteBuffer( */ def getBuffer(): ByteBuffer = bb + def isClosed: Boolean = synchronized { closed } + /** * Decrements the ref count. If the ref count reaches 0, the buffer is * either returned to the (optional) pool or destroyed. */ override def close(): Unit = synchronized { + if (closed) { + throw new IllegalStateException("Close called too many times!") + } refCount = refCount - 1 if (refCount <= 0) { if (pool.isDefined) { @@ -493,6 +511,7 @@ class RefCountedDirectByteBuffer( unsafeDestroy() // not pooled, should disappear } } + closed = true } /** @@ -513,10 +532,14 @@ class RefCountedDirectByteBuffer( * A set of util functions used throughout */ object TransportUtils { - def formatTag(tag: Long): String = { + def toHex(tag: Long): String = { f"0x$tag%016X" } + def toHex(tag: Int): String = { + f"0x$tag%08X" + } + def copyBuffer(src: ByteBuffer, dst: ByteBuffer, size: Int): Unit = { val copyMetaRange = new NvtxRange("Transport.CopyBuffer", NvtxColor.RED) try { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala index 3edef3d7f9d..9d442c899b5 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala @@ -73,14 +73,14 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { val shuffleRequests = RapidsShuffleTestHelper.getShuffleBlocks val contigBuffSize = 100000 val numBatches = 3 - val tableMetas = - RapidsShuffleTestHelper.mockMetaResponse(mockTransport, contigBuffSize, numBatches) + val (tableMetas, response) = + RapidsShuffleTestHelper.mockMetaResponse(mockTransaction, contigBuffSize, numBatches) // initialize metadata fetch client.doFetch(shuffleRequests.map(_._1), mockHandler) // the connection saw one request (for metadata) - assertResult(1)(mockConnection.requests.size) + assertResult(1)(mockConnection.requests) // upon a successful response, the `start()` method in the fetch handler // will be called with 3 expected batches @@ -98,6 +98,8 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { val tm = ptrs(t).tableMeta verifyTableMeta(expected, tm) } + + assert(response.isClosed) } test("successful degenerate metadata fetch") { @@ -106,13 +108,13 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { val numRows = 100000 val numBatches = 3 - RapidsShuffleTestHelper.mockDegenerateMetaResponse(mockTransport, numBatches) + RapidsShuffleTestHelper.mockDegenerateMetaResponse(mockTransaction, numBatches) // initialize metadata fetch client.doFetch(shuffleRequests.map(_._1), mockHandler) // the connection saw one request (for metadata) - assertResult(1)(mockConnection.requests.size) + assertResult(1)(mockConnection.requests) // upon a successful response, the `start()` method in the fetch handler // will be called with 3 expected batches @@ -132,12 +134,12 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { val shuffleRequests = RapidsShuffleTestHelper.getShuffleBlocks val contigBuffSize = 100000 - RapidsShuffleTestHelper.mockMetaResponse( - mockTransport, contigBuffSize, 3) + val (_, response) = RapidsShuffleTestHelper.mockMetaResponse( + mockTransaction, contigBuffSize, 3) client.doFetch(shuffleRequests.map(_._1), mockHandler) - assertResult(1)(mockConnection.requests.size) + assertResult(1)(mockConnection.requests) // upon an errored response, the start handler will not be called verify(mockHandler, times(0)).start(any()) @@ -148,6 +150,8 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { // the transport will receive no pending requests (for buffers) for queuing verify(mockTransport, times(0)).queuePending(any()) + assert(response.isClosed) + newMocks() } } @@ -156,12 +160,12 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { when(mockTransaction.getStatus).thenThrow(new RuntimeException("test exception")) val shuffleRequests = RapidsShuffleTestHelper.getShuffleBlocks val contigBuffSize = 100000 - RapidsShuffleTestHelper.mockMetaResponse( - mockTransport, contigBuffSize, 3) + var (_, response) = RapidsShuffleTestHelper.mockMetaResponse( + mockTransaction, contigBuffSize, 3) client.doFetch(shuffleRequests.map(_._1), mockHandler) - assertResult(1)(mockConnection.requests.size) + assertResult(1)(mockConnection.requests) // upon an errored response, the start handler will not be called verify(mockHandler, times(0)).start(any()) @@ -173,6 +177,8 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { // the transport will receive no pending requests (for buffers) for queuing verify(mockTransport, times(0)).queuePending(any()) + assert(response.isClosed) + newMocks() } @@ -181,7 +187,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { val numRows = 25001 val tableMeta = - RapidsShuffleTestHelper.prepareMetaTransferResponse(mockTransport, numRows) + RapidsShuffleTestHelper.prepareMetaTransferResponse(mockTransaction, numRows) // 10000 in bytes ~ 2500 rows (minus validity/offset buffers) worth of contiguous // single column int table, so we need 10 buffer-lengths to receive all of 25000 rows, @@ -251,7 +257,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { } // after closing, we should have freed our bounce buffers. - assertResult(true)(bounceBuffer.isClosed) + assert(bounceBuffer.isClosed) } } } @@ -262,7 +268,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { val numRows = 100 val tableMeta = - RapidsShuffleTestHelper.prepareMetaTransferResponse(mockTransport, numRows) + RapidsShuffleTestHelper.prepareMetaTransferResponse(mockTransaction, numRows) val sizePerBuffer = 10000 val expectedReceives = 1 closeOnExcept(getBounceBuffer(sizePerBuffer)) { bounceBuffer => @@ -304,7 +310,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { assertResult(tableMeta.bufferMeta().size())(receivedBuff.getLength) // after closing, we should have freed our bounce buffers. - assertResult(true)(bounceBuffer.isClosed) + assert(bounceBuffer.isClosed) } } @@ -314,7 +320,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { val numRows = 500 val tableMetas = (0 until 5).map { - _ => RapidsShuffleTestHelper.prepareMetaTransferResponse(mockTransport, numRows) + _ => RapidsShuffleTestHelper.prepareMetaTransferResponse(mockTransaction, numRows) } // 20000 in bytes ~ 5000 rows (minus validity/offset buffers) worth of contiguous @@ -365,7 +371,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { dmbCaptor.getAllValues().toArray().map(_.asInstanceOf[DeviceMemoryBuffer].getLength).sum) // after closing, we should have freed our bounce buffers. - assertResult(true)(bounceBuffer.isClosed) + assert(bounceBuffer.isClosed) } } @@ -375,7 +381,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { val numRows = 500 val tableMetas = (0 until 20).map { - _ => RapidsShuffleTestHelper.prepareMetaTransferResponse(mockTransport, numRows) + _ => RapidsShuffleTestHelper.prepareMetaTransferResponse(mockTransaction, numRows) } // 20000 in bytes ~ 5000 rows (minus validity/offset buffers) worth of contiguous @@ -427,7 +433,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { dmbCaptor.getAllValues().toArray().map(_.asInstanceOf[DeviceMemoryBuffer].getLength).sum) // after closing, we should have freed our bounce buffers. - assertResult(true)(bounceBuffer.isClosed) + assert(bounceBuffer.isClosed) } } @@ -438,7 +444,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { val numRows = 100000 val tableMeta = - RapidsShuffleTestHelper.prepareMetaTransferResponse(mockTransport, numRows) + RapidsShuffleTestHelper.prepareMetaTransferResponse(mockTransaction, numRows) // error condition, so it doesn't matter much what we set here, only the first // receive will happen @@ -465,7 +471,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { verify(mockConnection, times(1)).request(any(), any(), any()) // ensure we closed the BufferReceiveState => releasing the bounce buffers - assertResult(true)(bounceBuffer.isClosed) + assert(bounceBuffer.isClosed) } newMocks() @@ -477,7 +483,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { val numRows = 100000 val tableMeta = - RapidsShuffleTestHelper.prepareMetaTransferResponse(mockTransport, numRows) + RapidsShuffleTestHelper.prepareMetaTransferResponse(mockTransaction, numRows) // error condition, so it doesn't matter much what we set here, only the first // receive will happen @@ -503,7 +509,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { verify(mockConnection, times(1)).request(any(), any(), any()) // ensure we closed the BufferReceiveState => releasing the bounce buffers - assertResult(true)(bounceBuffer.isClosed) + assert(bounceBuffer.isClosed) } newMocks() @@ -526,8 +532,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { mockTable: TableMeta, consumed: ConsumedBatchFromBounceBuffer): Unit = { assertResult(mockTable.bufferMeta().size())(consumed.contigBuffer.getLength) - assertResult(true)( - areBuffersEqual(source, consumed.contigBuffer)) + assert(areBuffersEqual(source, consumed.contigBuffer)) } class MockBlock(val hmb: HostMemoryBuffer) extends BlockWithSize { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala index cf198a3d4b0..ee6e627610d 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -68,7 +68,7 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler]) when(mockTransport.makeClient(any(), any())).thenReturn(client) - doNothing().when(client).doFetch(any(), ac.capture(), any()) + doNothing().when(client).doFetch(any(), ac.capture()) cl.start() val handler = ac.getValue.asInstanceOf[RapidsShuffleFetchHandler] @@ -102,7 +102,7 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler]) when(mockTransport.makeClient(any(), any())).thenReturn(client) - doNothing().when(client).doFetch(any(), ac.capture(), any()) + doNothing().when(client).doFetch(any(), ac.capture()) cl.start() val handler = ac.getValue.asInstanceOf[RapidsShuffleFetchHandler] @@ -144,7 +144,7 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler]) when(mockTransport.makeClient(any(), any())).thenReturn(client) - doNothing().when(client).doFetch(any(), ac.capture(), any()) + doNothing().when(client).doFetch(any(), ac.capture()) // signal a timeout to the iterator when(cl.pollForResult(any())).thenReturn(None) @@ -177,7 +177,7 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler]) when(mockTransport.makeClient(any(), any())).thenReturn(client) - doNothing().when(client).doFetch(any(), ac.capture(), any()) + doNothing().when(client).doFetch(any(), ac.capture()) val bufferId = ShuffleReceivedBufferId(1) val mockBuffer = mock[RapidsBuffer] diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala index e71e4ae8503..9dfa157b2d0 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala @@ -94,15 +94,16 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { } test("sending tables that fit within one bounce buffer") { - val mockTransferRequest = - RapidsShuffleTestHelper.prepareMetaTransferRequest(10, 1000) + val mockTx = mock[Transaction] + val transferRequest = RapidsShuffleTestHelper.prepareMetaTransferRequest(10, 1000) + when(mockTx.releaseMessage()).thenReturn(transferRequest) val bb = closeOnExcept(getSendBounceBuffer(10000)) { bounceBuffer => withResource((0 until 10).map(_ => DeviceMemoryBuffer.allocate(1000))) { deviceBuffers => val receiveSide = deviceBuffers.map(b => new MockBlockWithSize(b)) val receiveWindow = new WindowedBlockIterator[MockBlockWithSize](receiveSide, 10000) val (handler, mockBuffers, numCloses) = setupMocks(deviceBuffers) - withResource(new BufferSendState(mockTransferRequest, bounceBuffer, handler)) { bss => + withResource(new BufferSendState(mockTx, bounceBuffer, handler)) { bss => assert(bss.hasNext) val alt = bss.next() val receiveBlocks = receiveWindow.next() @@ -119,19 +120,22 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { } bounceBuffer } - assertResult(true)(bb.deviceBounceBuffer.isClosed) + assert(bb.deviceBounceBuffer.isClosed) + assert(transferRequest.isClosed) newMocks() } test("sending tables that require two bounce buffer lengths") { - val mockTransferRequest = RapidsShuffleTestHelper.prepareMetaTransferRequest(20, 1000) + val mockTx = mock[Transaction] + val transferRequest = RapidsShuffleTestHelper.prepareMetaTransferRequest(20, 1000) + when(mockTx.releaseMessage()).thenReturn(transferRequest) val bb = closeOnExcept(getSendBounceBuffer(10000)) { bounceBuffer => withResource((0 until 20).map(_ => DeviceMemoryBuffer.allocate(1000))) { deviceBuffers => val receiveSide = deviceBuffers.map(b => new MockBlockWithSize(b)) val receiveWindow = new WindowedBlockIterator[MockBlockWithSize](receiveSide, 10000) val (handler, mockBuffers, numCloses) = setupMocks(deviceBuffers) - withResource(new BufferSendState(mockTransferRequest, bounceBuffer, handler)) { bss => + withResource(new BufferSendState(mockTx, bounceBuffer, handler)) { bss => var buffs = bss.next() var receiveBlocks = receiveWindow.next() compareRanges(bounceBuffer, receiveBlocks) @@ -153,11 +157,14 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { } bounceBuffer } - assertResult(true)(bb.deviceBounceBuffer.isClosed) + assert(bb.deviceBounceBuffer.isClosed) + assert(transferRequest.isClosed) } test("sending buffers larger than bounce buffer") { - val mockTransferRequest = RapidsShuffleTestHelper.prepareMetaTransferRequest(20, 10000) + val mockTx = mock[Transaction] + val transferRequest = RapidsShuffleTestHelper.prepareMetaTransferRequest(20, 10000) + when(mockTx.releaseMessage()).thenReturn(transferRequest) val bb = closeOnExcept(getSendBounceBuffer(10000)) { bounceBuffer => withResource((0 until 20).map(_ => DeviceMemoryBuffer.allocate(123000))) { deviceBuffers => @@ -165,7 +172,7 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { val receiveSide = deviceBuffers.map(b => new MockBlockWithSize(b)) val receiveWindow = new WindowedBlockIterator[MockBlockWithSize](receiveSide, 10000) - withResource(new BufferSendState(mockTransferRequest, bounceBuffer, handler)) { bss => + withResource(new BufferSendState(mockTx, bounceBuffer, handler)) { bss => (0 until 246).foreach { _ => bss.next() val receiveBlocks = receiveWindow.next() @@ -180,6 +187,7 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { } bounceBuffer } - assertResult(true)(bb.deviceBounceBuffer.isClosed) + assert(bb.deviceBounceBuffer.isClosed) + assert(transferRequest.isClosed) } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala index b945bbe78c8..b2a707dbcfd 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala @@ -16,6 +16,7 @@ package com.nvidia.spark.rapids.shuffle +import java.nio.ByteBuffer import java.util.concurrent.Executor import scala.collection.mutable.ArrayBuffer @@ -148,7 +149,6 @@ class RapidsShuffleTestHelper extends FunSuite mockTransport, mockExecutor, mockCopyExecutor, - 1024, mockStorage, mockCatalog)) } @@ -179,27 +179,25 @@ object RapidsShuffleTestHelper extends MockitoSugar with Arm { } def mockMetaResponse( - mockTransport: RapidsShuffleTransport, + mockTransaction: Transaction, numRows: Long, - numBatches: Int, - maximumResponseSize: Long = 10000): Seq[TableMeta] = + numBatches: Int): (Seq[TableMeta], RefCountedDirectByteBuffer) = withMockContiguousTable(numRows) { ct => val tableMetas = (0 until numBatches).map(b => buildMockTableMeta(b, ct)) - val res = ShuffleMetadata.buildMetaResponse(tableMetas, maximumResponseSize) + val res = ShuffleMetadata.buildMetaResponse(tableMetas) val refCountedRes = new RefCountedDirectByteBuffer(res) - when(mockTransport.getMetaBuffer(any())).thenReturn(refCountedRes) - tableMetas + when(mockTransaction.releaseMessage()).thenReturn(refCountedRes) + (tableMetas, refCountedRes) } def mockDegenerateMetaResponse( - mockTransport: RapidsShuffleTransport, - numBatches: Int, - maximumResponseSize: Long = 10000): Seq[TableMeta] = { - val tableMetas = (0 until numBatches).map(_ => buildDegenerateMockTableMeta()) - val res = ShuffleMetadata.buildMetaResponse(tableMetas, maximumResponseSize) + mockTransaction: Transaction, + numBatches: Int): (Seq[TableMeta], RefCountedDirectByteBuffer) = { + val tableMetas = (0 until numBatches).map(b => buildDegenerateMockTableMeta()) + val res = ShuffleMetadata.buildMetaResponse(tableMetas) val refCountedRes = new RefCountedDirectByteBuffer(res) - when(mockTransport.getMetaBuffer(any())).thenReturn(refCountedRes) - tableMetas + when(mockTransaction.releaseMessage()).thenReturn(refCountedRes) + (tableMetas, refCountedRes) } def prepareMetaTransferRequest(numTables: Int, numRows: Long): RefCountedDirectByteBuffer = @@ -207,7 +205,7 @@ object RapidsShuffleTestHelper extends MockitoSugar with Arm { val tableMetaTags = (0 until numTables).map { t => (buildMockTableMeta(t, ct), t.toLong) } - val trBuffer = ShuffleMetadata.buildTransferRequest(1, 123, tableMetaTags) + val trBuffer = ShuffleMetadata.buildTransferRequest(tableMetaTags) val refCountedRes = new RefCountedDirectByteBuffer(trBuffer) refCountedRes } @@ -218,14 +216,14 @@ object RapidsShuffleTestHelper extends MockitoSugar with Arm { } def prepareMetaTransferResponse( - mockTransport: RapidsShuffleTransport, + mockTransaction: Transaction, numRows: Long): TableMeta = withMockContiguousTable(numRows) { ct => val tableMeta = buildMockTableMeta(1, ct) val bufferMeta = tableMeta.bufferMeta() val res = ShuffleMetadata.buildBufferTransferResponse(Seq(bufferMeta)) val refCountedRes = new RefCountedDirectByteBuffer(res) - when(mockTransport.getMetaBuffer(any())).thenReturn(refCountedRes) + when(mockTransaction.releaseMessage()).thenReturn(refCountedRes) tableMeta } @@ -259,16 +257,8 @@ class ImmediateExecutor extends Executor { } class MockConnection(mockTransaction: Transaction) extends ClientConnection { - val requests = new ArrayBuffer[AddressLengthTag] + var requests: Int = 0 val receiveLengths = new ArrayBuffer[Long] - override def request( - request: AddressLengthTag, - response: AddressLengthTag, - cb: TransactionCallback): Transaction = { - requests.append(request) - cb(mockTransaction) - mockTransaction - } override def receive(alt: AddressLengthTag, cb: TransactionCallback): Transaction = { receiveLengths.append(alt.length) @@ -278,8 +268,13 @@ class MockConnection(mockTransaction: Transaction) extends ClientConnection { override def getPeerExecutorId: Long = 0 - override def assignResponseTag: Long = 1L override def assignBufferTag(msgId: Int): Long = 2L - override def composeRequestTag(requestType: RequestType.Value): Long = 3L + + override def request(requestType: RequestType.Value, + request: ByteBuffer, cb: TransactionCallback): Transaction = { + requests += 1 + cb(mockTransaction) + mockTransaction + } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/UCXConnectionSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/UCXConnectionSuite.scala new file mode 100644 index 00000000000..a1633691945 --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/UCXConnectionSuite.scala @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shuffle + +import com.nvidia.spark.rapids.shuffle.ucx.UCXConnection._ +import org.scalatest.FunSuite + +class UCXConnectionSuite extends FunSuite { + test("generate active message id") { + Seq(0, 1, 2, 100000, Int.MaxValue).foreach { eId => + assertResult(eId)( + extractExecutorId(composeRequestHeader(eId, 123L))) + } + } + + test("negative executor ids are invalid") { + Seq(-1, -1 * Int.MaxValue).foreach { eId => + assertThrows[IllegalArgumentException]( + extractExecutorId(composeRequestHeader(eId, 123L))) + } + } + + test("executor id longer that doesn't fit in an int is invalid") { + assertThrows[IllegalArgumentException](composeRequestHeader(Long.MaxValue, 123L)) + } + + test("transaction ids can rollover") { + assertResult(0)(composeRequestHeader(0, 0x0000000100000000L)) + assertResult(10)(composeRequestHeader(0, 0x000000010000000AL)) + } +}