Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove getfielddims/getfieldinterpolations/getfieldnames for FieldHandler #647

Merged
merged 3 commits into from
Mar 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/Dofs/ConstraintHandler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ end
function add!(ch::ConstraintHandler{<:MixedDofHandler}, dbc::Dirichlet)
dbc_added = false
for fh in ch.dh.fieldhandlers
if dbc.field_name in getfieldnames(fh) && _in_cellset(ch.dh.grid, fh.cellset, dbc.faces; all=false)
if !isnothing(_find_field(fh, dbc.field_name)) && _in_cellset(ch.dh.grid, fh.cellset, dbc.faces; all=false)
# Dofs in `dbc` not in `fh` will be removed, hence `dbc.faces` must be copied.
# Recreating the `dbc` will create a copy of `dbc.faces`.
# In this case, add! will warn, unless `warn_not_in_cellset=false`
Expand All @@ -938,8 +938,8 @@ function add!(ch::ConstraintHandler, fh::FieldHandler, dbc::Dirichlet; warn_not_

# Extract stuff for the field
field_idx = find_field(fh, dbc.field_name)
interpolation = getfieldinterpolations(fh)[field_idx]
field_dim = getfielddims(fh)[field_idx]
interpolation = getfieldinterpolation(fh, field_idx)
field_dim = getfielddim(fh, field_idx)

if !all(c -> 0 < c <= field_dim, dbc.components)
error("components $(dbc.components) not within range of field :$(dbc.field_name) ($(field_dim) dimension(s))")
Expand Down
41 changes: 18 additions & 23 deletions src/Dofs/MixedDofHandler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@ function Base.show(io::IO, ::MIME"text/plain", dh::MixedDofHandler)
end
end

getfieldnames(fh::FieldHandler) = [field.name for field in fh.fields]
getfielddims(fh::FieldHandler) = [field.dim for field in fh.fields]
getfieldinterpolations(fh::FieldHandler) = [field.interpolation for field in fh.fields]

"""
ndofs_per_cell(dh::AbstractDofHandler[, cell::Int=1])

Expand Down Expand Up @@ -120,7 +116,6 @@ end

"""
getfieldnames(dh::MixedDofHandler)
getfieldnames(fh::FieldHandler)

Return a vector with the names of all fields. Can be used as an iterable over all the fields
in the problem.
Expand Down Expand Up @@ -164,11 +159,11 @@ function add!(dh::MixedDofHandler, fh::FieldHandler)
_check_same_celltype(dh.grid, collect(fh.cellset))
_check_cellset_intersections(dh, fh)
# the field interpolations should have the same refshape as the cells they are applied to
refshapes_fh = getrefshape.(getfieldinterpolations(fh))
# extract the celltype from the first cell as the celltypes are all equal
cell_type = typeof(dh.grid.cells[first(fh.cellset)])
refshape_cellset = getrefshape(default_interpolation(cell_type))
for refshape in refshapes_fh
for field_idx in eachindex(fh.fields)
refshape = getrefshape(getfieldinterpolation(fh, field_idx))
refshape_cellset == refshape || error("The RefShapes of the fieldhandlers interpolations must correspond to the RefShape of the cells it is applied to.")
end

Expand Down Expand Up @@ -265,9 +260,7 @@ function __close!(dh::MixedDofHandler{dim}) where {dim}
dh,
cellnumbers,
dh.field_names,
getfieldnames(fh),
getfielddims(fh),
getfieldinterpolations(fh),
fh.fields,
nextdof,
vertexdicts,
edgedicts,
Expand All @@ -281,9 +274,10 @@ function __close!(dh::MixedDofHandler{dim}) where {dim}

end

function _close!(dh::MixedDofHandler{dim}, cellnumbers, global_field_names, field_names, field_dims, field_interpolations, nextdof, vertexdicts, edgedicts, facedicts) where {dim}
function _close!(dh::MixedDofHandler{dim}, cellnumbers, global_field_names, fields, nextdof, vertexdicts, edgedicts, facedicts) where {dim}
ip_infos = InterpolationInfo[]
for interpolation in field_interpolations
for field in fields
interpolation = field.interpolation
ip_info = InterpolationInfo(interpolation)
push!(ip_infos, ip_info)
# TODO: More than one face dof per face in 3D are not implemented yet. This requires
Expand All @@ -300,14 +294,15 @@ function _close!(dh::MixedDofHandler{dim}, cellnumbers, global_field_names, fiel
len_cell_dofs = length(dh.cell_dofs)
dh.cell_dofs_offset[ci] = len_cell_dofs + 1

for (local_num, field_name) in enumerate(field_names)
fi = findfirst(i->i == field_name, global_field_names)
@debug "\tfield: $(field_name)"
for (local_num, field) in pairs(fields)
# for (local_num, field_name) in enumerate(field_names)
fi = findfirst(i->i == field.name, global_field_names)
@debug "\tfield: $(field.name)"
ip_info = ip_infos[local_num]

# Distribute dofs for vertices
nextdof = add_vertex_dofs(
dh.cell_dofs, cell, vertexdicts[fi], field_dims[local_num],
dh.cell_dofs, cell, vertexdicts[fi], field.dim,
ip_info.nvertexdofs, nextdof
)

Expand All @@ -316,7 +311,7 @@ function _close!(dh::MixedDofHandler{dim}, cellnumbers, global_field_names, fiel
# Regular 3D element or 2D interpolation embedded in 3D space
nentitydofs = ip_info.dim == 3 ? ip_info.nedgedofs : ip_info.nfacedofs
nextdof = add_edge_dofs(
dh.cell_dofs, cell, edgedicts[fi], field_dims[local_num],
dh.cell_dofs, cell, edgedicts[fi], field.dim,
nentitydofs, nextdof
)
end
Expand All @@ -325,14 +320,14 @@ function _close!(dh::MixedDofHandler{dim}, cellnumbers, global_field_names, fiel
# they are added above as edge dofs.
if ip_info.dim == dim
nextdof = add_face_dofs(
dh.cell_dofs, cell, facedicts[fi], field_dims[local_num],
dh.cell_dofs, cell, facedicts[fi], field.dim,
ip_info.nfacedofs, nextdof
)
end

# Distribute internal dofs for cells
nextdof = add_cell_dofs(
dh.cell_dofs, field_dims[local_num], ip_info.ncelldofs, nextdof
dh.cell_dofs, field.dim, ip_info.ncelldofs, nextdof
)
end

Expand Down Expand Up @@ -460,7 +455,7 @@ See also: [`find_field(dh::MixedDofHandler, field_name::Symbol)`](@ref), [`_find
function find_field(fh::FieldHandler, field_name::Symbol)
field_idx = _find_field(fh, field_name)
if field_idx === nothing
error("Did not find field :$field_name in FieldHandler (existing fields: $(getfieldnames(fh)))")
error("Did not find field :$field_name in FieldHandler (existing fields: $([field.name for field in fh.fields]))")
end
return field_idx
end
Expand Down Expand Up @@ -580,9 +575,9 @@ function reshape_to_nodes(dh::MixedDofHandler, u::Vector{T}, fieldname::Symbol)

for fh in dh.fieldhandlers
# check if this fh contains this field, otherwise continue to the next
field_pos = findfirst(i->i == fieldname, getfieldnames(fh))
field_pos === nothing && continue
offset = field_offset(fh, fieldname)
field_idx = _find_field(fh, fieldname)
field_idx === nothing && continue
offset = field_offset(fh, field_idx)

reshape_field_data!(data, dh, u, offset, field_dim, fh.cellset)
end
Expand Down
2 changes: 1 addition & 1 deletion src/Dofs/apply_analytical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ function apply_analytical!(
ip_geos = _default_interpolations(dh)

for (fh, ip_geo) in zip(dh.fieldhandlers, ip_geos)
fieldname ∈ getfieldnames(fh) || continue
isnothing(_find_field(fh, fieldname)) && continue
field_idx = find_field(fh, fieldname)
ip_fun = getfieldinterpolation(fh, field_idx)
field_dim = getfielddim(fh, field_idx)
Expand Down
2 changes: 1 addition & 1 deletion src/PointEval/PointEvalHandler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ get_func_interpolations(dh::DH, fieldname) where DH<:DofHandler = [getfieldinter
function get_func_interpolations(dh::DH, fieldname) where DH<:MixedDofHandler
func_interpolations = Union{Interpolation,Nothing}[]
for fh in dh.fieldhandlers
j = findfirst(i -> i === fieldname, getfieldnames(fh))
j = _find_field(fh, fieldname)
if j === nothing
push!(func_interpolations, missing)
else
Expand Down
2 changes: 1 addition & 1 deletion test/test_apply_analytical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
function _global_dof_range(dh::MixedDofHandler, field_name::Symbol)
dofs = Set{Int}()
for fh in dh.fieldhandlers
if field_name ∈ Ferrite.getfieldnames(fh)
if !isnothing(Ferrite._find_field(fh, field_name))
_global_dof_range!(dofs, dh, fh, field_name, fh.cellset)
end
end
Expand Down