Skip to content

Commit

Permalink
Ensure memory copy integrity during transport setup (#731)
Browse files Browse the repository at this point in the history
  • Loading branch information
wenkaidu authored Apr 25, 2023
1 parent 79df139 commit 36e453c
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions src/transport.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "bootstrap.h"
#define ENABLE_TIMER 0
#include "timer.h"
#include <cstring>

struct ncclTransport* ncclTransports[NTRANSPORTS] = {
&p2pTransport,
Expand Down Expand Up @@ -134,8 +135,13 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph*
struct ncclConnector* conn = comm->channels[c].peers[sendPeer].send + connIndex;
NCCLCHECKGOTO(conn->transportComm->connect(comm, sendData++, 1, comm->rank, conn), ret, fail);
conn->connected = 1;
CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[sendPeer].send[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail);
CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[sendPeer].send[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail);
do {
struct ncclConnInfo connInfo;
CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[sendPeer].send[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail);
CUDACHECKGOTO(cudaMemcpyAsync(&connInfo, &comm->channels[c].devPeers[sendPeer].send[connIndex], sizeof(struct ncclConnInfo), cudaMemcpyDeviceToHost, comm->hostStream.cudaStream), ret, fail);
CUDACHECKGOTO(hipStreamSynchronize(comm->hostStream.cudaStream), ret, fail);
if (memcmp(&connInfo, &conn->conn, sizeof(struct ncclConnInfo)) == 0) break;
} while (true);
}
}
TIME_STOP(3);
Expand All @@ -145,7 +151,13 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph*
struct ncclConnector* conn = comm->channels[c].peers[recvPeer].recv + connIndex;
NCCLCHECKGOTO(conn->transportComm->connect(comm, recvData++, 1, comm->rank, conn), ret, fail);
conn->connected = 1;
CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[recvPeer].recv[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail);
do {
struct ncclConnInfo connInfo;
CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[recvPeer].recv[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail);
CUDACHECKGOTO(cudaMemcpyAsync(&connInfo, &comm->channels[c].devPeers[recvPeer].recv[connIndex], sizeof(struct ncclConnInfo), cudaMemcpyDeviceToHost, comm->hostStream.cudaStream), ret, fail);
CUDACHECKGOTO(hipStreamSynchronize(comm->hostStream.cudaStream), ret, fail);
if (memcmp(&connInfo, &conn->conn, sizeof(struct ncclConnInfo)) == 0) break;
} while (true);
}
}
TIME_STOP(4);
Expand Down

0 comments on commit 36e453c

Please sign in to comment.