Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HYPRE extension #9

Merged
merged 1 commit into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Tensors = "48a634ad-e948-5137-8d70-aa71f2a747f4"

[weakdeps]
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
Metis = "2679e427-3c69-5b7f-982b-ece356f1e94b"

[extensions]
FerriteDistributedHYPREAssembly = "HYPRE"
FerriteDistributedMetisPartitioning = "Metis"

[compat]
Expand Down
20 changes: 20 additions & 0 deletions ext/FerriteDistributedHYPREAssembly.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
Module containing the code for distributed assembly via HYPRE.jl
"""
module FerriteDistributedHYPREAssembly

using FerriteDistributed
# TODO remove me. These are merely hotfixes to split the extensions trasiently via an internal API.
import FerriteDistributed: getglobalgrid, num_local_true_dofs, num_local_dofs, global_comm, interface_comm, global_rank, compute_owner, remote_entities
using MPI
using HYPRE
using Base: @propagate_inbounds

include("FerriteDistributedHYPREAssembly/assembler.jl")
include("FerriteDistributedHYPREAssembly/conversion.jl")

function __init__()
@info "FerriteHYPRE extension loaded."
end

end # module FerriteHYPRE
48 changes: 48 additions & 0 deletions ext/FerriteDistributedHYPREAssembly/assembler.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
function Ferrite.create_sparsity_pattern(::Type{<:HYPREMatrix}, dh::Ferrite.AbstractDofHandler, ch::Union{ConstraintHandler,Nothing}=nothing; kwargs...)
K = create_sparsity_pattern(dh, ch; kwargs...)
fill!(K.nzval, 1)
return HYPREMatrix(K)
end

###########################################
## HYPREAssembler and associated methods ##
###########################################

struct HYPREAssembler <: Ferrite.AbstractSparseAssembler
A::HYPRE.HYPREAssembler
end

Ferrite.matrix_handle(a::HYPREAssembler) = a.A.A.A # :)
Ferrite.vector_handle(a::HYPREAssembler) = a.A.b.b # :)

function Ferrite.start_assemble(K::HYPREMatrix, f::HYPREVector)
return HYPREAssembler(HYPRE.start_assemble!(K, f))
end

function Ferrite.assemble!(a::HYPREAssembler, dofs::AbstractVector{<:Integer}, ke::AbstractMatrix, fe::AbstractVector)
HYPRE.assemble!(a.A, dofs, ke, fe)
end

function Ferrite.end_assemble(a::HYPREAssembler)
HYPRE.finish_assemble!(a.A)
end

## Methods for arrayutils.jl ##

function Ferrite.addindex!(A::HYPREMatrix, v, i::Int, j::Int)
nrows = HYPRE_Int(1)
ncols = Ref{HYPRE_Int}(1)
rows = Ref{HYPRE_BigInt}(i)
cols = Ref{HYPRE_BigInt}(j)
values = Ref{HYPRE_Complex}(v)
HYPRE.@check HYPRE_IJMatrixAddToValues(A.ijmatrix, nrows, ncols, rows, cols, values)
return A
end

function Ferrite.addindex!(b::HYPREVector, v, i::Int)
nvalues = HYPRE_Int(1)
indices = Ref{HYPRE_BigInt}(i)
values = Ref{HYPRE_Complex}(v)
HYPRE.@check HYPRE_IJVectorAddToValues(b.ijvector, nvalues, indices, values)
return b
end
119 changes: 119 additions & 0 deletions ext/FerriteDistributedHYPREAssembly/conversion.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Hypre to Ferrite vector
function hypre_to_ferrite!(u::Vector{T}, uh::HYPREVector, dh::Ferrite.AbstractDofHandler) where {T}
# Copy solution from HYPRE to Julia
uj = Vector{Float64}(undef, num_local_true_dofs(dh))
copy!(uj, uh)

my_rank = global_rank(getglobalgrid(dh))

# Helper to gather which global dof and values have to be send to which process
gdof_value_send = [Dict{Int,Float64}() for i ∈ 1:MPI.Comm_size(MPI.COMM_WORLD)]
# Helper to get the global dof to local dof mapping
rank_recv_count = [0 for i∈1:MPI.Comm_size(MPI.COMM_WORLD)]
gdof_to_ldof = Dict{Int,Int}()

next_dof = 1
for (ldof,rank) ∈ enumerate(dh.ldof_to_rank)
if rank == my_rank
u[ldof] = uj[next_dof]
next_dof += 1
else
# We have to sync these later.
gdof_to_ldof[dh.ldof_to_gdof[ldof]] = ldof
rank_recv_count[rank] += 1
end
end

# TODO speed this up and better API
dgrid = getglobalgrid(dh)
for sv ∈ get_shared_vertices(dgrid)
lvi = sv.local_idx
my_rank != compute_owner(dgrid, sv) && continue
for field_idx in 1:num_fields(dh)
if Ferrite.has_vertex_dofs(dh, field_idx, lvi)
local_dofs = Ferrite.vertex_dofs(dh, field_idx, lvi)
global_dofs = dh.ldof_to_gdof[local_dofs]
for receiver_rank ∈ keys(remote_entities(sv))
for i ∈ 1:length(global_dofs)
# Note that u already has the correct values for all locally owned dofs due to the loop above!
gdof_value_send[receiver_rank][global_dofs[i]] = u[local_dofs[i]]
end
end
end
end
end

for se ∈ get_shared_edges(dgrid)
lei = se.local_idx
my_rank != compute_owner(dgrid, se) && continue
for field_idx in 1:num_fields(dh)
if Ferrite.has_edge_dofs(dh, field_idx, lei)
local_dofs = Ferrite.edge_dofs(dh, field_idx, lei)
global_dofs = dh.ldof_to_gdof[local_dofs]
for receiver_rank ∈ keys(remote_entities(se))
for i ∈ 1:length(global_dofs)
# Note that u already has the correct values for all locally owned dofs due to the loop above!
gdof_value_send[receiver_rank][global_dofs[i]] = u[local_dofs[i]]
end
end
end
end
end

for sf ∈ get_shared_faces(dgrid)
lfi = sf.local_idx
my_rank != compute_owner(dgrid, sf) && continue
for field_idx in 1:num_fields(dh)
if Ferrite.has_face_dofs(dh, field_idx, lfi)
local_dofs = Ferrite.face_dofs(dh, field_idx, lfi)
global_dofs = dh.ldof_to_gdof[local_dofs]
for receiver_rank ∈ keys(remote_entities(sf))
for i ∈ 1:length(global_dofs)
# Note that u already has the correct values for all locally owned dofs due to the loop above!
gdof_value_send[receiver_rank][global_dofs[i]] = u[local_dofs[i]]
end
end
end
end
end

Ferrite.@debug println("preparing to distribute $gdof_value_send (R$my_rank)")

# TODO precompute graph at it is static
graph_source = Cint[my_rank-1]
graph_dest = Cint[]
for r ∈ 1:MPI.Comm_size(MPI.COMM_WORLD)
!isempty(gdof_value_send[r]) && push!(graph_dest, r-1)
end

graph_degree = Cint[length(graph_dest)]
graph_comm = MPI.Dist_graph_create(MPI.COMM_WORLD, graph_source, graph_degree, graph_dest)
indegree, outdegree, _ = MPI.Dist_graph_neighbors_count(graph_comm)

inranks = Vector{Cint}(undef, indegree)
outranks = Vector{Cint}(undef, outdegree)
MPI.Dist_graph_neighbors!(graph_comm, inranks, outranks)

send_count = [length(gdof_value_send[outrank+1]) for outrank ∈ outranks]
recv_count = [rank_recv_count[inrank+1] for inrank ∈ inranks]

send_gdof = Cint[]
for outrank ∈ outranks
append!(send_gdof, Cint.(keys(gdof_value_send[outrank+1])))
end
recv_gdof = Vector{Cint}(undef, sum(recv_count))
MPI.Neighbor_alltoallv!(VBuffer(send_gdof,send_count), VBuffer(recv_gdof,recv_count), graph_comm)

send_val = Cdouble[]
for outrank ∈ outranks
append!(send_val, Cdouble.(values(gdof_value_send[outrank+1])))
end
recv_val = Vector{Cdouble}(undef, sum(recv_count))
MPI.Neighbor_alltoallv!(VBuffer(send_val,send_count), VBuffer(recv_val,recv_count), graph_comm)

for (gdof, val) ∈ zip(recv_gdof, recv_val)
u[gdof_to_ldof[gdof]] = val
end

return u
end