From 6967c52ccd396a62886463608a0280a1348896eb Mon Sep 17 00:00:00 2001 From: nikopj Date: Fri, 19 Jul 2024 16:11:02 -0400 Subject: [PATCH 1/2] handle complex datatype --- src/base.jl | 1 + src/collective.jl | 29 +++--- test/runtests.jl | 250 ++++++++++++++++++++++++---------------------- 3 files changed, 150 insertions(+), 130 deletions(-) diff --git a/src/base.jl b/src/base.jl index cb1efc9..65c03cc 100644 --- a/src/base.jl +++ b/src/base.jl @@ -53,3 +53,4 @@ ncclDataType_t(::Type{UInt64}) = ncclUint64 ncclDataType_t(::Type{Float16}) = ncclFloat16 ncclDataType_t(::Type{Float32}) = ncclFloat32 ncclDataType_t(::Type{Float64}) = ncclFloat64 +ncclDataType_t(::Type{Complex{T}}) where {T} = ncclDataType_t(T) diff --git a/src/collective.jl b/src/collective.jl index 054fa3a..1608cb6 100644 --- a/src/collective.jl +++ b/src/collective.jl @@ -1,3 +1,6 @@ +count(X::CuArray{T}) where {T} = length(X) +count(X::CuArray{Complex{T}}) where {T} = 2*length(X) + """ NCCL.Allreduce!( sendbuf, recvbuf, op, comm::Communicator; @@ -11,11 +14,11 @@ or [`NCCL.avg`](@ref)), writing the result to `recvbuf` to all ranks. """ function Allreduce!(sendbuf, recvbuf, op, comm::Communicator; stream::CuStream=default_device_stream(comm)) - count = length(recvbuf) - @assert length(sendbuf) == count + a_count = count(recvbuf) + @assert count(sendbuf) == a_count data_type = ncclDataType_t(eltype(recvbuf)) _op = ncclRedOp_t(op) - ncclAllReduce(sendbuf, recvbuf, count, data_type, _op, comm, stream) + ncclAllReduce(sendbuf, recvbuf, a_count, data_type, _op, comm, stream) return recvbuf end @@ -47,8 +50,8 @@ Copies array the `sendbuf` on rank `root` to `recvbuf` on all ranks. function Broadcast!(sendbuf, recvbuf, comm::Communicator; root::Integer=0, stream::CuStream=default_device_stream(comm)) data_type = ncclDataType_t(eltype(recvbuf)) - count = length(recvbuf) - ncclBroadcast(sendbuf, recvbuf, count, data_type, root, comm, stream) + a_count = count(recvbuf) + ncclBroadcast(sendbuf, recvbuf, a_count, data_type, root, comm, stream) return recvbuf end function Broadcast!(sendrecvbuf, comm::Communicator; root::Integer=0, @@ -72,9 +75,9 @@ or `[`NCCL.avg`](@ref)`), writing the result to `recvbuf` on rank `root`. function Reduce!(sendbuf, recvbuf, op, comm::Communicator; root::Integer=0, stream::CuStream=default_device_stream(comm)) data_type = ncclDataType_t(eltype(recvbuf)) - count = length(recvbuf) + a_count = count(recvbuf) _op = ncclRedOp_t(op) - ncclReduce(sendbuf, recvbuf, count, data_type, _op, root, comm, stream) + ncclReduce(sendbuf, recvbuf, a_count, data_type, _op, root, comm, stream) return recvbuf end function Reduce!(sendrecvbuf, op, comm::Communicator; root::Integer=0, @@ -96,9 +99,9 @@ Concatenate `sendbuf` from each rank into `recvbuf` on all ranks. function Allgather!(sendbuf, recvbuf, comm::Communicator; stream::CuStream=default_device_stream(comm)) data_type = ncclDataType_t(eltype(recvbuf)) - sendcount = length(sendbuf) - @assert length(recvbuf) == sendcount * size(comm) - ncclAllGather(sendbuf, recvbuf, sendcount, data_type, comm, stream) + senda_count = count(sendbuf) + @assert count(recvbuf) == senda_count * size(comm) + ncclAllGather(sendbuf, recvbuf, senda_count, data_type, comm, stream) return recvbuf end @@ -117,10 +120,10 @@ scattered over the devices such that `recvbuf` on each rank will contain the """ function ReduceScatter!(sendbuf, recvbuf, op, comm::Communicator; stream::CuStream=default_device_stream(comm)) - recvcount = length(recvbuf) - @assert length(sendbuf) == recvcount * size(comm) + recva_count = count(recvbuf) + @assert count(sendbuf) == recva_count * size(comm) data_type = ncclDataType_t(eltype(recvbuf)) _op = ncclRedOp_t(op) - ncclReduceScatter(sendbuf, recvbuf, recvcount, data_type, _op, comm, stream) + ncclReduceScatter(sendbuf, recvbuf, recva_count, data_type, _op, comm, stream) return recvbuf end diff --git a/test/runtests.jl b/test/runtests.jl index df695c0..1dc2dfb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,171 +26,187 @@ end devs = CUDA.devices() comms = NCCL.Communicators(devs) - @testset "sum" begin - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) - N = 512 + @testset "$T" for T in (Float64, ComplexF64) + @testset "sum" begin + recvbuf = Vector{CuVector{T}}(undef, length(devs)) + sendbuf = Vector{CuVector{T}}(undef, length(devs)) + N = 512 + for (ii, dev) in enumerate(devs) + CUDA.device!(ii - 1) + sendbuf[ii] = CuArray(fill(T(ii), N)) + recvbuf[ii] = CUDA.zeros(T, N) + end + NCCL.group() do + for ii in 1:length(devs) + NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], +, comms[ii]) + end + end + answer = sum(1:length(devs)) + for (ii, dev) in enumerate(devs) + device!(ii - 1) + crecv = collect(recvbuf[ii]) + @test all(crecv .== answer) + end + end + + @testset "NCCL.avg" begin + recvbuf = Vector{CuVector{T}}(undef, length(devs)) + sendbuf = Vector{CuVector{T}}(undef, length(devs)) + N = 512 + for (ii, dev) in enumerate(devs) + CUDA.device!(ii - 1) + sendbuf[ii] = CuArray(fill(T(ii), N)) + recvbuf[ii] = CUDA.zeros(T, N) + end + NCCL.group() do + for ii in 1:length(devs) + NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], NCCL.avg, comms[ii]) + end + end + answer = sum(1:length(devs)) / length(devs) + for (ii, dev) in enumerate(devs) + device!(ii - 1) + crecv = collect(recvbuf[ii]) + @test all(crecv .≈ answer) + end + end + end +end + +@testset "Broadcast!" begin + devs = CUDA.devices() + comms = NCCL.Communicators(devs) + + @testset "$T" for T in (Float64, ComplexF64) + recvbuf = Vector{CuVector{T}}(undef, length(devs)) + sendbuf = Vector{CuVector{T}}(undef, length(devs)) + root = 0 for (ii, dev) in enumerate(devs) CUDA.device!(ii - 1) - sendbuf[ii] = CuArray(fill(Float64(ii), N)) - recvbuf[ii] = CUDA.zeros(Float64, N) + sendbuf[ii] = (ii - 1) == root ? CuArray(fill(T(1.0), 512)) : CUDA.zeros(T, 512) + recvbuf[ii] = CUDA.zeros(T, 512) end NCCL.group() do for ii in 1:length(devs) - NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], +, comms[ii]) + NCCL.Broadcast!(sendbuf[ii], recvbuf[ii], comms[ii]; root) end end - answer = sum(1:length(devs)) + answer = 1.0 for (ii, dev) in enumerate(devs) device!(ii - 1) crecv = collect(recvbuf[ii]) @test all(crecv .== answer) end end +end - @testset "NCCL.avg" begin - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) - N = 512 +@testset "Reduce!" begin + devs = CUDA.devices() + comms = NCCL.Communicators(devs) + @testset "$T" for T in (Float64, ComplexF64) + recvbuf = Vector{CuVector{T}}(undef, length(devs)) + sendbuf = Vector{CuVector{T}}(undef, length(devs)) + root = 0 for (ii, dev) in enumerate(devs) CUDA.device!(ii - 1) - sendbuf[ii] = CuArray(fill(Float64(ii), N)) - recvbuf[ii] = CUDA.zeros(Float64, N) + sendbuf[ii] = CuArray(fill(T(ii), 512)) + recvbuf[ii] = CUDA.zeros(T, 512) end NCCL.group() do for ii in 1:length(devs) - NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], NCCL.avg, comms[ii]) + NCCL.Reduce!(sendbuf[ii], recvbuf[ii], +, comms[ii]; root) end end - answer = sum(1:length(devs)) / length(devs) for (ii, dev) in enumerate(devs) + answer = (ii - 1) == root ? sum(1:length(devs)) : 0.0 device!(ii - 1) crecv = collect(recvbuf[ii]) - @test all(crecv .≈ answer) + @test all(crecv .== answer) end end end -@testset "Broadcast!" begin +@testset "Allgather!" begin devs = CUDA.devices() comms = NCCL.Communicators(devs) - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) - root = 0 - for (ii, dev) in enumerate(devs) - CUDA.device!(ii - 1) - sendbuf[ii] = (ii - 1) == root ? CuArray(fill(Float64(1.0), 512)) : CUDA.zeros(Float64, 512) - recvbuf[ii] = CUDA.zeros(Float64, 512) - end - NCCL.group() do - for ii in 1:length(devs) - NCCL.Broadcast!(sendbuf[ii], recvbuf[ii], comms[ii]; root) - end - end - answer = 1.0 - for (ii, dev) in enumerate(devs) - device!(ii - 1) - crecv = collect(recvbuf[ii]) - @test all(crecv .== answer) - end -end -@testset "Reduce!" begin - devs = CUDA.devices() - comms = NCCL.Communicators(devs) - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) - root = 0 - for (ii, dev) in enumerate(devs) - CUDA.device!(ii - 1) - sendbuf[ii] = CuArray(fill(Float64(ii), 512)) - recvbuf[ii] = CUDA.zeros(Float64, 512) - end - NCCL.group() do - for ii in 1:length(devs) - NCCL.Reduce!(sendbuf[ii], recvbuf[ii], +, comms[ii]; root) + @testset "$T" for T in (Float64, ComplexF64) + recvbuf = Vector{CuVector{T}}(undef, length(devs)) + sendbuf = Vector{CuVector{T}}(undef, length(devs)) + for (ii, dev) in enumerate(devs) + CUDA.device!(ii - 1) + sendbuf[ii] = CuArray(fill(T(ii), 512)) + recvbuf[ii] = CUDA.zeros(T, length(devs)*512) end - end - for (ii, dev) in enumerate(devs) - answer = (ii - 1) == root ? sum(1:length(devs)) : 0.0 - device!(ii - 1) - crecv = collect(recvbuf[ii]) - @test all(crecv .== answer) - end -end - -@testset "Allgather!" begin - devs = CUDA.devices() - comms = NCCL.Communicators(devs) - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) - for (ii, dev) in enumerate(devs) - CUDA.device!(ii - 1) - sendbuf[ii] = CuArray(fill(Float64(ii), 512)) - recvbuf[ii] = CUDA.zeros(Float64, length(devs)*512) - end - NCCL.group() do - for ii in 1:length(devs) - NCCL.Allgather!(sendbuf[ii], recvbuf[ii], comms[ii]) + NCCL.group() do + for ii in 1:length(devs) + NCCL.Allgather!(sendbuf[ii], recvbuf[ii], comms[ii]) + end + end + answer = vec(repeat(1:length(devs), inner=512)) + for (ii, dev) in enumerate(devs) + device!(ii - 1) + crecv = collect(recvbuf[ii]) + @test all(crecv .== answer) end - end - answer = vec(repeat(1:length(devs), inner=512)) - for (ii, dev) in enumerate(devs) - device!(ii - 1) - crecv = collect(recvbuf[ii]) - @test all(crecv .== answer) end end @testset "ReduceScatter!" begin devs = CUDA.devices() comms = NCCL.Communicators(devs) - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) - for (ii, dev) in enumerate(devs) - CUDA.device!(ii - 1) - sendbuf[ii] = CuArray(vec(repeat(collect(1:length(devs)), inner=2))) - recvbuf[ii] = CUDA.zeros(Float64, 2) - end - NCCL.group() do - for ii in 1:length(devs) - NCCL.ReduceScatter!(sendbuf[ii], recvbuf[ii], +, comms[ii]) + + @testset "$T" for T in (Float64, ComplexF64) + recvbuf = Vector{CuVector{T}}(undef, length(devs)) + sendbuf = Vector{CuVector{T}}(undef, length(devs)) + for (ii, dev) in enumerate(devs) + CUDA.device!(ii - 1) + sendbuf[ii] = CuArray(vec(repeat(collect(1:length(devs)), inner=2))) + recvbuf[ii] = CUDA.zeros(T, 2) + end + NCCL.group() do + for ii in 1:length(devs) + NCCL.ReduceScatter!(sendbuf[ii], recvbuf[ii], +, comms[ii]) + end + end + for (ii, dev) in enumerate(devs) + answer = length(devs)*ii + device!(ii - 1) + crecv = collect(recvbuf[ii]) + @test all(crecv .== answer) end - end - for (ii, dev) in enumerate(devs) - answer = length(devs)*ii - device!(ii - 1) - crecv = collect(recvbuf[ii]) - @test all(crecv .== answer) end end @testset "Send/Recv" begin devs = CUDA.devices() comms = NCCL.Communicators(devs) - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) - N = 512 - for (ii, dev) in enumerate(devs) - CUDA.device!(ii - 1) - sendbuf[ii] = CuArray(fill(Float64(ii), N)) - recvbuf[ii] = CUDA.zeros(Float64, N) - end - NCCL.group() do - for ii in 1:length(devs) - comm = comms[ii] - dest = mod(NCCL.rank(comm)+1, NCCL.size(comm)) - source = mod(NCCL.rank(comm)-1, NCCL.size(comm)) - NCCL.Send(sendbuf[ii], comm; dest) - NCCL.Recv!(recvbuf[ii], comm; source) + @testset "$T" for T in (Float64, ComplexF64) + recvbuf = Vector{CuVector{T}}(undef, length(devs)) + sendbuf = Vector{CuVector{T}}(undef, length(devs)) + N = 512 + for (ii, dev) in enumerate(devs) + CUDA.device!(ii - 1) + sendbuf[ii] = CuArray(fill(T(ii), N)) + recvbuf[ii] = CUDA.zeros(T, N) + end + + NCCL.group() do + for ii in 1:length(devs) + comm = comms[ii] + dest = mod(NCCL.rank(comm)+1, NCCL.size(comm)) + source = mod(NCCL.rank(comm)-1, NCCL.size(comm)) + NCCL.Send(sendbuf[ii], comm; dest) + NCCL.Recv!(recvbuf[ii], comm; source) + end + end + for (ii, dev) in enumerate(devs) + answer = mod1(ii - 1, length(devs)) + device!(ii - 1) + crecv = collect(recvbuf[ii]) + @test all(crecv .== answer) end - end - for (ii, dev) in enumerate(devs) - answer = mod1(ii - 1, length(devs)) - device!(ii - 1) - crecv = collect(recvbuf[ii]) - @test all(crecv .== answer) end end From 22fd8c8b283d25dbb37edbf2e1f3f483aad211ff Mon Sep 17 00:00:00 2001 From: nikopj Date: Tue, 13 Aug 2024 21:17:00 -0400 Subject: [PATCH 2/2] added whitespace to force push --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index dfde1e5..2b044a7 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ NCCL.jl ======= -A Julia wrapper for the [NVIDIA Collective Communications Library (NCCL)](https://developer.nvidia.com/nccl). +A Julia wrapper for the [NVIDIA Collective Communications Library (NCCL)](https://developer.nvidia.com/nccl).