Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change shuffle metadata messages to use UCX Active Messages #2409

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Name | Description | Default Value
<a name="shuffle.transport.earlyStart"></a>spark.rapids.shuffle.transport.earlyStart|Enable early connection establishment for RAPIDS Shuffle|true
<a name="shuffle.transport.earlyStart.heartbeatInterval"></a>spark.rapids.shuffle.transport.earlyStart.heartbeatInterval|Shuffle early start heartbeat interval (milliseconds)|5000
<a name="shuffle.transport.maxReceiveInflightBytes"></a>spark.rapids.shuffle.transport.maxReceiveInflightBytes|Maximum aggregate amount of bytes that be fetched at any given time from peers during shuffle|1073741824
<a name="shuffle.ucx.activeMessages.mode"></a>spark.rapids.shuffle.ucx.activeMessages.mode|Set to 'rndv', 'eager', or 'auto' to indicate what UCX Active Message mode to use. We set 'rndv' (Rendezvous) by default because UCX 1.10.x doesn't support 'eager' fully. This restriction can be lifted if the user is running UCX 1.11+.|rndv
<a name="shuffle.ucx.managementServerHost"></a>spark.rapids.shuffle.ucx.managementServerHost|The host to be used to start the management server|null
<a name="shuffle.ucx.useWakeup"></a>spark.rapids.shuffle.ucx.useWakeup|When set to true, use UCX's event-based progress (epoll) in order to wake up the progress thread when needed, instead of a hot loop.|true
<a name="sql.batchSizeBytes"></a>spark.rapids.sql.batchSizeBytes|Set the target number of bytes for a GPU batch. Splits sizes for input data is covered by separate configs. The maximum setting is 2 GB to avoid exceeding the cudf row count limit of a column.|2147483647
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -68,7 +68,7 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon

