From bd964b1528912be017ad3407dcde7693fff22d25 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 27 Nov 2019 12:50:22 -0500 Subject: [PATCH] port improvements from vchuravy/NCCL.jl --- examples/scaffold.jl | 50 ++++++++++++++++++++++++++++++++++++++++++++ src/base.jl | 14 +++++++++++++ src/collective.jl | 6 ++++++ src/communicator.jl | 19 +++++++++++++++++ src/group.jl | 8 ++++++- 5 files changed, 96 insertions(+), 1 deletion(-) create mode 100644 examples/scaffold.jl diff --git a/examples/scaffold.jl b/examples/scaffold.jl new file mode 100644 index 0000000..f61f6d4 --- /dev/null +++ b/examples/scaffold.jl @@ -0,0 +1,50 @@ +import MPI +import NCCL +using CuArrays +using CUDAdrv +using CUDAnative + +MPI.Init() +comm = MPI.COMM_WORLD +myrank = MPI.Comm_rank(comm) +nranks = MPI.Comm_size(comm) + +# Issues: +# - Avoid allocations during allReduce + +print(stdout, ENV) + +@info "MPI initialized" myrank nranks + +if myrank == 0 + uid = NCCL.UniqueID() +else + uid = nothing +end +uid = MPI.bcast(uid, 0, comm)::NCCL.UniqueID + +dev = CuDevice(parse(Int, first(split(ENV["CUDA_VISIBLE_DEVICES"], ",")))) +@info "NCCL uid bcast" myrank uid dev +CUDAnative.device!(dev) + +cuComm = NCCL.Communicator(nranks, uid, myrank) + +recv = CuArray{Float32}(undef, 1024) +send = CuArray{Float32}(undef, 1024) +fill!(send, float(myrank)) + +# Stream to do communication on +stream = CuStream() + +event = CuEvent(CUDAdrv.EVENT_DISABLE_TIMING) +NCCL.allReduce(+, send, recv, cuComm, stream) +CUDAdrv.record(event, stream) # mark communication as done + +# Enqueue a marker on CuDefaultStream to wait on the communication +wait(event) +# Now do work on CuDefaultStream() +# ... + +synchronize(stream) +NCCL.destroy(cuComm) +MPI.Finalize() diff --git a/src/base.jl b/src/base.jl index b74ed18..75f7928 100644 --- a/src/base.jl +++ b/src/base.jl @@ -30,3 +30,17 @@ function ncclDataType(T::DataType) throw(ArgumentError("ncclDataType equivalent for input type $T does not exist!")) end end + +function ncclReductionOp(T::DataType) + if T == typeof(+) + return ncclSum + elseif T == typeof(*) + return ncclProd + elseif T == typeof(min) + return ncclMin + elseif T == typeof(max) + return ncclMax + else + throw(ArgumentError("ncclReductionOp equivalent for input function type $T does not exist!")) + end +end diff --git a/src/collective.jl b/src/collective.jl index 125b005..4dd0f20 100644 --- a/src/collective.jl +++ b/src/collective.jl @@ -2,6 +2,12 @@ export Allreduce!, Broadcast!, Reduce!, Allgather!, ReduceScatter! +function allReduce!(::Op, sendbuf, recvbuf, comm::Communicator; stream=CUDAdrv.CuDefaultStream()) where Op + op = ncclReductionOp(Op) + @assert size(sendbuf) == size(recvbuf) + Allreduce!(sendbuf, recvbuf, length(sendbuf), op, comm, stream=stream) +end + function Allreduce!(sendbuf, recvbuf, count::Integer, op, comm::Communicator; stream::CuStream=CuDefaultStream() ) data_type = ncclDataType(eltype(recvbuf)) ncclAllReduce(sendbuf, recvbuf, count, data_type, op, comm.handle, stream) diff --git a/src/communicator.jl b/src/communicator.jl index 8f50421..f062423 100644 --- a/src/communicator.jl +++ b/src/communicator.jl @@ -28,6 +28,15 @@ end # creates a new communicator (multi thread/process version) +""" + Communicator(nranks, uid, rank) + +Creates a new Communicator (multi thread/process version) +`rank` must be between `0` and `nranks-1` and unique within a communicator +clique. Each rank is associated to a CUDA device which has to be set before +calling `Communicator`. Implicitly synchroniszed with other ranks so it must +be called by different threads/processes or used within `group`. +""" function Communicator(nranks, comm_id, rank) handle_ref = Ref{ncclComm_t}(C_NULL) ncclCommInitRank(handle_ref, nranks, comm_id.internal, rank) @@ -70,3 +79,13 @@ function rank(comm::Communicator) ncclCommUserRank(comm.handle, rank_ref) return rank_ref[] end + +function abort(comm::Communicator) + ncclCommAbort(comm.handle) +end + +function getError(comm::Communicator) + ref = Ref{ncclResult_t}() + ncclCommGetAsyncError(comm.handle, ref) + return NCCLError(ref[]) +end diff --git a/src/group.jl b/src/group.jl index 650965c..bff8cee 100644 --- a/src/group.jl +++ b/src/group.jl @@ -1,6 +1,12 @@ # Group calls -export groupStart, groupEnd +export groupStart, groupEnd, group groupStart() = ncclGroupStart() groupEnd() = ncclGroupEnd() + +function group(f) + groupStart() + f() + groupEnd() +end