Skip to content

Commit

Permalink
Cherry-pick HYPRE extension from Ferrite-FEM/Ferrite.jl#486 - which i…
Browse files Browse the repository at this point in the history
…n turn is based on work done by @fredrikekre
  • Loading branch information
termi-official committed Jun 7, 2023
1 parent 6c38b6b commit cbe1d30
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Tensors = "48a634ad-e948-5137-8d70-aa71f2a747f4"

[weakdeps]
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
Metis = "2679e427-3c69-5b7f-982b-ece356f1e94b"

[extensions]
FerriteDistributedHYPREAssembly = "HYPRE"
FerriteDistributedMetisPartitioning = "Metis"

[compat]
Expand Down
20 changes: 20 additions & 0 deletions ext/FerriteDistributedHYPREAssembly.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
Module containing the code for distributed assembly via HYPRE.jl
"""
module FerriteDistributedHYPREAssembly

using FerriteDistributed
# TODO remove me. These are merely hotfixes to split the extensions trasiently via an internal API.
import FerriteDistributed: getglobalgrid, num_local_true_dofs, num_local_dofs, global_comm, interface_comm, global_rank, compute_owner, remote_entities
using MPI
using HYPRE
using Base: @propagate_inbounds

include("FerriteDistributedHYPREAssembly/assembler.jl")
include("FerriteDistributedHYPREAssembly/conversion.jl")

function __init__()
@info "FerriteHYPRE extension loaded."
end

end # module FerriteHYPRE
48 changes: 48 additions & 0 deletions ext/FerriteDistributedHYPREAssembly/assembler.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
function Ferrite.create_sparsity_pattern(::Type{<:HYPREMatrix}, dh::Ferrite.AbstractDofHandler, ch::Union{ConstraintHandler,Nothing}=nothing; kwargs...)
K = create_sparsity_pattern(dh, ch; kwargs...)
fill!(K.nzval, 1)
return HYPREMatrix(K)
end

###########################################
## HYPREAssembler and associated methods ##
###########################################

struct HYPREAssembler <: Ferrite.AbstractSparseAssembler
A::HYPRE.HYPREAssembler
end

Ferrite.matrix_handle(a::HYPREAssembler) = a.A.A.A # :)
Ferrite.vector_handle(a::HYPREAssembler) = a.A.b.b # :)

function Ferrite.start_assemble(K::HYPREMatrix, f::HYPREVector)
return HYPREAssembler(HYPRE.start_assemble!(K, f))
end

function Ferrite.assemble!(a::HYPREAssembler, dofs::AbstractVector{<:Integer}, ke::AbstractMatrix, fe::AbstractVector)
HYPRE.assemble!(a.A, dofs, ke, fe)
end

function Ferrite.end_assemble(a::HYPREAssembler)
HYPRE.finish_assemble!(a.A)
end

## Methods for arrayutils.jl ##

function Ferrite.addindex!(A::HYPREMatrix, v, i::Int, j::Int)
nrows = HYPRE_Int(1)
ncols = Ref{HYPRE_Int}(1)
rows = Ref{HYPRE_BigInt}(i)
cols = Ref{HYPRE_BigInt}(j)
values = Ref{HYPRE_Complex}(v)
HYPRE.@check HYPRE_IJMatrixAddToValues(A.ijmatrix, nrows, ncols, rows, cols, values)
return A
end

function Ferrite.addindex!(b::HYPREVector, v, i::Int)
nvalues = HYPRE_Int(1)
indices = Ref{HYPRE_BigInt}(i)
values = Ref{HYPRE_Complex}(v)
HYPRE.@check HYPRE_IJVectorAddToValues(b.ijvector, nvalues, indices, values)
return b
end
119 changes: 119 additions & 0 deletions ext/FerriteDistributedHYPREAssembly/conversion.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Hypre to Ferrite vector
function hypre_to_ferrite!(u::Vector{T}, uh::HYPREVector, dh::Ferrite.AbstractDofHandler) where {T}
# Copy solution from HYPRE to Julia
uj = Vector{Float64}(undef, num_local_true_dofs(dh))
copy!(uj, uh)

my_rank = global_rank(getglobalgrid(dh))

# Helper to gather which global dof and values have to be send to which process
gdof_value_send = [Dict{Int,Float64}() for i 1:MPI.Comm_size(MPI.COMM_WORLD)]
# Helper to get the global dof to local dof mapping
rank_recv_count = [0 for i1:MPI.Comm_size(MPI.COMM_WORLD)]
gdof_to_ldof = Dict{Int,Int}()

next_dof = 1
for (ldof,rank) enumerate(dh.ldof_to_rank)
if rank == my_rank
u[ldof] = uj[next_dof]
next_dof += 1
else
# We have to sync these later.
gdof_to_ldof[dh.ldof_to_gdof[ldof]] = ldof
rank_recv_count[rank] += 1
end
end

# TODO speed this up and better API
dgrid = getglobalgrid(dh)
for sv get_shared_vertices(dgrid)
lvi = sv.local_idx
my_rank != compute_owner(dgrid, sv) && continue
for field_idx in 1:num_fields(dh)
if Ferrite.has_vertex_dofs(dh, field_idx, lvi)
local_dofs = Ferrite.vertex_dofs(dh, field_idx, lvi)
global_dofs = dh.ldof_to_gdof[local_dofs]
for receiver_rank keys(remote_entities(sv))
for i 1:length(global_dofs)
# Note that u already has the correct values for all locally owned dofs due to the loop above!
gdof_value_send[receiver_rank][global_dofs[i]] = u[local_dofs[i]]
end
end
end
end
end

for se get_shared_edges(dgrid)
lei = se.local_idx
my_rank != compute_owner(dgrid, se) && continue
for field_idx in 1:num_fields(dh)
if Ferrite.has_edge_dofs(dh, field_idx, lei)
local_dofs = Ferrite.edge_dofs(dh, field_idx, lei)
global_dofs = dh.ldof_to_gdof[local_dofs]
for receiver_rank keys(remote_entities(se))
for i 1:length(global_dofs)
# Note that u already has the correct values for all locally owned dofs due to the loop above!
gdof_value_send[receiver_rank][global_dofs[i]] = u[local_dofs[i]]
end
end
end
end
end

for sf get_shared_faces(dgrid)
lfi = sf.local_idx
my_rank != compute_owner(dgrid, sf) && continue
for field_idx in 1:num_fields(dh)
if Ferrite.has_face_dofs(dh, field_idx, lfi)
local_dofs = Ferrite.face_dofs(dh, field_idx, lfi)
global_dofs = dh.ldof_to_gdof[local_dofs]
for receiver_rank keys(remote_entities(sf))
for i 1:length(global_dofs)
# Note that u already has the correct values for all locally owned dofs due to the loop above!
gdof_value_send[receiver_rank][global_dofs[i]] = u[local_dofs[i]]
end
end
end
end
end

Ferrite.@debug println("preparing to distribute $gdof_value_send (R$my_rank)")

# TODO precompute graph at it is static
graph_source = Cint[my_rank-1]
graph_dest = Cint[]
for r 1:MPI.Comm_size(MPI.COMM_WORLD)
!isempty(gdof_value_send[r]) && push!(graph_dest, r-1)
end

graph_degree = Cint[length(graph_dest)]
graph_comm = MPI.Dist_graph_create(MPI.COMM_WORLD, graph_source, graph_degree, graph_dest)
indegree, outdegree, _ = MPI.Dist_graph_neighbors_count(graph_comm)

inranks = Vector{Cint}(undef, indegree)
outranks = Vector{Cint}(undef, outdegree)
MPI.Dist_graph_neighbors!(graph_comm, inranks, outranks)

send_count = [length(gdof_value_send[outrank+1]) for outrank outranks]
recv_count = [rank_recv_count[inrank+1] for inrank inranks]

send_gdof = Cint[]
for outrank outranks
append!(send_gdof, Cint.(keys(gdof_value_send[outrank+1])))
end
recv_gdof = Vector{Cint}(undef, sum(recv_count))
MPI.Neighbor_alltoallv!(VBuffer(send_gdof,send_count), VBuffer(recv_gdof,recv_count), graph_comm)

send_val = Cdouble[]
for outrank outranks
append!(send_val, Cdouble.(values(gdof_value_send[outrank+1])))
end
recv_val = Vector{Cdouble}(undef, sum(recv_count))
MPI.Neighbor_alltoallv!(VBuffer(send_val,send_count), VBuffer(recv_val,recv_count), graph_comm)

for (gdof, val) zip(recv_gdof, recv_val)
u[gdof_to_ldof[gdof]] = val
end

return u
end

0 comments on commit cbe1d30

Please sign in to comment.