private[this] lazy val ucx = {
logWarning("UCX Shuffle Transport Enabled")
val ucxImpl = new UCX(shuffleServerId, rapidsConf)
val ucxImpl = new UCX(this, shuffleServerId, rapidsConf)
ucxImpl.init()

initBounceBufferPools(bounceBufferSize,
Expand All @@ -88,7 +88,7 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
ucxImpl
}

override def getMetaBuffer(size: Long): RefCountedDirectByteBuffer = {
override def getDirectByteBuffer(size: Long): RefCountedDirectByteBuffer = {
if (size > rapidsConf.shuffleMaxMetadataSize) {
logWarning(s"Large metadata message size $size B, larger " +
s"than ${rapidsConf.shuffleMaxMetadataSize} B. " +
Expand Down Expand Up @@ -251,8 +251,7 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
clientConnection,
this,
clientExecutor,
clientCopyExecutor,
rapidsConf.shuffleMaxMetadataSize)
clientCopyExecutor)
})
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@

package com.nvidia.spark.rapids.shuffle.ucx

import java.nio.ByteBuffer
import java.util.concurrent.{ConcurrentLinkedQueue, TimeUnit}
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.locks.ReentrantLock
Expand All @@ -24,7 +25,7 @@ import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{NvtxColor, NvtxRange}
import com.nvidia.spark.rapids.shuffle.{AddressLengthTag, Transaction, TransactionCallback, TransactionStats, TransactionStatus, TransportUtils}
import com.nvidia.spark.rapids.shuffle.{AddressLengthTag, RefCountedDirectByteBuffer, RequestType, Transaction, TransactionCallback, TransactionStats, TransactionStatus, TransportUtils}
import org.openucx.jucx.ucp.UcpRequest

import org.apache.spark.internal.Logging
Expand All @@ -42,6 +43,12 @@ private[ucx] object UCXTransactionType extends Enumeration {
private[ucx] class UCXTransaction(conn: UCXConnection, val txId: Long)
extends Transaction with Logging {

// Active Messages: header used to disambiguate responses for a request
private var header: Option[Long] = None

// Type of request this transaction is handling, used to simplify the `respond` method
private var messageType: Option[RequestType.Value] = None

// various threads can access the status during the course of a Transaction
// the UCX progress thread, client/server pools, and the executor task thread
@volatile private[this] var status = TransactionStatus.NotStarted
Expand All @@ -67,14 +74,12 @@ private[ucx] class UCXTransaction(conn: UCXConnection, val txId: Long)
* This will mark the tag as having an error for debugging purposes.
*
* @param tag the tag involved in the error
* @param errorMsg error description from UCX
*/
def handleTagError(tag: Long, errorMsg: String): Unit = {
def handleTagError(tag: Long): Unit = {
if (registeredByTag.contains(tag)) {
val origBuff = registeredByTag(tag)
errored += origBuff
}
errorMessage = Some(errorMsg)
}

/**
Expand Down Expand Up @@ -106,7 +111,7 @@ private[ucx] class UCXTransaction(conn: UCXConnection, val txId: Long)

private var hadError: Boolean = false

private[ucx] var txCallback: TransactionStatus.Value => Unit = _
private var txCallback: TransactionStatus.Value => Unit = _

// Start and end times used for metrics
private var start: Long = 0
Expand Down Expand Up @@ -244,7 +249,7 @@ private[ucx] class UCXTransaction(conn: UCXConnection, val txId: Long)
def registerForSend(alt: AddressLengthTag): Unit = {
registeredByTag.put(alt.tag, alt)
registered += alt
logTrace(s"Assigned tag for send ${TransportUtils.formatTag(alt.tag)} for message at " +
logTrace(s"Assigned tag for send ${TransportUtils.toHex(alt.tag)} for message at " +
s"buffer ${alt.address} with size ${alt.length}")
}

Expand All @@ -254,7 +259,7 @@ private[ucx] class UCXTransaction(conn: UCXConnection, val txId: Long)
def registerForReceive(alt: AddressLengthTag): Unit = {
registered += alt
registeredByTag.put(alt.tag, alt)
logTrace(s"Assigned tag for receive ${TransportUtils.formatTag(alt.tag)} for message at " +
logTrace(s"Assigned tag for receive ${TransportUtils.toHex(alt.tag)} for message at " +
s"buffer ${alt.address} with size ${alt.length}")
}

Expand Down Expand Up @@ -323,6 +328,9 @@ private[ucx] class UCXTransaction(conn: UCXConnection, val txId: Long)
hadError = true
}
}
// close any active message we may have
activeMessageData.foreach(_.close())
activeMessageData = None
} catch {
case t: Throwable =>
if (ex == null) {
Expand Down Expand Up @@ -360,5 +368,86 @@ private[ucx] class UCXTransaction(conn: UCXConnection, val txId: Long)
}

var callbackCalled: Boolean = false

private var activeMessageData: Option[RefCountedDirectByteBuffer] = None

override def respond(response: ByteBuffer,
jlowe marked this conversation as resolved.
Show resolved Hide resolved
cb: TransactionCallback): Transaction = {
logDebug(s"Responding to ${peerExecutorId} at ${TransportUtils.toHex(this.getHeader)} " +
s"with ${response}")

conn match {
case serverConnection: UCXServerConnection =>
serverConnection.respond(peerExecutorId(), messageType.get, this.getHeader, response, cb)
case _ =>
throw new IllegalStateException("Tried to respond using a client connection. " +
"This is not supported.")
}
}

def complete(status: TransactionStatus.Value,
messageType: Option[RequestType.Value] = None,
header: Option[Long] = None,
message: Option[RefCountedDirectByteBuffer] = None,
errorMessage: Option[String] = None): Unit = {
setHeader(header)
setActiveMessageData(message)
setMessageType(messageType)
setErrorMessage(errorMessage)
setHeader(header)
txCallback(status)
}

def completeWithError(errorMsg: String): Unit = {
complete(TransactionStatus.Error,
errorMessage = Option(errorMsg))
}

def completeCancelled(requestType: RequestType.Value, hdr: Long): Unit = {
complete(TransactionStatus.Cancelled,
messageType = Option(requestType),
header = Option(hdr))
}

def completeWithSuccess(
messageType: RequestType.Value,
hdr: Option[Long],
message: Option[RefCountedDirectByteBuffer]): Unit = {
complete(TransactionStatus.Success,
messageType = Option(messageType),
header = hdr,
message = message)
}
abellina marked this conversation as resolved.
Show resolved Hide resolved

// Reference count is not updated here. The caller is responsible to close
private[ucx] def setActiveMessageData(data: Option[RefCountedDirectByteBuffer]): Unit = {
activeMessageData = data
}

// Reference count is not updated here. The caller is responsible to close
override def releaseMessage(): RefCountedDirectByteBuffer = {
val msg = activeMessageData.get
activeMessageData = None
msg
}

private[ucx] def setHeader(id: Option[Long]): Unit = header = id

override def getHeader: Long = {
require(header.nonEmpty,
"Attempted to get an Active Message header, but it was not set!")
header.get
}

private[ucx] def setMessageType(msgType: Option[RequestType.Value]): Unit = {
messageType = msgType
}

private[ucx] def setErrorMessage(errorMsg: Option[String]): Unit = {
errorMessage = errorMessage
}

override def peerExecutorId(): Long =
UCXConnection.extractExecutorId(getHeader)
}

11 changes: 1 addition & 10 deletions sql-plugin/src/main/format/ShuffleMetadataRequest.fbs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2020, NVIDIA CORPORATION.
// Copyright (c) 2019-2021, NVIDIA CORPORATION.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -24,15 +24,6 @@ table BlockIdMeta {

/// Flat buffer for Rapids UCX Shuffle Metadata Request.
table MetadataRequest {
/// Spark executor ID
executor_id: long;

/// UCX message tag to use when sending the response
response_tag: long;

/// maximum size in bytes for the response message.
max_response_size: long;

/// array of shuffle block descriptors for which metadata is needed
block_ids : [BlockIdMeta];
}
Expand Down
5 changes: 0 additions & 5 deletions sql-plugin/src/main/format/ShuffleMetadataResponse.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ table TableMeta {

/// Flat buffer for Rapids UCX Shuffle Metadata Response
table MetadataResponse {
/// Buffer size in bytes required to hold the full response. If this value is larger than the
/// maximum response size sent in the corresponding request then the metadata contents in this
/// response are incomplete and must be re-requested with a larger response buffer allocated.
full_response_size: long;

/// metadata for each table
table_metas: [TableMeta];
}
Expand Down
6 changes: 0 additions & 6 deletions sql-plugin/src/main/format/ShuffleTransferRequest.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,6 @@ table BufferTransferRequest {

/// Flat buffer for Rapids UCX Shuffle Transfer Request.
table TransferRequest {
/// peer executor id to send response to
executor_id: long;

/// UCX message tag to use when sending the response
response_tag: long;

/// array of table requests to transfer
requests : [BufferTransferRequest];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,46 +17,22 @@ public final class MetadataRequest extends Table {
public void __init(int _i, ByteBuffer _bb) { bb_pos = _i; bb = _bb; vtable_start = bb_pos - bb.getInt(bb_pos); vtable_size = bb.getShort(vtable_start); }
public MetadataRequest __assign(int _i, ByteBuffer _bb) { __init(_i, _bb); return this; }

/**
* Spark executor ID
*/
public long executorId() { int o = __offset(4); return o != 0 ? bb.getLong(o + bb_pos) : 0L; }
public boolean mutateExecutorId(long executor_id) { int o = __offset(4); if (o != 0) { bb.putLong(o + bb_pos, executor_id); return true; } else { return false; } }
/**
* UCX message tag to use when sending the response
*/
public long responseTag() { int o = __offset(6); return o != 0 ? bb.getLong(o + bb_pos) : 0L; }
public boolean mutateResponseTag(long response_tag) { int o = __offset(6); if (o != 0) { bb.putLong(o + bb_pos, response_tag); return true; } else { return false; } }
/**
* maximum size in bytes for the response message.
*/
public long maxResponseSize() { int o = __offset(8); return o != 0 ? bb.getLong(o + bb_pos) : 0L; }
public boolean mutateMaxResponseSize(long max_response_size) { int o = __offset(8); if (o != 0) { bb.putLong(o + bb_pos, max_response_size); return true; } else { return false; } }
/**
* array of shuffle block descriptors for which metadata is needed
*/
public BlockIdMeta blockIds(int j) { return blockIds(new BlockIdMeta(), j); }
public BlockIdMeta blockIds(BlockIdMeta obj, int j) { int o = __offset(10); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; }
public int blockIdsLength() { int o = __offset(10); return o != 0 ? __vector_len(o) : 0; }
public BlockIdMeta blockIds(BlockIdMeta obj, int j) { int o = __offset(4); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; }
public int blockIdsLength() { int o = __offset(4); return o != 0 ? __vector_len(o) : 0; }

public static int createMetadataRequest(FlatBufferBuilder builder,
long executor_id,
long response_tag,
long max_response_size,
int block_idsOffset) {
builder.startObject(4);
MetadataRequest.addMaxResponseSize(builder, max_response_size);
MetadataRequest.addResponseTag(builder, response_tag);
MetadataRequest.addExecutorId(builder, executor_id);
builder.startObject(1);
MetadataRequest.addBlockIds(builder, block_idsOffset);
return MetadataRequest.endMetadataRequest(builder);
}

public static void startMetadataRequest(FlatBufferBuilder builder) { builder.startObject(4); }
public static void addExecutorId(FlatBufferBuilder builder, long executorId) { builder.addLong(0, executorId, 0L); }
public static void addResponseTag(FlatBufferBuilder builder, long responseTag) { builder.addLong(1, responseTag, 0L); }
public static void addMaxResponseSize(FlatBufferBuilder builder, long maxResponseSize) { builder.addLong(2, maxResponseSize, 0L); }
public static void addBlockIds(FlatBufferBuilder builder, int blockIdsOffset) { builder.addOffset(3, blockIdsOffset, 0); }
public static void startMetadataRequest(FlatBufferBuilder builder) { builder.startObject(1); }
public static void addBlockIds(FlatBufferBuilder builder, int blockIdsOffset) { builder.addOffset(0, blockIdsOffset, 0); }
public static int createBlockIdsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startBlockIdsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static int endMetadataRequest(FlatBufferBuilder builder) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,22 @@ public final class MetadataResponse extends Table {
public void __init(int _i, ByteBuffer _bb) { bb_pos = _i; bb = _bb; vtable_start = bb_pos - bb.getInt(bb_pos); vtable_size = bb.getShort(vtable_start); }
public MetadataResponse __assign(int _i, ByteBuffer _bb) { __init(_i, _bb); return this; }

/**
* Buffer size in bytes required to hold the full response. If this value is larger than the
* maximum response size sent in the corresponding request then the metadata contents in this
* response are incomplete and must be re-requested with a larger response buffer allocated.
*/
public long fullResponseSize() { int o = __offset(4); return o != 0 ? bb.getLong(o + bb_pos) : 0L; }
public boolean mutateFullResponseSize(long full_response_size) { int o = __offset(4); if (o != 0) { bb.putLong(o + bb_pos, full_response_size); return true; } else { return false; } }
/**
* metadata for each table
*/
public TableMeta tableMetas(int j) { return tableMetas(new TableMeta(), j); }
public TableMeta tableMetas(TableMeta obj, int j) { int o = __offset(6); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; }
public int tableMetasLength() { int o = __offset(6); return o != 0 ? __vector_len(o) : 0; }
public TableMeta tableMetas(TableMeta obj, int j) { int o = __offset(4); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; }
public int tableMetasLength() { int o = __offset(4); return o != 0 ? __vector_len(o) : 0; }

public static int createMetadataResponse(FlatBufferBuilder builder,
long full_response_size,
int table_metasOffset) {
builder.startObject(2);
MetadataResponse.addFullResponseSize(builder, full_response_size);
builder.startObject(1);
MetadataResponse.addTableMetas(builder, table_metasOffset);
return MetadataResponse.endMetadataResponse(builder);
}

public static void startMetadataResponse(FlatBufferBuilder builder) { builder.startObject(2); }
public static void addFullResponseSize(FlatBufferBuilder builder, long fullResponseSize) { builder.addLong(0, fullResponseSize, 0L); }
public static void addTableMetas(FlatBufferBuilder builder, int tableMetasOffset) { builder.addOffset(1, tableMetasOffset, 0); }
public static void startMetadataResponse(FlatBufferBuilder builder) { builder.startObject(1); }
public static void addTableMetas(FlatBufferBuilder builder, int tableMetasOffset) { builder.addOffset(0, tableMetasOffset, 0); }
public static int createTableMetasVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startTableMetasVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static int endMetadataResponse(FlatBufferBuilder builder) {
Expand Down
Loading