Skip to content
This repository has been archived by the owner on Mar 1, 2023. It is now read-only.

Replace copy_stack_field_down! with broadcasts of reshaped MPIStateArrays #1301

Merged
merged 2 commits into from
Jul 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/APIs/BalanceLaws/BalanceLaws.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ transform_post_gradient_laplacian!
```@docs
wavespeed
boundary_state!
nodal_update_auxiliary_state!
update_auxiliary_state!
update_auxiliary_state_gradient!
nodal_update_auxiliary_state!
```
2 changes: 1 addition & 1 deletion src/BalanceLaws/BalanceLaws.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ export BalanceLaw,
source!,
wavespeed,
boundary_state!,
nodal_update_auxiliary_state!,
update_auxiliary_state!,
update_auxiliary_state_gradient!,
nodal_update_auxiliary_state!,
vars_integrals,
integral_load_auxiliary_state!,
integral_set_auxiliary_state!,
Expand Down
3 changes: 2 additions & 1 deletion src/Numerics/DGMethods/DGMethods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ import ..BalanceLaws:
reverse_integral_load_auxiliary_state!,
reverse_integral_set_auxiliary_state!

export DGModel, init_ode_state, restart_ode_state, restart_auxiliary_state
export DGModel,
init_ode_state, restart_ode_state, restart_auxiliary_state, basic_grid_info

include("NumericalFluxes.jl")
include("DGModel.jl")
Expand Down
72 changes: 32 additions & 40 deletions src/Numerics/DGMethods/DGModel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,36 @@ end
# Include the remainder model for composing DG models and balance laws
include("remainder.jl")

function basic_grid_info(dg::DGModel)
grid = dg.grid
topology = grid.topology

dim = dimensionality(grid)
N = polynomialorder(grid)

Nq = N + 1
Nqk = dim == 2 ? 1 : Nq
Nfp = Nq * Nqk
Np = dofs_per_element(grid)

nelem = length(topology.elems)
nvertelem = topology.stacksize
nhorzelem = div(nelem, nvertelem)
nrealelem = length(topology.realelems)
nhorzrealelem = div(nrealelem, nvertelem)

return (
Nq = Nq,
Nqk = Nqk,
Nfp = Nfp,
Np = Np,
nvertelem = nvertelem,
nhorzelem = nhorzelem,
nhorzrealelem = nhorzrealelem,
nrealelem = nrealelem,
)
end

