Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multisampler #148

Open
wants to merge 38 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
df6f389
started working on multisampler
mwien Feb 29, 2024
e785e56
minor fixes
mwien Feb 29, 2024
f4cb6c5
add full multisampler
Mar 1, 2024
82ed06c
remove comment
Mar 1, 2024
ef04dea
remove printlns
Mar 1, 2024
c33bd49
Fixes
mschauer Mar 1, 2024
113866d
define penalty
mschauer Mar 1, 2024
030c087
Tuning
mschauer Mar 1, 2024
37ca082
More extensive tests
mschauer Mar 1, 2024
e69b300
rename for clarity
mwien Mar 1, 2024
27fa658
fix copy
mwien Mar 1, 2024
6698330
minor rename
mwien Mar 1, 2024
f8e7b20
remove comment
mwien Mar 1, 2024
00fea92
Fix test
mschauer Mar 1, 2024
953bbb2
Merge pull request #1 from mwien/multisampler2
mwien Mar 1, 2024
0d8f92e
I don't think we can save all graphs
mschauer Mar 2, 2024
5b12aac
fix comment
mwien Mar 2, 2024
0009661
Update src/multisampler.jl
mschauer Mar 2, 2024
35ad2e6
Ups
mschauer Mar 3, 2024
df05f9c
Make sure running init first
mschauer Mar 3, 2024
3f44570
Tricky business involving assuring that log(Pi) < 0
mschauer Mar 3, 2024
750f225
Count kills
mschauer Mar 3, 2024
aea3578
Fixes
mschauer Mar 3, 2024
c2ec4d0
Once more
mschauer Mar 3, 2024
112246d
Update src/multisampler.jl
mschauer Mar 3, 2024
a289505
Update src/multisampler.jl
mschauer Mar 3, 2024
bc4a3f0
Older versions do not support typed globals
mschauer Mar 3, 2024
6256f22
Factoring
mschauer Mar 6, 2024
cc38750
Cosmetics
mschauer Mar 6, 2024
0c5b3fe
Return a bit later
mschauer Mar 6, 2024
6fd67cf
Cosmetics
mschauer Mar 7, 2024
92d8396
Show temp
mschauer Mar 7, 2024
698dbde
Bug fix (never copy dead states)
mschauer Mar 7, 2024
7e3db9e
Need to monitor time.
mschauer Mar 7, 2024
2cec3f4
Score
mschauer Mar 7, 2024
110f294
Move opposite direction
mschauer Mar 7, 2024
67c730b
No need for baseline
mschauer Mar 7, 2024
0e11569
Allow stopping at target
mschauer Mar 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.15.1"

[deps]
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
Expand All @@ -26,6 +27,17 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
TabularDisplay = "3eeacb1d-13c2-54cc-9b18-30c86af3cadb"
ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"

[weakdeps]
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
GraphRecipes = "bd48cda9-67a9-57be-86fa-5b3c104eda73"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
TikzGraphs = "b4f28e30-c73f-5eaf-a395-8a9db949a742"

[extensions]
GraphMakieExt = "GraphMakie"
GraphRecipesExt = ["GraphRecipes", "Plots"]
TikzGraphsExt = "TikzGraphs"

[compat]
Combinatorics = "1.0"
DelimitedFiles = "1.6, 1.7, 1.8, 1.9"
Expand Down Expand Up @@ -55,11 +67,6 @@ ThreadsX = "0.1"
TikzGraphs = "1.3, 1.4"
julia = "1.6, 1.7, 1.8, 1.9, 1.10"

[extensions]
GraphMakieExt = "GraphMakie"
GraphRecipesExt = ["GraphRecipes", "Plots"]
TikzGraphsExt = "TikzGraphs"

[extras]
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
Expand All @@ -71,9 +78,3 @@ TikzGraphs = "b4f28e30-c73f-5eaf-a395-8a9db949a742"

[targets]
test = ["Test", "StatsBase", "DelimitedFiles"]

[weakdeps]
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
GraphRecipes = "bd48cda9-67a9-57be-86fa-5b3c104eda73"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
TikzGraphs = "b4f28e30-c73f-5eaf-a395-8a9db949a742"
3 changes: 3 additions & 0 deletions src/CausalInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ using Base.Iterators
using Memoization, LRUCache
using ThreadsX
using LinkedLists
using DataStructures

import Base: ==, show

export multisampler
export exactscorebased
export ancestors, descendants, alt_test_dsep, test_covariate_adjustment, alt_test_backdoor, find_dsep, find_min_dsep, find_covariate_adjustment, find_backdoor_adjustment, find_frontdoor_adjustment, find_min_covariate_adjustment, find_min_backdoor_adjustment, find_min_frontdoor_adjustment, list_dseps, list_covariate_adjustment, list_backdoor_adjustment, list_frontdoor_adjustment
export dsep, skeleton, gausscitest, dseporacle, partialcor
Expand Down Expand Up @@ -56,6 +58,7 @@ include("dag_sampler.jl")
include("misc2.jl")
include("exact.jl")
#include("mcs.jl")
include("multisampler.jl")

