From 147fe93c6882bebf22184fcaca6dbc25fa1efa44 Mon Sep 17 00:00:00 2001 From: Wenkai Du <43822138+wenkaidu@users.noreply.github.com> Date: Tue, 25 Apr 2023 14:41:43 -0700 Subject: [PATCH] Ensure memory copy integrity during transport setup (#731) (cherry picked from commit 36e453c61e8ddc5913ffc411f9e9bd2cc7090d6a) --- src/transport.cc | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/transport.cc b/src/transport.cc index 04a329ed4..6b9c82c0a 100644 --- a/src/transport.cc +++ b/src/transport.cc @@ -10,6 +10,7 @@ #include "bootstrap.h" #define ENABLE_TIMER 0 #include "timer.h" +#include struct ncclTransport* ncclTransports[NTRANSPORTS] = { &p2pTransport, @@ -131,8 +132,13 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* struct ncclConnector* conn = comm->channels[c].peers[sendPeer].send + connIndex; NCCLCHECK(conn->transportComm->connect(comm, sendData++, 1, comm->rank, conn)); conn->connected = 1; - CUDACHECK(cudaMemcpyAsync(&comm->channels[c].devPeers[sendPeer].send[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sideStream)); - CUDACHECK(cudaMemcpyAsync(&comm->channels[c].devPeers[sendPeer].send[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sideStream)); + do { + struct ncclConnInfo connInfo; + CUDACHECK(cudaMemcpyAsync(&comm->channels[c].devPeers[sendPeer].send[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sideStream)); + CUDACHECK(cudaMemcpyAsync(&connInfo, &comm->channels[c].devPeers[sendPeer].send[connIndex], sizeof(struct ncclConnInfo), cudaMemcpyDeviceToHost, comm->sideStream)); + CUDACHECK(hipStreamSynchronize(comm->sideStream)); + if (memcmp(&connInfo, &conn->conn, sizeof(struct ncclConnInfo)) == 0) break; + } while (true); } } TIME_STOP(3); @@ -142,7 +148,13 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* struct ncclConnector* conn = comm->channels[c].peers[recvPeer].recv + connIndex; NCCLCHECK(conn->transportComm->connect(comm, recvData++, 1, comm->rank, conn)); conn->connected = 1; - CUDACHECK(cudaMemcpyAsync(&comm->channels[c].devPeers[recvPeer].recv[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sideStream)); + do { + struct ncclConnInfo connInfo; + CUDACHECK(cudaMemcpyAsync(&comm->channels[c].devPeers[recvPeer].recv[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sideStream)); + CUDACHECK(cudaMemcpyAsync(&connInfo, &comm->channels[c].devPeers[recvPeer].recv[connIndex], sizeof(struct ncclConnInfo), cudaMemcpyDeviceToHost, comm->sideStream)); + CUDACHECK(hipStreamSynchronize(comm->sideStream)); + if (memcmp(&connInfo, &conn->conn, sizeof(struct ncclConnInfo)) == 0) break; + } while (true); } } TIME_STOP(4);