"""
(dg::DGModel)(tendency, state_conservative, nothing, t, α, β)
Expand All @@ -55,7 +85,7 @@ Computes the tendency terms compatible with `IncrementODEProblem`
The 4-argument form will just compute
tendency .= dQdt(state_conservative, p, t)
tendency .= dQdt(state_conservative, p, t)
"""
function (dg::DGModel)(
Expand Down Expand Up @@ -108,7 +138,7 @@ function (dg::DGModel)(tendency, state_conservative, _, t, α, β)
if num_state_conservative < num_state_tendency && β != 1
# if we don't operate on the full state, then we need to scale here instead of volume_tendency!
tendency .*= β
β = β != 0 # if β==0 then we can avoid the memory load in volume_tendency!
β = β != 0 # if β==0 then we can avoid the memory load in volume_tendency!
end

communicate =
Expand Down Expand Up @@ -864,44 +894,6 @@ function courant(
MPI.Allreduce(rank_courant_max, max, topology.mpicomm)
end

function copy_stack_field_down!(
dg::DGModel,
m::BalanceLaw,
state_auxiliary::MPIStateArray,
fldin,
fldout,
elems = topology.elems,
)
device = array_device(state_auxiliary)

grid = dg.grid
topology = grid.topology

dim = dimensionality(grid)
N = polynomialorder(grid)
Nq = N + 1
Nqk = dim == 2 ? 1 : Nq

# do integrals
nelem = length(elems)
nvertelem = topology.stacksize
horzelems = fld1(first(elems), nvertelem):fld1(last(elems), nvertelem)

event = Event(device)
event = kernel_copy_stack_field_down!(device, (Nq, Nqk))(
Val(dim),
Val(N),
Val(nvertelem),
state_auxiliary.data,
horzelems,
Val(fldin),
Val(fldout);
ndrange = (length(horzelems) * Nq, Nqk),
dependencies = (event,),
)
wait(device, event)
end

function MPIStateArrays.MPIStateArray(dg::DGModel)
balance_law = dg.balance_law
grid = dg.grid
Expand Down
38 changes: 0 additions & 38 deletions src/Numerics/DGMethods/DGModel_kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1921,44 +1921,6 @@ end
end
end

# TODO: Generalize to more than one field?
@kernel function kernel_copy_stack_field_down!(
::Val{dim},
::Val{N},
::Val{nvertelem},
state_auxiliary,
elems,
::Val{fldin},
::Val{fldout},
) where {dim, N, nvertelem, fldin, fldout}
DFloat = eltype(state_auxiliary)

Nq = N + 1
Nqj = dim == 2 ? 1 : Nq

_eh = @index(Group, Linear)
i, j = @index(Local, NTuple)

# note that k is the second not 4th index (since this is scratch memory and k
# needs to be persistent across threads)
@inbounds begin
# Initialize the constant state at zero
ijk = i + Nq * ((j - 1) + Nqj * (Nq - 1))
eh = elems[_eh]
et = nvertelem + (eh - 1) * nvertelem
val = state_auxiliary[ijk, fldin, et]

# Loop up the stack of elements
for ev in 1:nvertelem
e = ev + (eh - 1) * nvertelem
@unroll for k in 1:Nq
ijk = i + Nq * ((j - 1) + Nqj * (k - 1))
state_auxiliary[ijk, fldout, e] = val
end
end
end
end

@kernel function volume_divergence_of_gradients!(
balance_law::BalanceLaw,
::Val{dim},
Expand Down
21 changes: 14 additions & 7 deletions src/Ocean/HydrostaticBoussinesq/HydrostaticBoussinesqModel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ using ...MPIStateArrays
using ...Mesh.Filters: apply!
using ...Mesh.Grids: VerticalDirection
using ...Mesh.Geometry
using ...DGMethods: DGModel, copy_stack_field_down!
using ...DGMethods
using ...DGMethods.NumericalFluxes
using ...BalanceLaws
using ...BalanceLaws: number_state_auxiliary

import ...DGMethods.NumericalFluxes: update_penalty!
import ...BalanceLaws:
Expand Down Expand Up @@ -629,13 +630,19 @@ function update_auxiliary_state_gradient!(
indefinite_stack_integral!(dg, m, Q, A, t, elems) # bottom -> top
reverse_indefinite_stack_integral!(dg, m, Q, A, t, elems) # top -> bottom

# We are unable to use vars (ie A.w) for this because this operation will
# return a SubArray, and adapt (used for broadcasting along reshaped arrays)
# has a limited recursion depth for the types allowed.
number_auxiliary = number_state_auxiliary(m, FT)
index_w = varsindex(vars_state_auxiliary(m, FT), :w)
index_wz0 = varsindex(vars_state_auxiliary(m, FT), :wz0)
Nq, Nqk, _, _, nelemv, nelemh, nhorzrealelem, _ = basic_grid_info(dg)

# project w(z=0) down the stack
# [1] to convert from range to integer
# copy_stack_field_down! doesn't like ranges
# eventually replace this with a reshape and broadcast
index_w = varsindex(vars_state_auxiliary(m, FT), :w)[1]
index_wz0 = varsindex(vars_state_auxiliary(m, FT), :wz0)[1]
copy_stack_field_down!(dg, m, A, index_w, index_wz0, elems)
data = reshape(A.data, Nq^2, Nqk, number_auxiliary, nelemv, nelemh)
flat_wz0 = @view data[:, end:end, index_w, end:end, 1:nhorzrealelem]
boxy_wz0 = @view data[:, :, index_wz0, :, 1:nhorzrealelem]
boxy_wz0 .= flat_wz0
blallen marked this conversation as resolved.
Show resolved Hide resolved

return true
end
Expand Down