Skip to content
This repository has been archived by the owner on Mar 1, 2023. It is now read-only.

Commit

Permalink
Try #1301:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored Jun 30, 2020
2 parents 231921a + bc15c1c commit 3bcd81d
Show file tree
Hide file tree
Showing 10 changed files with 163 additions and 50 deletions.
2 changes: 1 addition & 1 deletion docs/src/APIs/BalanceLaws/BalanceLaws.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ transform_post_gradient_laplacian!
```@docs
wavespeed
boundary_state!
nodal_update_auxiliary_state!
update_auxiliary_state!
update_auxiliary_state_gradient!
nodal_update_auxiliary_state!
```
100 changes: 69 additions & 31 deletions src/Arrays/MPIStateArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@ using LazyArrays
using LinearAlgebra
using MPI
using StaticArrays
using Adapt

using ..TicToc
using ..VariableTemplates: @vars, varsindex

include("CMBuffers.jl")
using .CMBuffers

using Base.Broadcast: Broadcasted, BroadcastStyle, ArrayStyle

# This is so we can do things like
Expand All @@ -21,8 +25,6 @@ Base.similar(::Type{A}, ::Type{FT}, dims...) where {A <: Array, FT} =
Base.similar(::Type{A}, ::Type{FT}, dims...) where {A <: CuArray, FT} =
similar(CuArray{FT}, dims...)

include("CMBuffers.jl")
using .CMBuffers

