Skip to content
This repository has been archived by the owner on Mar 1, 2023. It is now read-only.

Commit

Permalink
Merge #1301
Browse files Browse the repository at this point in the history
1301: Replace `copy_stack_field_down!` with broadcasts of reshaped MPIStateArrays r=blallen a=blallen

# Description

After #1071 is merged, we no longer need a dedicated kernel to copy the vertical velocity at the top of the ocean (`wz0`) down the entire stack for use in the source term for `η`. We can also use this PR to discuss how we want the `basic_grid_info` function to work. 



Co-authored-by: Brandon Allen <ballen@mit.edu>
  • Loading branch information
bors[bot] and blallen authored Jun 30, 2020
2 parents 1d18306 + 59f4828 commit 1f30473
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 88 deletions.
2 changes: 1 addition & 1 deletion docs/src/APIs/BalanceLaws/BalanceLaws.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ transform_post_gradient_laplacian!
```@docs
wavespeed
boundary_state!
nodal_update_auxiliary_state!
update_auxiliary_state!
update_auxiliary_state_gradient!
nodal_update_auxiliary_state!
```
2 changes: 1 addition & 1 deletion src/BalanceLaws/BalanceLaws.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ export BalanceLaw,
source!,
wavespeed,
boundary_state!,
nodal_update_auxiliary_state!,
update_auxiliary_state!,
update_auxiliary_state_gradient!,
nodal_update_auxiliary_state!,
vars_integrals,
integral_load_auxiliary_state!,
integral_set_auxiliary_state!,
Expand Down
3 changes: 2 additions & 1 deletion src/Numerics/DGMethods/DGMethods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ import ..BalanceLaws:
reverse_integral_load_auxiliary_state!,
reverse_integral_set_auxiliary_state!

export DGModel, init_ode_state, restart_ode_state, restart_auxiliary_state
export DGModel,
init_ode_state, restart_ode_state, restart_auxiliary_state, basic_grid_info

include("NumericalFluxes.jl")
include("DGModel.jl")
Expand Down
72 changes: 32 additions & 40 deletions src/Numerics/DGMethods/DGModel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,36 @@ end
# Include the remainder model for composing DG models and balance laws
include("remainder.jl")

function basic_grid_info(dg::DGModel)
grid = dg.grid
topology = grid.topology

dim = dimensionality(grid)
N = polynomialorder(grid)

Nq = N + 1
Nqk = dim == 2 ? 1 : Nq
Nfp = Nq * Nqk
Np = dofs_per_element(grid)

nelem = length(topology.elems)
nvertelem = topology.stacksize
nhorzelem = div(nelem, nvertelem)
nrealelem = length(topology.realelems)
nhorzrealelem = div(nrealelem, nvertelem)

return (
Nq = Nq,
Nqk = Nqk,
Nfp = Nfp,
Np = Np,
nvertelem = nvertelem,
nhorzelem = nhorzelem,
nhorzrealelem = nhorzrealelem,
nrealelem = nrealelem,
)
end