# Compatibility with the new "Package Extensions" (https://github.com/JuliaLang/julia/pull/47695)
const EXTENSIONS_SUPPORTED = isdefined(Base, :get_extension)
Expand Down
143 changes: 143 additions & 0 deletions src/multisampler.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
struct Sample
g::DiGraph
τ::Float64
dir::Int8
total::Int32
scoreval::Float64
end

struct Action
τ::Float64
apply::Function
args::Tuple{Vararg{Any}}
end

function expcoldness(τ, k=0.0005)
return exp(k*τ)
end

function Dexpcoldness(τ, k=0.0005)
return k*exp(k*τ)
end

function init(_, _, nextτ, g, dir, total, scoreval)
return Sample(g, nextτ, dir, total, scoreval)
end

function applyup(samplers, i, nextτ, x, y, T, Δscoreval)
prevsample = samplers[i]
g = next_CPDAG(prevsample.g, :up, x, y, T)
return Sample(g, nextτ, prevsample.dir, prevsample.total+1, prevsample.scoreval + Δscoreval)
end

function applydown(samplers, i, nextτ, x, y, H, Δscoreval)
prevsample = samplers[i]
g = next_CPDAG(prevsample.g, :down, x, y, H)
return Sample(g, nextτ, prevsample.dir, prevsample.total-1, prevsample.scoreval + Δscoreval)
end

function applyflip(samplers, i, nextτ)
prevsample = samplers[i]
return Sample(prevsample.g, nextτ, -1*prevsample.dir, prevsample.total, prevsample.scoreval)
end

function applycopy(samplers, _, nextτ, j)
copysample = samplers[j]
return Sample(copysample.g, nextτ, copysample.dir, copysample.total, copysample.scoreval)
end

# for starters without turn move
baseline_::Float64 = 0.0
mschauer marked this conversation as resolved.
Show resolved Hide resolved

function sampleaction(samplers, i, M, balance, prior, score, σ, ρ, κ, coldness, Dcoldness, threshold)
# preprocess
prevsample = samplers[i]
sup, sdown, Δscorevalup, Δscorevaldown, argsup, argsdown = count_moves_new(prevsample.g, κ, balance, prior, score, coldness(prevsample.τ), prevsample.total)
global baseline_
# propose moves
λdir = prevsample.dir == 1 ? sup : sdown
λupdown = sup + sdown
λflip = max(prevsample.dir*(-sup + sdown), 0.0)
if baseline_ - prevsample.scoreval <= 0 # assert exp(score) < 1.0, -score > 0
baseline_ = prevsample.scoreval
end
λterm = exp(ULogarithmic, 0.0)*Dcoldness(prevsample.τ) * clamp(baseline_ - prevsample.scoreval, eps(), threshold) # TODO: prior
Δτdir = randexp()/(ρ*λdir)
Δτupdown = randexp()/(σ*λupdown)
Δτflip = randexp()/(ρ*λflip)
Δτterm = randexp()/abs(λterm)
Δτmin, a = findmin((Δτdir, Δτupdown, Δτflip, Δτterm))
A = (:dir, :updown, :flip, :term)[a]
@assert Δτmin >= 0
if :dir == A
if prevsample.dir == 1
return Action(prevsample.τ + Δτdir, applyup, (argsup..., Δscorevalup))
else
return Action(prevsample.τ + Δτdir, applydown, (argsdown..., Δscorevaldown))
end
end

if :updown == A
λup = sup
if rand() < λup/λupdown
return Action(prevsample.τ + Δτupdown, applyup, (argsup..., Δscorevalup))
else
return Action(prevsample.τ + Δτupdown, applydown, (argsdown..., Δscorevaldown))
end
end

if :flip == A
return Action(prevsample.τ + Δτflip, applyflip, ())
end

if :term == A
return Action(prevsample.τ + Δτterm, applycopy, (rand(1:M),))
mschauer marked this conversation as resolved.
Show resolved Hide resolved
end

@assert false
end
# remark: chose κ = n-1 as default
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, better

function multisampler(n, G = (DiGraph(n), 0); M = 10, balance = metropolis_balance, prior = (_,_) -> 1.0, score=UniformScore(), σ = 0.0, ρ = 1.0, κ = n - 1, baseline = 0.0, iterations = min(3*n^2, 50000), schedule=(expcoldness, Dexpcoldness), threshold=Inf) #, verbose = false, save = true)
if κ >= n
κ = n - 1
@warn "Truncate κ to $κ"
end
coldness, Dcoldness = schedule

global baseline_
baseline_ = baseline
# init M samplers
samplers = [Sample(G[1], 0.0, 1, G[2], 0.0) for _ = 1:M] # pass correct initial score?!
nextaction = Vector{Action}(undef, M)
queue = PriorityQueue{Int32, Float64}()