cpuify(x::AbstractArray) = convert(Array, x)
cpuify(x::Real) = x
Expand Down Expand Up @@ -101,9 +103,9 @@ mutable struct MPIStateArray{
sendreq = fill(MPI.REQUEST_NULL, nnabr)
recvreq = fill(MPI.REQUEST_NULL, nnabr)

# If vmap is not on the device we need to copy it up (we also do not want to
# put it up everytime, so if it's already on the device then we do not do
# anything).
# If vmap is not on the device we need to copy it up (we also do not
# want to put it up everytime, so if it's already on the device then we
# do not do anything).
#
# Better way than checking the type names?
# XXX: Use Adapt.jl vmaprecv = adapt(DA, vmaprecv)
Expand Down Expand Up @@ -239,6 +241,27 @@ function MPIStateArray{FT, V}(
)
end

# MPIDestArray is a union of MPIStateArray and all possible wrappers
@eval const MPIDestArray = Union{
MPIStateArray,
$(
(
:($W where {T, N, Dst, Src <: MPIStateArray}) for
(W, _) in Adapt._wrappers
)...
),
}

# This creates 2 adaptors for finding the realdata (RealviewAdaptor) and
# data (RawAdaptor) of an adapted MPIStateArray
struct RealviewAdaptor end
Adapt.adapt_storage(to::RealviewAdaptor, arr::MPIStateArray) = arr.realdata
realview(Q) = adapt(RealviewAdaptor(), Q)

struct RawAdaptor end
Adapt.adapt_storage(to::RawAdaptor, arr::MPIStateArray) = arr.data
rawview(Q) = adapt(RawAdaptor(), Q)

# FIXME: should general cases be handled?
function Base.similar(
Q::MPIStateArray{OLDFT, V},
Expand Down Expand Up @@ -281,9 +304,17 @@ Base.setindex!(Q::MPIStateArray, x...; kw...) =

Base.eltype(Q::MPIStateArray, x...; kw...) = eltype(Q.data, x...; kw...)

Base.Array(Q::MPIStateArray) = Array(Q.data)
Base.Array(Q::MPIDestArray) = Array(rawview(Q))

Base.fill!(Q::MPIDestArray, x) = fill!(parent(Q), x)

for (W, ctor) in Adapt._wrappers
@eval begin
BroadcastStyle(::Type{<:$W}) where {T, N, Dst, Src <: MPIDestArray} =
BroadcastStyle(Dst)
end
end

# broadcasting stuff

# find the first MPIStateArray among `bc` arguments
# based on https://docs.julialang.org/en/v1/manual/interfaces/#Selecting-an-appropriate-output-array-1
Expand All @@ -302,40 +333,43 @@ function Base.similar(
end

# transform all arguments of `bc` from MPIStateArrays to Arrays
function transform_broadcasted(bc::Broadcasted, dest)
transform_broadcasted(bc, rawview(dest))
end

function transform_broadcasted(bc::Broadcasted, ::Array)
transform_array(bc)
end

function transform_array(bc::Broadcasted)
Broadcasted(bc.f, transform_array.(bc.args), bc.axes)
end
transform_array(mpisa::MPIStateArray) = mpisa.realdata
transform_array(x) = x

transform_array(x) = realview(x)

Base.copyto!(dest::Array, src::MPIStateArray) = copyto!(dest, src.data)
Base.copyto!(dest::MPIStateArray, src::Array) = copyto!(dest.data, src)

function Base.copyto!(dest::MPIStateArray, src::MPIStateArray)
copyto!(dest.realdata, src.realdata)
function Base.copyto!(dest::MPIDestArray, src::AbstractArray)
copyto!(rawview(dest), src)
dest
end

@inline function Base.copyto!(dest::MPIStateArray, bc::Broadcasted{Nothing})
# check for the case a .= b, where b is an array
if bc.f === identity && bc.args isa Tuple{AbstractArray}
if bc.args isa Tuple{MPIStateArray}
realindices = CartesianIndices((
axes(dest.data)[1:(end - 1)]...,
dest.realelems,
))
copyto!(dest.data, realindices, bc.args[1].data, realindices)
else
copyto!(dest.data, bc.args[1])
end
else
copyto!(dest.realdata, transform_broadcasted(bc, dest.data))
end
function Base.copyto!(dest::MPIDestArray, src::MPIDestArray)
copyto!(rawview(dest), rawview(src))
dest
end

@inline function Base.copyto!(dest::MPIDestArray, bc::Broadcasted{Nothing})
copyto!(realview(dest), transform_broadcasted(bc, dest))
dest
end

@inline Base.copyto!(
dest::MPIDestArray,
bc::Broadcasted{<:Broadcast.AbstractArrayStyle{0}},
) = copyto!(dest, convert(Broadcasted{Nothing}, bc))

"""
begin_ghost_exchange!(Q::MPIStateArray; dependencies = nothing)
Expand Down Expand Up @@ -737,15 +771,19 @@ function Base.mapreduce(
MPI.Allreduce(cpuify(locreduce), max, Q.mpicomm)
end

# helpers: `array_device` and `realview`
# `array_device` is a helper that enable
# testing ODESolvers and LinearSolvers without using MPIStateArrays
# They could be potentially useful elsewhere and exported but probably need
# better names, for example `array_device` is also defined in CUDAdrv

array_device(::Union{Array, SArray, MArray}) = CPU()
array_device(::CuArray) = CUDADevice()
array_device(s::SubArray) = array_device(parent(s))
array_device(Q::MPIStateArray) = array_device(Q.data)

realview(Q::Union{Array, SArray, MArray}) = Q
realview(Q::MPIStateArray) = Q.realdata
realview(Q::CuArray) = Q
for (W, _) in Adapt._wrappers
@eval array_device(wrapper::$W where {T, N, Dst, Src}) =
array_device(parent(wrapper))
end

# transform all arguments of `bc` from MPIStateArrays to CuArrays
# and replace CPU function with GPU variants
Expand Down
2 changes: 1 addition & 1 deletion src/BalanceLaws/BalanceLaws.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ export BalanceLaw,
source!,
wavespeed,
boundary_state!,
nodal_update_auxiliary_state!,
update_auxiliary_state!,
update_auxiliary_state_gradient!,
nodal_update_auxiliary_state!,
vars_integrals,
integral_load_auxiliary_state!,
integral_set_auxiliary_state!,
Expand Down
3 changes: 2 additions & 1 deletion src/Numerics/DGMethods/DGMethods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ import ..BalanceLaws:
reverse_integral_load_auxiliary_state!,
reverse_integral_set_auxiliary_state!

export DGModel, init_ode_state, restart_ode_state, restart_auxiliary_state
export DGModel,
init_ode_state, restart_ode_state, restart_auxiliary_state, basic_grid_info

include("NumericalFluxes.jl")
include("DGModel.jl")
Expand Down
38 changes: 34 additions & 4 deletions src/Numerics/DGMethods/DGModel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,36 @@ end
# Include the remainder model for composing DG models and balance laws
include("remainder.jl")

function basic_grid_info(dg::DGModel)
grid = dg.grid
topology = grid.topology

dim = dimensionality(grid)
N = polynomialorder(grid)

Nq = N + 1
Nqk = dim == 2 ? 1 : Nq
Nfp = Nq * Nqk
Np = dofs_per_element(grid)

nelem = length(topology.elems)
nvertelem = topology.stacksize
nhorzelem = div(nelem, nvertelem)
nrealelem = length(topology.realelems)
nhorzrealelem = div(nrealelem, nvertelem)

return (
Nq = Nq,
Nqk = Nqk,
Nfp = Nfp,
Np = Np,
nvertelem = nvertelem,
nhorzelem = nhorzelem,
nhorzrealelem = nhorzrealelem,
nrealelem = nrealelem,
)
end

"""
(dg::DGModel)(tendency, state_conservative, nothing, t, α, β)
Expand All @@ -55,7 +85,7 @@ Computes the tendency terms compatible with `IncrementODEProblem`
The 4-argument form will just compute
tendency .= dQdt(state_conservative, p, t)
tendency .= dQdt(state_conservative, p, t)
"""
function (dg::DGModel)(
Expand Down Expand Up @@ -108,7 +138,7 @@ function (dg::DGModel)(tendency, state_conservative, _, t, α, β)
if num_state_conservative < num_state_tendency && β != 1
# if we don't operate on the full state, then we need to scale here instead of volume_tendency!
tendency .*= β
β = β != 0 # if β==0 then we can avoid the memory load in volume_tendency!
β = β != 0 # if β==0 then we can avoid the memory load in volume_tendency!
end

communicate =
Expand Down Expand Up @@ -579,7 +609,7 @@ function init_ode_state(dg::DGModel, args...; init_on_cpu = false)
else
h_state_conservative = similar(state_conservative, Array)
h_state_auxiliary = similar(state_auxiliary, Array)
h_state_auxiliary .= state_auxiliary
copyto!(h_state_auxiliary, state_auxiliary)
event = kernel_init_state_conservative!(CPU(), Np)(
balance_law,
Val(dim),
Expand All @@ -592,7 +622,7 @@ function init_ode_state(dg::DGModel, args...; init_on_cpu = false)
ndrange = Np * nrealelem,
)
wait(event) # XXX: This could be `wait(device, event)` once KA supports that.
state_conservative .= h_state_conservative
copyto!(state_conservative, h_state_conservative)
end

event = Event(device)
Expand Down
21 changes: 14 additions & 7 deletions src/Ocean/HydrostaticBoussinesq/HydrostaticBoussinesqModel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ using ...MPIStateArrays
using ...Mesh.Filters: apply!
using ...Mesh.Grids: VerticalDirection
using ...Mesh.Geometry
using ...DGMethods: DGModel, copy_stack_field_down!
using ...DGMethods
using ...DGMethods.NumericalFluxes
using ...BalanceLaws
using ...BalanceLaws: number_state_auxiliary

import ...DGMethods.NumericalFluxes: update_penalty!
import ...BalanceLaws:
Expand Down Expand Up @@ -629,13 +630,19 @@ function update_auxiliary_state_gradient!(
indefinite_stack_integral!(dg, m, Q, A, t, elems) # bottom -> top
reverse_indefinite_stack_integral!(dg, m, Q, A, t, elems) # top -> bottom

# We are unable to use vars (ie A.w) for this because this operation will
# return a SubArray, and adapt (used for broadcasting along reshaped arrays)
# has a limited recursion depth for the types allowed.
number_auxiliary = number_state_auxiliary(m, FT)
index_w = varsindex(vars_state_auxiliary(m, FT), :w)
index_wz0 = varsindex(vars_state_auxiliary(m, FT), :wz0)
Nq, Nqk, _, _, nelemv, nelemh, nhorzrealelem, _ = basic_grid_info(dg)

# project w(z=0) down the stack
# [1] to convert from range to integer
# copy_stack_field_down! doesn't like ranges
# eventually replace this with a reshape and broadcast
index_w = varsindex(vars_state_auxiliary(m, FT), :w)[1]
index_wz0 = varsindex(vars_state_auxiliary(m, FT), :wz0)[1]
copy_stack_field_down!(dg, m, A, index_w, index_wz0, elems)
data = reshape(A.data, Nq^2, Nqk, number_auxiliary, nelemv, nelemh)
flat_wz0 = @view data[:, end:end, index_w, end:end, 1:nhorzrealelem]
boxy_wz0 = @view data[:, :, index_wz0, :, 1:nhorzrealelem]
boxy_wz0 .= flat_wz0

return true
end
Expand Down
4 changes: 2 additions & 2 deletions test/Arrays/broadcasting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ const mpicomm = MPI.COMM_WORLD
QA = MPIStateArray{Float32}(mpicomm, ArrayType, localsize...)
QB = similar(QA)

QA .= A
QB .= B
copyto!(QA, A)
copyto!(QB, B)

@test Array(QA) == A
@test Array(QB) == B
Expand Down
6 changes: 3 additions & 3 deletions test/Arrays/reductions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ mpirank = MPI.Comm_rank(mpicomm)
globalA = vcat([A for _ in 1:mpisize]...)

QA = MPIStateArray{Float32}(mpicomm, ArrayType, localsize...)
QA .= A
copyto!(QA, A)


@test norm(QA, 1) norm(globalA, 1)
Expand All @@ -36,15 +36,15 @@ mpirank = MPI.Comm_rank(mpicomm)
globalB = vcat([B for _ in 1:mpisize]...)

QB = similar(QA)
QB .= B
copyto!(QB, B)

@test isapprox(euclidean_distance(QA, QB), norm(globalA .- globalB))
@test isapprox(dot(QA, QB), dot(globalA, globalB))

C = fill(Float32(mpirank + 1), localsize)
globalC = vcat([fill(i, localsize) for i in 1:mpisize]...)
QC = similar(QA)
QC .= C
copyto!(QC, C)

@test sum(QC) == sum(globalC)
@test Array(sum(QC; dims = (1, 3))) == sum(globalC; dims = (1, 3))
Expand Down
36 changes: 36 additions & 0 deletions test/Arrays/reshape.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using MPI
using Test
using ClimateMachine
using ClimateMachine.MPIStateArrays

ClimateMachine.init()
ArrayType = ClimateMachine.array_type()
mpicomm = MPI.COMM_WORLD
FT = Float32
Q = MPIStateArray{FT}(mpicomm, ArrayType, 4, 4, 4)
Qb = reshape(Q, (16, 4, 1));

Q .= 1
Qb .= 1

@testset "MPIStateArray Reshape basics" begin
ClimateMachine.gpu_allowscalar(true)
@test minimum(Q[:] .== 1)
@test minimum(Qb[:] .== 1)

@test eltype(Qb) == Float32
@test size(Qb) == (16, 4, 1)

fillval = 0.5f0
fill!(Qb, fillval)

@test Qb[1] == fillval
@test Qb[8, 1, 1] == fillval
@test Qb[end] == fillval

@test Array(Qb) == fill(fillval, 16, 4, 1)

Qb[8, 1, 1] = 2fillval
@test Qb[8, 1, 1] != fillval
ClimateMachine.gpu_allowscalar(false)
end
1 change: 1 addition & 0 deletions test/Arrays/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ include(joinpath("..", "testhelpers.jl"))
runmpi(joinpath(@__DIR__, "reductions.jl"))
runmpi(joinpath(@__DIR__, "reductions.jl"), ntasks = 3)
runmpi(joinpath(@__DIR__, "varsindex.jl"))
runmpi(joinpath(@__DIR__, "reshape.jl"))
end

0 comments on commit 3bcd81d

Please sign in to comment.