diff --git a/docs/src/APIs/BalanceLaws/BalanceLaws.md b/docs/src/APIs/BalanceLaws/BalanceLaws.md index 127bb0d3671..e701915f325 100644 --- a/docs/src/APIs/BalanceLaws/BalanceLaws.md +++ b/docs/src/APIs/BalanceLaws/BalanceLaws.md @@ -63,4 +63,5 @@ boundary_state! nodal_update_auxiliary_state! update_auxiliary_state! update_auxiliary_state_gradient! +nodal_update_auxiliary_state! ``` diff --git a/src/BalanceLaws/BalanceLaws.jl b/src/BalanceLaws/BalanceLaws.jl index 9a9ec4d068e..f886fe5e60d 100644 --- a/src/BalanceLaws/BalanceLaws.jl +++ b/src/BalanceLaws/BalanceLaws.jl @@ -26,9 +26,9 @@ export BalanceLaw, source!, wavespeed, boundary_state!, - nodal_update_auxiliary_state!, update_auxiliary_state!, update_auxiliary_state_gradient!, + nodal_update_auxiliary_state!, vars_integrals, integral_load_auxiliary_state!, integral_set_auxiliary_state!, diff --git a/src/Numerics/DGMethods/DGMethods.jl b/src/Numerics/DGMethods/DGMethods.jl index 5362aee1971..2ba789512a1 100644 --- a/src/Numerics/DGMethods/DGMethods.jl +++ b/src/Numerics/DGMethods/DGMethods.jl @@ -52,7 +52,8 @@ import ..BalanceLaws: reverse_integral_load_auxiliary_state!, reverse_integral_set_auxiliary_state! -export DGModel, init_ode_state, restart_ode_state, restart_auxiliary_state +export DGModel, + init_ode_state, restart_ode_state, restart_auxiliary_state, basic_grid_info include("NumericalFluxes.jl") include("DGModel.jl") diff --git a/src/Numerics/DGMethods/DGModel.jl b/src/Numerics/DGMethods/DGModel.jl index 1ea95bb43ac..86d1ba9e1a0 100644 --- a/src/Numerics/DGMethods/DGModel.jl +++ b/src/Numerics/DGMethods/DGModel.jl @@ -46,6 +46,36 @@ end # Include the remainder model for composing DG models and balance laws include("remainder.jl") +function basic_grid_info(dg::DGModel) + grid = dg.grid + topology = grid.topology + + dim = dimensionality(grid) + N = polynomialorder(grid) + + Nq = N + 1 + Nqk = dim == 2 ? 1 : Nq + Nfp = Nq * Nqk + Np = dofs_per_element(grid) + + nelem = length(topology.elems) + nvertelem = topology.stacksize + nhorzelem = div(nelem, nvertelem) + nrealelem = length(topology.realelems) + nhorzrealelem = div(nrealelem, nvertelem) + + return ( + Nq = Nq, + Nqk = Nqk, + Nfp = Nfp, + Np = Np, + nvertelem = nvertelem, + nhorzelem = nhorzelem, + nhorzrealelem = nhorzrealelem, + nrealelem = nrealelem, + ) +end + """ (dg::DGModel)(tendency, state_conservative, nothing, t, α, β) @@ -55,7 +85,7 @@ Computes the tendency terms compatible with `IncrementODEProblem` The 4-argument form will just compute - tendency .= dQdt(state_conservative, p, t) + tendency .= dQdt(state_conservative, p, t) """ function (dg::DGModel)( @@ -108,7 +138,7 @@ function (dg::DGModel)(tendency, state_conservative, _, t, α, β) if num_state_conservative < num_state_tendency && β != 1 # if we don't operate on the full state, then we need to scale here instead of volume_tendency! tendency .*= β - β = β != 0 # if β==0 then we can avoid the memory load in volume_tendency! + β = β != 0 # if β==0 then we can avoid the memory load in volume_tendency! end communicate = diff --git a/src/Ocean/HydrostaticBoussinesq/HydrostaticBoussinesqModel.jl b/src/Ocean/HydrostaticBoussinesq/HydrostaticBoussinesqModel.jl index 044c56c558b..1ae85f4574f 100644 --- a/src/Ocean/HydrostaticBoussinesq/HydrostaticBoussinesqModel.jl +++ b/src/Ocean/HydrostaticBoussinesq/HydrostaticBoussinesqModel.jl @@ -11,9 +11,10 @@ using ...MPIStateArrays using ...Mesh.Filters: apply! using ...Mesh.Grids: VerticalDirection using ...Mesh.Geometry -using ...DGMethods: DGModel, copy_stack_field_down! +using ...DGMethods using ...DGMethods.NumericalFluxes using ...BalanceLaws +using ...BalanceLaws: number_state_auxiliary import ...DGMethods.NumericalFluxes: update_penalty! import ...BalanceLaws: @@ -629,13 +630,21 @@ function update_auxiliary_state_gradient!( indefinite_stack_integral!(dg, m, Q, A, t, elems) # bottom -> top reverse_indefinite_stack_integral!(dg, m, Q, A, t, elems) # top -> bottom + # We are unable to use vars (ie A.w) for this because this operation will + # return a SubArray, and adapt (used for broadcasting along reshaped arrays) + # has a limited recursion depth for the types allowed. + number_auxiliary = number_state_auxiliary(m, FT) + index_w = varsindex(vars_state_auxiliary(m, FT), :w) + index_wz0 = varsindex(vars_state_auxiliary(m, FT), :wz0) + w = A.data[:, index_w, :] + wz0 = A.data[:, index_wz0, :] + # project w(z=0) down the stack - # [1] to convert from range to integer - # copy_stack_field_down! doesn't like ranges - # eventually replace this with a reshape and broadcast - index_w = varsindex(vars_state_auxiliary(m, FT), :w)[1] - index_wz0 = varsindex(vars_state_auxiliary(m, FT), :wz0)[1] - copy_stack_field_down!(dg, m, A, index_w, index_wz0, elems) + Nq, Nqk, _, _, nelemv, nelemh, nhorzrealelem, _ = basic_grid_info(dg) + data = reshape(A.data, Nq^2, Nqk, number_auxiliary, nelemv, nelemh) + flat_wz0 = @view data[:, end:end, index_w, end:end, 1:nhorzrealelem] + boxy_wz0 = @view data[:, :, index_wz0, :, 1:nhorzrealelem] + boxy_wz0 .= flat_wz0 return true end