Skip to content
This repository has been archived by the owner on Mar 1, 2023. It is now read-only.

Commit

Permalink
adding a version of MPIStateArrays that allows for traversing through…
Browse files Browse the repository at this point in the history
… wrapper types

Co-authored-by: Valentin Churavy <vchuravy@users.noreply.github.com>
  • Loading branch information
2 people authored and blallen committed Jun 30, 2020
1 parent e8865d4 commit e79b891
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 38 deletions.
100 changes: 69 additions & 31 deletions src/Arrays/MPIStateArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@ using LazyArrays
using LinearAlgebra
using MPI
using StaticArrays
using Adapt

using ..TicToc
using ..VariableTemplates: @vars, varsindex

include("CMBuffers.jl")
using .CMBuffers

using Base.Broadcast: Broadcasted, BroadcastStyle, ArrayStyle

# This is so we can do things like
Expand All @@ -21,8 +25,6 @@ Base.similar(::Type{A}, ::Type{FT}, dims...) where {A <: Array, FT} =
Base.similar(::Type{A}, ::Type{FT}, dims...) where {A <: CuArray, FT} =
similar(CuArray{FT}, dims...)

include("CMBuffers.jl")
using .CMBuffers

cpuify(x::AbstractArray) = convert(Array, x)
cpuify(x::Real) = x
Expand Down Expand Up @@ -101,9 +103,9 @@ mutable struct MPIStateArray{
sendreq = fill(MPI.REQUEST_NULL, nnabr)
recvreq = fill(MPI.REQUEST_NULL, nnabr)

# If vmap is not on the device we need to copy it up (we also do not want to
# put it up everytime, so if it's already on the device then we do not do
# anything).
# If vmap is not on the device we need to copy it up (we also do not
# want to put it up everytime, so if it's already on the device then we
# do not do anything).
#
# Better way than checking the type names?
# XXX: Use Adapt.jl vmaprecv = adapt(DA, vmaprecv)
Expand Down Expand Up @@ -239,6 +241,27 @@ function MPIStateArray{FT, V}(
)
end

# MPIDestArray is a union of MPIStateArray and all possible wrappers
@eval const MPIDestArray = Union{
MPIStateArray,
$(
(
:($W where {T, N, Dst, Src <: MPIStateArray}) for
(W, _) in Adapt._wrappers
)...
),
}

# This creates 2 adaptors for finding the realdata (RealviewAdaptor) and
# data (RawAdaptor) of an adapted MPIStateArray
struct RealviewAdaptor end
Adapt.adapt_storage(to::RealviewAdaptor, arr::MPIStateArray) = arr.realdata
realview(Q) = adapt(RealviewAdaptor(), Q)

struct RawAdaptor end
Adapt.adapt_storage(to::RawAdaptor, arr::MPIStateArray) = arr.data
rawview(Q) = adapt(RawAdaptor(), Q)

# FIXME: should general cases be handled?
function Base.similar(
Q::MPIStateArray{OLDFT, V},
Expand Down Expand Up @@ -281,9 +304,17 @@ Base.setindex!(Q::MPIStateArray, x...; kw...) =

Base.eltype(Q::MPIStateArray, x...; kw...) = eltype(Q.data, x...; kw...)

Base.Array(Q::MPIStateArray) = Array(Q.data)
Base.Array(Q::MPIDestArray) = Array(rawview(Q))

Base.fill!(Q::MPIDestArray, x) = fill!(parent(Q), x)

for (W, ctor) in Adapt._wrappers
@eval begin
BroadcastStyle(::Type{<:$W}) where {T, N, Dst, Src <: MPIDestArray} =
BroadcastStyle(Dst)
end
end

# broadcasting stuff

# find the first MPIStateArray among `bc` arguments
# based on https://docs.julialang.org/en/v1/manual/interfaces/#Selecting-an-appropriate-output-array-1
Expand All @@ -302,40 +333,43 @@ function Base.similar(
end

# transform all arguments of `bc` from MPIStateArrays to Arrays
function transform_broadcasted(bc::Broadcasted, dest)
transform_broadcasted(bc, rawview(dest))
end

function transform_broadcasted(bc::Broadcasted, ::Array)
transform_array(bc)
end

function transform_array(bc::Broadcasted)
Broadcasted(bc.f, transform_array.(bc.args), bc.axes)
end
transform_array(mpisa::MPIStateArray) = mpisa.realdata
transform_array(x) = x

transform_array(x) = realview(x)

Base.copyto!(dest::Array, src::MPIStateArray) = copyto!(dest, src.data)
Base.copyto!(dest::MPIStateArray, src::Array) = copyto!(dest.data, src)

function Base.copyto!(dest::MPIStateArray, src::MPIStateArray)
copyto!(dest.realdata, src.realdata)
function Base.copyto!(dest::MPIDestArray, src::AbstractArray)
copyto!(rawview(dest), src)
dest
end

@inline function Base.copyto!(dest::MPIStateArray, bc::Broadcasted{Nothing})
# check for the case a .= b, where b is an array
if bc.f === identity && bc.args isa Tuple{AbstractArray}
if bc.args isa Tuple{MPIStateArray}
realindices = CartesianIndices((
axes(dest.data)[1:(end - 1)]...,
dest.realelems,
))
copyto!(dest.data, realindices, bc.args[1].data, realindices)
else
copyto!(dest.data, bc.args[1])
end
else
copyto!(dest.realdata, transform_broadcasted(bc, dest.data))
end
function Base.copyto!(dest::MPIDestArray, src::MPIDestArray)
copyto!(rawview(dest), rawview(src))
dest
end

@inline function Base.copyto!(dest::MPIDestArray, bc::Broadcasted{Nothing})
copyto!(realview(dest), transform_broadcasted(bc, dest))
dest
end

@inline Base.copyto!(
dest::MPIDestArray,
bc::Broadcasted{<:Broadcast.AbstractArrayStyle{0}},
) = copyto!(dest, convert(Broadcasted{Nothing}, bc))

"""
begin_ghost_exchange!(Q::MPIStateArray; dependencies = nothing)
Expand Down Expand Up @@ -737,15 +771,19 @@ function Base.mapreduce(
MPI.Allreduce(cpuify(locreduce), max, Q.mpicomm)
end

# helpers: `array_device` and `realview`
# `array_device` is a helper that enable
# testing ODESolvers and LinearSolvers without using MPIStateArrays
# They could be potentially useful elsewhere and exported but probably need
# better names, for example `array_device` is also defined in CUDAdrv

array_device(::Union{Array, SArray, MArray}) = CPU()
array_device(::CuArray) = CUDADevice()
array_device(s::SubArray) = array_device(parent(s))
array_device(Q::MPIStateArray) = array_device(Q.data)

realview(Q::Union{Array, SArray, MArray}) = Q
realview(Q::MPIStateArray) = Q.realdata
realview(Q::CuArray) = Q
for (W, _) in Adapt._wrappers
@eval array_device(wrapper::$W where {T, N, Dst, Src}) =
array_device(parent(wrapper))
end

# transform all arguments of `bc` from MPIStateArrays to CuArrays
# and replace CPU function with GPU variants
Expand Down
4 changes: 2 additions & 2 deletions src/Numerics/DGMethods/DGModel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ function init_ode_state(dg::DGModel, args...; init_on_cpu = false)
else
h_state_conservative = similar(state_conservative, Array)
h_state_auxiliary = similar(state_auxiliary, Array)
h_state_auxiliary .= state_auxiliary
copyto!(h_state_auxiliary, state_auxiliary)
event = kernel_init_state_conservative!(CPU(), Np)(
balance_law,
Val(dim),
Expand All @@ -592,7 +592,7 @@ function init_ode_state(dg::DGModel, args...; init_on_cpu = false)
ndrange = Np * nrealelem,
)
wait(event) # XXX: This could be `wait(device, event)` once KA supports that.
state_conservative .= h_state_conservative
copyto!(state_conservative, h_state_conservative)
end

event = Event(device)
Expand Down
4 changes: 2 additions & 2 deletions test/Arrays/broadcasting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ const mpicomm = MPI.COMM_WORLD
QA = MPIStateArray{Float32}(mpicomm, ArrayType, localsize...)
QB = similar(QA)

QA .= A
QB .= B
copyto!(QA, A)
copyto!(QB, B)

@test Array(QA) == A
@test Array(QB) == B
Expand Down
6 changes: 3 additions & 3 deletions test/Arrays/reductions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ mpirank = MPI.Comm_rank(mpicomm)
globalA = vcat([A for _ in 1:mpisize]...)

QA = MPIStateArray{Float32}(mpicomm, ArrayType, localsize...)
QA .= A
copyto!(QA, A)


@test norm(QA, 1) norm(globalA, 1)
Expand All @@ -36,15 +36,15 @@ mpirank = MPI.Comm_rank(mpicomm)
globalB = vcat([B for _ in 1:mpisize]...)

QB = similar(QA)
QB .= B
copyto!(QB, B)

@test isapprox(euclidean_distance(QA, QB), norm(globalA .- globalB))
@test isapprox(dot(QA, QB), dot(globalA, globalB))

C = fill(Float32(mpirank + 1), localsize)
globalC = vcat([fill(i, localsize) for i in 1:mpisize]...)
QC = similar(QA)
QC .= C
copyto!(QC, C)

@test sum(QC) == sum(globalC)
@test Array(sum(QC; dims = (1, 3))) == sum(globalC; dims = (1, 3))
Expand Down
36 changes: 36 additions & 0 deletions test/Arrays/reshape.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using MPI
using Test
using ClimateMachine
using ClimateMachine.MPIStateArrays

ClimateMachine.init()
ArrayType = ClimateMachine.array_type()
mpicomm = MPI.COMM_WORLD
FT = Float32
Q = MPIStateArray{FT}(mpicomm, ArrayType, 4, 4, 4)
Qb = reshape(Q, (16, 4, 1));

Q .= 1
Qb .= 1

@testset "MPIStateArray Reshape basics" begin
ClimateMachine.gpu_allowscalar(true)
@test minimum(Q[:] .== 1)
@test minimum(Qb[:] .== 1)

@test eltype(Qb) == Float32
@test size(Qb) == (16, 4, 1)

fillval = 0.5f0
fill!(Qb, fillval)

@test Qb[1] == fillval
@test Qb[8, 1, 1] == fillval
@test Qb[end] == fillval

@test Array(Qb) == fill(fillval, 16, 4, 1)

Qb[8, 1, 1] = 2fillval
@test Qb[8, 1, 1] != fillval
ClimateMachine.gpu_allowscalar(false)
end
1 change: 1 addition & 0 deletions test/Arrays/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ include(joinpath("..", "testhelpers.jl"))
runmpi(joinpath(@__DIR__, "reductions.jl"))
runmpi(joinpath(@__DIR__, "reductions.jl"), ntasks = 3)
runmpi(joinpath(@__DIR__, "varsindex.jl"))
runmpi(joinpath(@__DIR__, "reshape.jl"))
end

0 comments on commit e79b891

Please sign in to comment.