"""
(dg::DGModel)(tendency, state_conservative, nothing, t, α, β)
Expand All @@ -55,7 +85,7 @@ Computes the tendency terms compatible with `IncrementODEProblem`
The 4-argument form will just compute
tendency .= dQdt(state_conservative, p, t)
tendency .= dQdt(state_conservative, p, t)
"""
function (dg::DGModel)(
Expand Down Expand Up @@ -108,7 +138,7 @@ function (dg::DGModel)(tendency, state_conservative, _, t, α, β)
if num_state_conservative < num_state_tendency && β != 1
# if we don't operate on the full state, then we need to scale here instead of volume_tendency!
tendency .*= β
β = β != 0 # if β==0 then we can avoid the memory load in volume_tendency!
β = β != 0 # if β==0 then we can avoid the memory load in volume_tendency!
end

communicate =
Expand Down Expand Up @@ -864,44 +894,6 @@ function courant(
MPI.Allreduce(rank_courant_max, max, topology.mpicomm)
end

function copy_stack_field_down!(
dg::DGModel,
m::BalanceLaw,
state_auxiliary::MPIStateArray,
fldin,
fldout,
elems = topology.elems,
)
device = array_device(state_auxiliary)

grid = dg.grid
topology = grid.topology

dim = dimensionality(grid)
N = polynomialorder(grid)
Nq = N + 1
Nqk = dim == 2 ? 1 : Nq

# do integrals
nelem = length(elems)
nvertelem = topology.stacksize
horzelems = fld1(first(elems), nvertelem):fld1(last(elems), nvertelem)

event = Event(device)
event = kernel_copy_stack_field_down!(device, (Nq, Nqk))(
Val(dim),
Val(N),
Val(nvertelem),
state_auxiliary.data,
horzelems,
Val(fldin),
Val(fldout);
ndrange = (length(horzelems) * Nq, Nqk),
dependencies = (event,),
)
wait(device, event)
end

function MPIStateArrays.MPIStateArray(dg::DGModel)
balance_law = dg.balance_law
grid = dg.grid
Expand Down
38 changes: 0 additions & 38 deletions src/Numerics/DGMethods/DGModel_kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1921,44 +1921,6 @@ end
end
end

# TODO: Generalize to more than one field?
@kernel function kernel_copy_stack_field_down!(
::Val{dim},
::Val{N},
::Val{nvertelem},
state_auxiliary,
elems,
::Val{fldin},
::Val{fldout},
) where {dim, N, nvertelem, fldin, fldout}
DFloat = eltype(state_auxiliary)

Nq = N + 1
Nqj = dim == 2 ? 1 : Nq

_eh = @index(Group, Linear)
i, j = @index(Local, NTuple)

# note that k is the second not 4th index (since this is scratch memory and k
# needs to be persistent across threads)
@inbounds begin
# Initialize the constant state at zero
ijk = i + Nq * ((j - 1) + Nqj * (Nq - 1))
eh = elems[_eh]
et = nvertelem + (eh - 1) * nvertelem
val = state_auxiliary[ijk, fldin, et]

# Loop up the stack of elements
for ev in 1:nvertelem
e = ev + (eh - 1) * nvertelem
@unroll for k in 1:Nq
ijk = i + Nq * ((j - 1) + Nqj * (k - 1))
state_auxiliary[ijk, fldout, e] = val
end
end
end
end

@kernel function volume_divergence_of_gradients!(
balance_law::BalanceLaw,
::Val{dim},
Expand Down
21 changes: 14 additions & 7 deletions src/Ocean/HydrostaticBoussinesq/HydrostaticBoussinesqModel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ using ...MPIStateArrays
using ...Mesh.Filters: apply!
using ...Mesh.Grids: VerticalDirection
using ...Mesh.Geometry
using ...DGMethods: DGModel, copy_stack_field_down!
using ...DGMethods
using ...DGMethods.NumericalFluxes
using ...BalanceLaws
using ...BalanceLaws: number_state_auxiliary

import ...DGMethods.NumericalFluxes: update_penalty!
import ...BalanceLaws:
Expand Down Expand Up @@ -629,13 +630,19 @@ function update_auxiliary_state_gradient!(
indefinite_stack_integral!(dg, m, Q, A, t, elems) # bottom -> top
reverse_indefinite_stack_integral!(dg, m, Q, A, t, elems) # top -> bottom

# We are unable to use vars (ie A.w) for this because this operation will
# return a SubArray, and adapt (used for broadcasting along reshaped arrays)
# has a limited recursion depth for the types allowed.
number_auxiliary = number_state_auxiliary(m, FT)
index_w = varsindex(vars_state_auxiliary(m, FT), :w)
index_wz0 = varsindex(vars_state_auxiliary(m, FT), :wz0)
Nq, Nqk, _, _, nelemv, nelemh, nhorzrealelem, _ = basic_grid_info(dg)

# project w(z=0) down the stack
# [1] to convert from range to integer
# copy_stack_field_down! doesn't like ranges
# eventually replace this with a reshape and broadcast
index_w = varsindex(vars_state_auxiliary(m, FT), :w)[1]
index_wz0 = varsindex(vars_state_auxiliary(m, FT), :wz0)[1]
copy_stack_field_down!(dg, m, A, index_w, index_wz0, elems)
data = reshape(A.data, Nq^2, Nqk, number_auxiliary, nelemv, nelemh)
flat_wz0 = @view data[:, end:end, index_w, end:end, 1:nhorzrealelem]
boxy_wz0 = @view data[:, :, index_wz0, :, 1:nhorzrealelem]
boxy_wz0 .= flat_wz0

return true
end
Expand Down

0 comments on commit 1f30473

Please sign in to comment.