Skip to content

Commit

Permalink
Merge pull request #59 from GiggleLiu/extension
Browse files Browse the repository at this point in the history
Move some dependencies to extensions
  • Loading branch information
houpc authored Oct 6, 2024
2 parents 43549d4 + 72f467c commit 1a40edc
Show file tree
Hide file tree
Showing 13 changed files with 149 additions and 130 deletions.
10 changes: 8 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@ version = "0.4.2"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"

[extensions]
MCIntegrationMPIExt = "MPI"
MCIntegrationCUDAExt = "CUDA"

[compat]
Graphs = "1"
Expand Down
Empty file added ext/MCIntegrationCUDAExt.jl
Empty file.
77 changes: 77 additions & 0 deletions ext/MCIntegrationMPIExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
module MCIntegrationMPIExt

using MCIntegration: integrate, ParallelBackend

function MCIntegrators.integrate(f, x, config::MPIBackend)
end

function doReweightMPI!(config::Configuration, gamma, reweight_goal::Union{Vector{Float64},Nothing}, comm::MPI.Comm)
if MCUtility.mpi_master()
# only the master process will output results, no matter parallel = :mpi or :thread or :serial
doReweight!(config, gamma, reweight_goal)
end
reweight_array = Vector{Float64}(config.reweight)
MPI.Bcast!(reweight_array, 0, comm)
config.reweight .= reweight_array
end

function MPIreduceConfig!(c::Configuration, root=0, comm=MPI.COMM_WORLD)
# Need to reduce from workers:
# neval
# var.histogram
# visited, propose, accept
# normalization, observable

function histogram_reduce!(var)
if var isa Dist.CompositeVar
for v in var.vars
histogram_reduce!(v)
end
else
MCUtility.MPIreduce!(var.histogram)
end
end

########## variable that could be a number ##############
c.neval = MCUtility.MPIreduce(c.neval) # reduce the amount of the commuication
c.normalization = MCUtility.MPIreduce(c.normalization) # reduce the amount of the commuication
for o in eachindex(c.observable)
if c.observable[o] isa AbstractArray
MCUtility.MPIreduce!(c.observable[o]) # avoid memory allocation
else
c.observable[o] = MCUtility.MPIreduce(c.observable[o])
end
end
for v in c.var
histogram_reduce!(v)
end

########## variable that are vectors ##############
MCUtility.MPIreduce!(c.visited)
MCUtility.MPIreduce!(c.propose)
MCUtility.MPIreduce!(c.accept)
end

function MPIbcastConfig!(c::Configuration, root=0, comm=MPI.COMM_WORLD)
# need to broadcast from root to workers:
# reweight
# var.histogram
function histogram_bcast!(var)
if var isa Dist.CompositeVar
for v in var.vars
histogram_bcast!(v)
end
else
MCUtility.MPIbcast!(var.histogram)
end
end

########## variable that could be a number ##############
MCUtility.MPIbcast(c.reweight)

for v in c.var
histogram_bcast!(v)
end
end

end
File renamed without changes.
11 changes: 9 additions & 2 deletions src/MCIntegration.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
module MCIntegration
using MPI
using Printf, Dates
using Random
using Graphs
using Test
using ProgressMeter
# using Measurements

const RNG = Random.GLOBAL_RNG
const TINY = eps(Float64(0)) * 1e50 # 4.940656458412466e-274
const EPSILON = eps(Float64)

abstract type ParallelBackend end
struct DefaultBackend <: ParallelBackend end
struct MPIBackend <: ParallelBackend end

function integrate(f, x, config::DefaultBackend)
end

integrate(f, x) = integrate(f, x, Defaultbackend())

# this is how vegas python package does it
# cdef double TINY = 10 ** (sys.float_info.min_10_exp + 50) # smallest and biggest
# cdef double HUGE = 10 ** (sys.float_info.max_10_exp - 50) # with extra headroom
Expand Down
61 changes: 1 addition & 60 deletions src/configuration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -333,65 +333,6 @@ function addConfig!(c::Configuration, ic::Configuration)
end
end

function MPIreduceConfig!(c::Configuration, root=0, comm=MPI.COMM_WORLD)
# Need to reduce from workers:
# neval
# var.histogram
# visited, propose, accept
# normalization, observable

function histogram_reduce!(var)
if var isa Dist.CompositeVar
for v in var.vars
histogram_reduce!(v)
end
else
MCUtility.MPIreduce!(var.histogram)
end
end

########## variable that could be a number ##############
c.neval = MCUtility.MPIreduce(c.neval) # reduce the amount of the commuication
c.normalization = MCUtility.MPIreduce(c.normalization) # reduce the amount of the commuication
for o in eachindex(c.observable)
if c.observable[o] isa AbstractArray
MCUtility.MPIreduce!(c.observable[o]) # avoid memory allocation
else
c.observable[o] = MCUtility.MPIreduce(c.observable[o])
end
end
for v in c.var
histogram_reduce!(v)
end

########## variable that are vectors ##############
MCUtility.MPIreduce!(c.visited)
MCUtility.MPIreduce!(c.propose)
MCUtility.MPIreduce!(c.accept)
end

function MPIbcastConfig!(c::Configuration, root=0, comm=MPI.COMM_WORLD)
# need to broadcast from root to workers:
# reweight
# var.histogram
function histogram_bcast!(var)
if var isa Dist.CompositeVar
for v in var.vars
histogram_bcast!(v)
end
else
MCUtility.MPIbcast!(var.histogram)
end
end

########## variable that could be a number ##############
MCUtility.MPIbcast(c.reweight)

for v in c.var
histogram_bcast!(v)
end
end

