From e79b8913058f7ce2f67954ad8141f6de64541d40 Mon Sep 17 00:00:00 2001 From: leios Date: Wed, 6 May 2020 06:46:59 +0900 Subject: [PATCH] adding a version of MPIStateArrays that allows for traversing through wrapper types Co-authored-by: Valentin Churavy --- src/Arrays/MPIStateArrays.jl | 100 +++++++++++++++++++++--------- src/Numerics/DGMethods/DGModel.jl | 4 +- test/Arrays/broadcasting.jl | 4 +- test/Arrays/reductions.jl | 6 +- test/Arrays/reshape.jl | 36 +++++++++++ test/Arrays/runtests.jl | 1 + 6 files changed, 113 insertions(+), 38 deletions(-) create mode 100644 test/Arrays/reshape.jl diff --git a/src/Arrays/MPIStateArrays.jl b/src/Arrays/MPIStateArrays.jl index 52e3559afce..8a6e53ac731 100644 --- a/src/Arrays/MPIStateArrays.jl +++ b/src/Arrays/MPIStateArrays.jl @@ -7,10 +7,14 @@ using LazyArrays using LinearAlgebra using MPI using StaticArrays +using Adapt using ..TicToc using ..VariableTemplates: @vars, varsindex +include("CMBuffers.jl") +using .CMBuffers + using Base.Broadcast: Broadcasted, BroadcastStyle, ArrayStyle # This is so we can do things like @@ -21,8 +25,6 @@ Base.similar(::Type{A}, ::Type{FT}, dims...) where {A <: Array, FT} = Base.similar(::Type{A}, ::Type{FT}, dims...) where {A <: CuArray, FT} = similar(CuArray{FT}, dims...) -include("CMBuffers.jl") -using .CMBuffers cpuify(x::AbstractArray) = convert(Array, x) cpuify(x::Real) = x @@ -101,9 +103,9 @@ mutable struct MPIStateArray{ sendreq = fill(MPI.REQUEST_NULL, nnabr) recvreq = fill(MPI.REQUEST_NULL, nnabr) - # If vmap is not on the device we need to copy it up (we also do not want to - # put it up everytime, so if it's already on the device then we do not do - # anything). + # If vmap is not on the device we need to copy it up (we also do not + # want to put it up everytime, so if it's already on the device then we + # do not do anything). # # Better way than checking the type names? # XXX: Use Adapt.jl vmaprecv = adapt(DA, vmaprecv) @@ -239,6 +241,27 @@ function MPIStateArray{FT, V}( ) end +# MPIDestArray is a union of MPIStateArray and all possible wrappers +@eval const MPIDestArray = Union{ + MPIStateArray, + $( + ( + :($W where {T, N, Dst, Src <: MPIStateArray}) for + (W, _) in Adapt._wrappers + )... + ), +} + +# This creates 2 adaptors for finding the realdata (RealviewAdaptor) and +# data (RawAdaptor) of an adapted MPIStateArray +struct RealviewAdaptor end +Adapt.adapt_storage(to::RealviewAdaptor, arr::MPIStateArray) = arr.realdata +realview(Q) = adapt(RealviewAdaptor(), Q) + +struct RawAdaptor end +Adapt.adapt_storage(to::RawAdaptor, arr::MPIStateArray) = arr.data +rawview(Q) = adapt(RawAdaptor(), Q) + # FIXME: should general cases be handled? function Base.similar( Q::MPIStateArray{OLDFT, V}, @@ -281,9 +304,17 @@ Base.setindex!(Q::MPIStateArray, x...; kw...) = Base.eltype(Q::MPIStateArray, x...; kw...) = eltype(Q.data, x...; kw...) -Base.Array(Q::MPIStateArray) = Array(Q.data) +Base.Array(Q::MPIDestArray) = Array(rawview(Q)) + +Base.fill!(Q::MPIDestArray, x) = fill!(parent(Q), x) + +for (W, ctor) in Adapt._wrappers + @eval begin + BroadcastStyle(::Type{<:$W}) where {T, N, Dst, Src <: MPIDestArray} = + BroadcastStyle(Dst) + end +end -# broadcasting stuff # find the first MPIStateArray among `bc` arguments # based on https://docs.julialang.org/en/v1/manual/interfaces/#Selecting-an-appropriate-output-array-1 @@ -302,40 +333,43 @@ function Base.similar( end # transform all arguments of `bc` from MPIStateArrays to Arrays +function transform_broadcasted(bc::Broadcasted, dest) + transform_broadcasted(bc, rawview(dest)) +end + function transform_broadcasted(bc::Broadcasted, ::Array) transform_array(bc) end + function transform_array(bc::Broadcasted) Broadcasted(bc.f, transform_array.(bc.args), bc.axes) end -transform_array(mpisa::MPIStateArray) = mpisa.realdata -transform_array(x) = x + +transform_array(x) = realview(x) Base.copyto!(dest::Array, src::MPIStateArray) = copyto!(dest, src.data) +Base.copyto!(dest::MPIStateArray, src::Array) = copyto!(dest.data, src) -function Base.copyto!(dest::MPIStateArray, src::MPIStateArray) - copyto!(dest.realdata, src.realdata) +function Base.copyto!(dest::MPIDestArray, src::AbstractArray) + copyto!(rawview(dest), src) dest end -@inline function Base.copyto!(dest::MPIStateArray, bc::Broadcasted{Nothing}) - # check for the case a .= b, where b is an array - if bc.f === identity && bc.args isa Tuple{AbstractArray} - if bc.args isa Tuple{MPIStateArray} - realindices = CartesianIndices(( - axes(dest.data)[1:(end - 1)]..., - dest.realelems, - )) - copyto!(dest.data, realindices, bc.args[1].data, realindices) - else - copyto!(dest.data, bc.args[1]) - end - else - copyto!(dest.realdata, transform_broadcasted(bc, dest.data)) - end +function Base.copyto!(dest::MPIDestArray, src::MPIDestArray) + copyto!(rawview(dest), rawview(src)) dest end +@inline function Base.copyto!(dest::MPIDestArray, bc::Broadcasted{Nothing}) + copyto!(realview(dest), transform_broadcasted(bc, dest)) + dest +end + +@inline Base.copyto!( + dest::MPIDestArray, + bc::Broadcasted{<:Broadcast.AbstractArrayStyle{0}}, +) = copyto!(dest, convert(Broadcasted{Nothing}, bc)) + """ begin_ghost_exchange!(Q::MPIStateArray; dependencies = nothing) @@ -737,15 +771,19 @@ function Base.mapreduce( MPI.Allreduce(cpuify(locreduce), max, Q.mpicomm) end -# helpers: `array_device` and `realview` +# `array_device` is a helper that enable +# testing ODESolvers and LinearSolvers without using MPIStateArrays +# They could be potentially useful elsewhere and exported but probably need +# better names, for example `array_device` is also defined in CUDAdrv + array_device(::Union{Array, SArray, MArray}) = CPU() array_device(::CuArray) = CUDADevice() -array_device(s::SubArray) = array_device(parent(s)) array_device(Q::MPIStateArray) = array_device(Q.data) -realview(Q::Union{Array, SArray, MArray}) = Q -realview(Q::MPIStateArray) = Q.realdata -realview(Q::CuArray) = Q +for (W, _) in Adapt._wrappers + @eval array_device(wrapper::$W where {T, N, Dst, Src}) = + array_device(parent(wrapper)) +end # transform all arguments of `bc` from MPIStateArrays to CuArrays # and replace CPU function with GPU variants diff --git a/src/Numerics/DGMethods/DGModel.jl b/src/Numerics/DGMethods/DGModel.jl index b45bc70f92a..1ea95bb43ac 100644 --- a/src/Numerics/DGMethods/DGModel.jl +++ b/src/Numerics/DGMethods/DGModel.jl @@ -579,7 +579,7 @@ function init_ode_state(dg::DGModel, args...; init_on_cpu = false) else h_state_conservative = similar(state_conservative, Array) h_state_auxiliary = similar(state_auxiliary, Array) - h_state_auxiliary .= state_auxiliary + copyto!(h_state_auxiliary, state_auxiliary) event = kernel_init_state_conservative!(CPU(), Np)( balance_law, Val(dim), @@ -592,7 +592,7 @@ function init_ode_state(dg::DGModel, args...; init_on_cpu = false) ndrange = Np * nrealelem, ) wait(event) # XXX: This could be `wait(device, event)` once KA supports that. - state_conservative .= h_state_conservative + copyto!(state_conservative, h_state_conservative) end event = Event(device) diff --git a/test/Arrays/broadcasting.jl b/test/Arrays/broadcasting.jl index 28127bf3e85..c7784474e4d 100644 --- a/test/Arrays/broadcasting.jl +++ b/test/Arrays/broadcasting.jl @@ -16,8 +16,8 @@ const mpicomm = MPI.COMM_WORLD QA = MPIStateArray{Float32}(mpicomm, ArrayType, localsize...) QB = similar(QA) - QA .= A - QB .= B + copyto!(QA, A) + copyto!(QB, B) @test Array(QA) == A @test Array(QB) == B diff --git a/test/Arrays/reductions.jl b/test/Arrays/reductions.jl index bbc4757896f..553b556733b 100644 --- a/test/Arrays/reductions.jl +++ b/test/Arrays/reductions.jl @@ -19,7 +19,7 @@ mpirank = MPI.Comm_rank(mpicomm) globalA = vcat([A for _ in 1:mpisize]...) QA = MPIStateArray{Float32}(mpicomm, ArrayType, localsize...) - QA .= A + copyto!(QA, A) @test norm(QA, 1) ≈ norm(globalA, 1) @@ -36,7 +36,7 @@ mpirank = MPI.Comm_rank(mpicomm) globalB = vcat([B for _ in 1:mpisize]...) QB = similar(QA) - QB .= B + copyto!(QB, B) @test isapprox(euclidean_distance(QA, QB), norm(globalA .- globalB)) @test isapprox(dot(QA, QB), dot(globalA, globalB)) @@ -44,7 +44,7 @@ mpirank = MPI.Comm_rank(mpicomm) C = fill(Float32(mpirank + 1), localsize) globalC = vcat([fill(i, localsize) for i in 1:mpisize]...) QC = similar(QA) - QC .= C + copyto!(QC, C) @test sum(QC) == sum(globalC) @test Array(sum(QC; dims = (1, 3))) == sum(globalC; dims = (1, 3)) diff --git a/test/Arrays/reshape.jl b/test/Arrays/reshape.jl new file mode 100644 index 00000000000..bac941dc2ab --- /dev/null +++ b/test/Arrays/reshape.jl @@ -0,0 +1,36 @@ +using MPI +using Test +using ClimateMachine +using ClimateMachine.MPIStateArrays + +ClimateMachine.init() +ArrayType = ClimateMachine.array_type() +mpicomm = MPI.COMM_WORLD +FT = Float32 +Q = MPIStateArray{FT}(mpicomm, ArrayType, 4, 4, 4) +Qb = reshape(Q, (16, 4, 1)); + +Q .= 1 +Qb .= 1 + +@testset "MPIStateArray Reshape basics" begin + ClimateMachine.gpu_allowscalar(true) + @test minimum(Q[:] .== 1) + @test minimum(Qb[:] .== 1) + + @test eltype(Qb) == Float32 + @test size(Qb) == (16, 4, 1) + + fillval = 0.5f0 + fill!(Qb, fillval) + + @test Qb[1] == fillval + @test Qb[8, 1, 1] == fillval + @test Qb[end] == fillval + + @test Array(Qb) == fill(fillval, 16, 4, 1) + + Qb[8, 1, 1] = 2fillval + @test Qb[8, 1, 1] != fillval + ClimateMachine.gpu_allowscalar(false) +end diff --git a/test/Arrays/runtests.jl b/test/Arrays/runtests.jl index db87be2f101..24d43dd6226 100644 --- a/test/Arrays/runtests.jl +++ b/test/Arrays/runtests.jl @@ -7,4 +7,5 @@ include(joinpath("..", "testhelpers.jl")) runmpi(joinpath(@__DIR__, "reductions.jl")) runmpi(joinpath(@__DIR__, "reductions.jl"), ntasks = 3) runmpi(joinpath(@__DIR__, "varsindex.jl")) + runmpi(joinpath(@__DIR__, "reshape.jl")) end