From 36e453c61e8ddc5913ffc411f9e9bd2cc7090d6a Mon Sep 17 00:00:00 2001 From: Wenkai Du <43822138+wenkaidu@users.noreply.github.com> Date: Tue, 25 Apr 2023 14:41:43 -0700 Subject: [PATCH] Ensure memory copy integrity during transport setup (#731) --- src/transport.cc | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/transport.cc b/src/transport.cc index 0f5cbf6fa..48ae3e02d 100644 --- a/src/transport.cc +++ b/src/transport.cc @@ -10,6 +10,7 @@ #include "bootstrap.h" #define ENABLE_TIMER 0 #include "timer.h" +#include struct ncclTransport* ncclTransports[NTRANSPORTS] = { &p2pTransport, @@ -134,8 +135,13 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* struct ncclConnector* conn = comm->channels[c].peers[sendPeer].send + connIndex; NCCLCHECKGOTO(conn->transportComm->connect(comm, sendData++, 1, comm->rank, conn), ret, fail); conn->connected = 1; - CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[sendPeer].send[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail); - CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[sendPeer].send[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail); + do { + struct ncclConnInfo connInfo; + CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[sendPeer].send[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail); + CUDACHECKGOTO(cudaMemcpyAsync(&connInfo, &comm->channels[c].devPeers[sendPeer].send[connIndex], sizeof(struct ncclConnInfo), cudaMemcpyDeviceToHost, comm->hostStream.cudaStream), ret, fail); + CUDACHECKGOTO(hipStreamSynchronize(comm->hostStream.cudaStream), ret, fail); + if (memcmp(&connInfo, &conn->conn, sizeof(struct ncclConnInfo)) == 0) break; + } while (true); } } TIME_STOP(3); @@ -145,7 +151,13 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* struct ncclConnector* conn = comm->channels[c].peers[recvPeer].recv + connIndex; NCCLCHECKGOTO(conn->transportComm->connect(comm, recvData++, 1, comm->rank, conn), ret, fail); conn->connected = 1; - CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[recvPeer].recv[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail); + do { + struct ncclConnInfo connInfo; + CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[recvPeer].recv[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail); + CUDACHECKGOTO(cudaMemcpyAsync(&connInfo, &comm->channels[c].devPeers[recvPeer].recv[connIndex], sizeof(struct ncclConnInfo), cudaMemcpyDeviceToHost, comm->hostStream.cudaStream), ret, fail); + CUDACHECKGOTO(hipStreamSynchronize(comm->hostStream.cudaStream), ret, fail); + if (memcmp(&connInfo, &conn->conn, sizeof(struct ncclConnInfo)) == 0) break; + } while (true); } } TIME_STOP(4);