function bcastConfig!(dest::Configuration, src::Configuration)
# need to broadcast from root to workers:
# reweight
Expand Down Expand Up @@ -506,4 +447,4 @@ function report(config::Configuration, total_neval=nothing)
# end
# println()

end
end
9 changes: 0 additions & 9 deletions src/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,12 +336,3 @@ function doReweight!(config, gamma, reweight_goal)
# config.reweight ./= sum(config.reweight)
end

function doReweightMPI!(config::Configuration, gamma, reweight_goal::Union{Vector{Float64},Nothing}, comm::MPI.Comm)
if MCUtility.mpi_master()
# only the master process will output results, no matter parallel = :mpi or :thread or :serial
doReweight!(config, gamma, reweight_goal)
end
reweight_array = Vector{Float64}(config.reweight)
MPI.Bcast!(reweight_array, 0, comm)
config.reweight .= reweight_array
end
4 changes: 2 additions & 2 deletions src/mcmc/MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ using ..Dist

using ..MCUtility

using Random, MPI
using Random
using LinearAlgebra
using Printf, Dates

include("montecarlo.jl")
include("updates.jl")

end
end
14 changes: 7 additions & 7 deletions src/mcmc/montecarlo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ function montecarlo(config::Configuration{N,V,P,O,T}, integrand::Function, neval
# println("kwargs in mcmc: ", kwargs)

################## test integrand type stability ######################
if debug
if (length(config.var) == 1)
MCUtility.test_type_stability(integrand, (1, config.var[1], config))
else
MCUtility.test_type_stability(integrand, (1, config.var, config))
end
end
# if debug
# if (length(config.var) == 1)
# MCUtility.test_type_stability(integrand, (1, config.var[1], config))
# else
# MCUtility.test_type_stability(integrand, (1, config.var, config))
# end
# end
#######################################################################

curr = idx
Expand Down
29 changes: 13 additions & 16 deletions src/utility/utility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
Utility data structures and functions
"""
module MCUtility
using Test
using ..MPI
using ..Threads
using LinearAlgebra

Expand All @@ -13,7 +11,6 @@ export StopWatch, check
include("color.jl")
export black, red, green, yellow, blue, magenta, cyan, white

include("parallel.jl")
export disable_threading

export locate, smooth, rescale
Expand All @@ -39,17 +36,17 @@ export locate, smooth, rescale
# return str
# end

function test_type_stability(f, args)
try
@inferred f(args...)
catch e
if isa(e, MethodError)
@warn("call $f with wrong args. Got $(args)")
else
@warn "Type instability issue detected for $f, it may makes the integration slow" exception = (e, catch_backtrace())
# @warn("Type instability issue detected for $f, it may makes the integration slow.\n$e")
end
end
end
# function test_type_stability(f, args)
# try
# @inferred f(args...)
# catch e
# if isa(e, MethodError)
# @warn("call $f with wrong args. Got $(args)")
# else
# @warn "Type instability issue detected for $f, it may makes the integration slow" exception = (e, catch_backtrace())
# # @warn("Type instability issue detected for $f, it may makes the integration slow.\n$e")
# end
# end
# end

end
end
30 changes: 15 additions & 15 deletions src/vegas/montecarlo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,21 +82,21 @@ function montecarlo(config::Configuration{Ni,V,P,O,T}, integrand::Function, neva
diff = [config.dof[i] == config.maxdof for i in 1:Ni] # check if the dof is the same as the maxdof, if the same, then there is no need to update the padding probability

################## test integrand type stability ######################
if debug
if inplace
if (length(config.var) == 1)
MCUtility.test_type_stability(integrand, (config.var[1], weights, config))
else
MCUtility.test_type_stability(integrand, (config.var, weights, config))
end
else
if (length(config.var) == 1)
MCUtility.test_type_stability(integrand, (config.var[1], config))
else
MCUtility.test_type_stability(integrand, (config.var, config))
end
end
end
# if debug
# if inplace
# if (length(config.var) == 1)
# MCUtility.test_type_stability(integrand, (config.var[1], weights, config))
# else
# MCUtility.test_type_stability(integrand, (config.var, weights, config))
# end
# else
# if (length(config.var) == 1)
# MCUtility.test_type_stability(integrand, (config.var[1], config))
# else
# MCUtility.test_type_stability(integrand, (config.var, config))
# end
# end
# end
#######################################################################


Expand Down
4 changes: 2 additions & 2 deletions src/vegas_mc/VegasMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ using ..Dist

using ..MCUtility

using Random, MPI
using Random
using LinearAlgebra
using Printf, Dates

include("montecarlo.jl")
include("updates.jl")
end
end
30 changes: 15 additions & 15 deletions src/vegas_mc/montecarlo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,21 +123,21 @@ function montecarlo(config::Configuration{N,V,P,O,T}, integrand::Function, neval
_weights = zeros(T, N)

################## test integrand type stability ######################
if debug
if inplace
if (length(config.var) == 1)
MCUtility.test_type_stability(integrand, (config.var[1], weights, config))
else
MCUtility.test_type_stability(integrand, (config.var, weights, config))
end
else
if (length(config.var) == 1)
MCUtility.test_type_stability(integrand, (config.var[1], config))
else
MCUtility.test_type_stability(integrand, (config.var, config))
end
end
end
# @debug begin
# if inplace
# if (length(config.var) == 1)
# MCUtility.test_type_stability(integrand, (config.var[1], weights, config))
# else
# MCUtility.test_type_stability(integrand, (config.var, weights, config))
# end
# else
# if (length(config.var) == 1)
# MCUtility.test_type_stability(integrand, (config.var[1], config))
# else
# MCUtility.test_type_stability(integrand, (config.var, config))
# end
# end
# end
#######################################################################

relativeWeights = zeros(T, N)
Expand Down

0 comments on commit 1a40edc

Please sign in to comment.