for i = 1:M
nextaction[i] = sampleaction(samplers, i, M, balance, prior, score, σ, ρ, κ, coldness, Dcoldness, threshold)
enqueue!(queue, i, nextaction[i].τ)
end

# todo: multiply iterations by M to keep passed iteration number indep of M?
# could also stop if one sampler has more than iterations many samples
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One can make the argument that the right stopping criterion is the total number of samples of all samplers

# but then @showprogress does not work so nicely?!
iterations *= M
bestgraph = DiGraph(n)
bestscore = 0.0 # fix if correct initial score is given above
count = 0
@showprogress for iter in 1:iterations
i = dequeue!(queue)
count += nextaction[i].apply == applycopy
nextsample = nextaction[i].apply(samplers, i, nextaction[i].τ, nextaction[i].args...)
if nextsample.scoreval > bestscore
bestgraph = nextsample.g
bestscore = nextsample.scoreval
end
samplers[i] = nextsample
nextaction[i] = sampleaction(samplers, i, M, balance, prior, score, σ, ρ, κ, coldness, Dcoldness, threshold)
enqueue!(queue, i, nextaction[i].τ)
# todo: applyflip shouldn't increase counter
end
killratio = count/iterations
@show killratio

return bestgraph, samplers
end
87 changes: 87 additions & 0 deletions test/multisampler.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
using Random, CausalInference, StatsBase, Statistics, Test, Graphs, LinearAlgebra
@testset "MultiSampler" begin
Random.seed!(1)

N = 400 # number of data points

# define simple linear model with added noise
x = randn(N)
v = x + randn(N)*0.25
w = x + randn(N)*0.25
z = v + w + randn(N)*0.25
s = z + randn(N)*0.25

df = (x=x, v=v, w=w, z=z, s=s)
iterations = 1_000
penalty = 2.0 # increase to get more edges in truth
n = length(df) # vertices
Random.seed!(101)
C = cor(CausalInference.Tables.matrix(df))
score = GaussianScore(C, N, penalty)
decay = 1e-5
schedule = (τ -> 1.0 + τ*decay, τ -> decay) # linear
M = 20
baseline = 0.0
bestgraph, samplers = multisampler(n; M, ρ = 1.0, score, baseline, schedule, iterations)
#posterior = sort(keyedreduce(+, graph_pairs, ws); byvalue=true, rev=true)

# maximum aposteriori estimate
MAP = [1=>2, 1=>3, 2=>1, 2=>4, 3=>1, 3=>4, 4=>5]
@test bestgraph == digraph(MAP, n)
cm = sort(countmap(vpairs.(getfield.(samplers, :g))), byvalue=true, rev=true)
Tmin, T = extrema(getfield.(samplers, :τ))
@show Tmin T schedule[1](T)
@test first(cm).first == MAP
end #testset

@testset "MultiSampler" begin
Random.seed!(1)
decay = 1e-5
schedule = (τ -> 1.0 + τ*decay, τ -> decay) # linear

N = 200 # number of data points

# define simple linear model with added noise
x = randn(N)
v = x + randn(N)*0.25
w = x + randn(N)*0.25
z = v + w + randn(N)*0.25
s = z + randn(N)*0.25

df = (x=x, v=v, w=w, z=z, s=s)
iterations = 500
penalty = 2.0 # increase to get more edges in truth
n = length(df) # vertices
Random.seed!(101)
C = cor(CausalInference.Tables.matrix(df))
score = GaussianScore(C, N, penalty)
M = 100
bestgraph, samplers = multisampler(n; M, score, schedule, iterations)
Tmin, T = extrema(getfield.(samplers, :τ))
coldness = schedule[1](T)
@show Tmin T coldness

gs = causalzigzag(n; score, κ=n-1, coldness, iterations=iterations*100)
graphs, graph_pairs, hs, τs, ws, ts, scores = CausalInference.unzipgs(gs)
posterior = sort(keyedreduce(+, graph_pairs, ws); byvalue=true, rev=true)


# maximum aposteriori estimate
MAP = [1=>2, 1=>3, 2=>1, 2=>4, 3=>1, 3=>4, 4=>5]
@test bestgraph == digraph(MAP, n)
cm = sort((proportionmap(vpairs.(getfield.(samplers, :g)))), byvalue=true, rev=true)
@test first(cm).first == MAP
logΠ = map(g->score_dag(pdag2dag!(digraph(g, n)), score), collect(keys(cm)))
Π = normalize(exp.(coldness*(logΠ .- maximum(logΠ))), 1)
Πhat = normalize(collect(values(cm)), 1)

display([Π Πhat])
s = 0.0
for (i, k) in enumerate(keys(cm))
s += get(posterior, k, 0.0)
#@show cm[k] Π[i]
end
@show s
@test s > 0.98
@test norm(collect(values(cm)) - Π) < 0.04
end #testset
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ include("witness.jl")
include("fci.jl")
include("klentropy.jl")
include("backdoor.jl")
include("multisampler.jl")
Loading