Skip to content

Commit

Permalink
Merge pull request #99 from mschauer/cleanup
Browse files Browse the repository at this point in the history
General cleanup and performance
  • Loading branch information
mschauer authored Jul 26, 2023
2 parents a60db2c + 567859e commit 14074af
Show file tree
Hide file tree
Showing 10 changed files with 12,067 additions and 125 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ TikzGraphsExt = "TikzGraphs"

[compat]
Combinatorics = "1.0"
DelimitedFiles = "1.6, 1.7, 1.8, 1.9"
Distances = "0.8, 0.9, 0.10"
Distributions = "0.22, 0.23, 0.24, 0.25"
GraphRecipes = "0.5"
Expand All @@ -54,6 +55,7 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TikzGraphs = "b4f28e30-c73f-5eaf-a395-8a9db949a742"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[targets]
test = ["Test", "StatsBase"]
test = ["Test", "StatsBase", "DelimitedFiles"]
5,001 changes: 5,001 additions & 0 deletions gmG8.csv

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions gmG8adj.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"","V1","V2","V3","V4","V5","V6","V7","V8"
"1",FALSE,TRUE,FALSE,FALSE,FALSE,TRUE,FALSE,TRUE
"2",TRUE,FALSE,TRUE,FALSE,TRUE,FALSE,FALSE,FALSE
"3",FALSE,TRUE,FALSE,FALSE,FALSE,FALSE,FALSE,FALSE
"4",FALSE,FALSE,FALSE,FALSE,FALSE,FALSE,FALSE,FALSE
"5",FALSE,TRUE,FALSE,FALSE,FALSE,TRUE,FALSE,TRUE
"6",FALSE,FALSE,FALSE,FALSE,FALSE,FALSE,TRUE,FALSE
"7",FALSE,FALSE,FALSE,FALSE,FALSE,FALSE,FALSE,FALSE
"8",FALSE,FALSE,FALSE,FALSE,FALSE,FALSE,FALSE,FALSE
6,831 changes: 6,831 additions & 0 deletions nci60.csv

Large diffs are not rendered by default.

65 changes: 65 additions & 0 deletions nci60adj.csv

Large diffs are not rendered by default.

178 changes: 70 additions & 108 deletions src/ges.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,4 @@
####################################################################
# Data Structures
####################################################################

# Collection of variables to pass along to different functions in the GES algorithm.
struct ParseData{M, F<:AbstractFloat}
data::M # original data
normAugData::M # standardized by the mean and std of each column, appended with ones column at end
numFeatures::Int # number of columns
numObservations::Int # number of rows
penalty::F

function ParseData(data::M, penalty) where M

# Determine the data type of the data inputted
baseType = eltype(data)

# Get the dimensions of the input data
numObservations, numFeatures = size(data)

# Copy the data and standardize each column
normAugData = copy(data)
normAugData .-= mean(normAugData, dims=1) #subtract mean
normAugData ./= std(normAugData, dims=1) #divide by standard deviation

# Augment a column of ones on the end for linear regression
normAugData = [normAugData ones(baseType, numObservations)]

# Ensure that the penalty is the same type as the data (e.g. Float32)
penalty = baseType(penalty)

return new{M, baseType}(data, normAugData, numFeatures, numObservations, penalty)
end
end

# Simple structure to hold the current edge, a subset of neighbors, and a score change
Base.@kwdef struct Step{A,B}
Expand All @@ -40,57 +7,54 @@ Base.@kwdef struct Step{A,B}
Δscore::B = zero(B)
end

####################################################################
# Base Function Overloads
####################################################################

# Print method to display the Step
function show(io::IO, nextstep::Step{A,B}) where {A,B}
print(io, "Edge: $(nextstep.edge), Subset: $(nextstep.subset), Δscore: $(nextstep.Δscore)")
end

# The @memoize macro has to check if ParseData is the same argument.
# We only ever define one immutable ParseData. We check equality of data.
==(a::T, b::T) where T <: ParseData = a.data === b.data


####################################################################
# Main Entry point for the Algorithm
####################################################################

"""
ges(data; verbose=false)
ges(X; method=:gaussian_bic, penalty=0.5, parallel=false, verbose=false)
Compute a causal graph for the given observed data using GES.
Compute a causal graph for the given observed data `X` (variables in columns) using GES.
Returns the CPDAG, the score improvement relative to the empty graph and time measurements
of first and second phase.
"""
function ges(data; penalty=1.0, parallel=false, verbose=false)

# Data type / precision
score = zero(eltype(data))

# Parse the inputted data
dataParsed = ParseData(data, penalty)


return ges(dataParsed.numFeatures, dataParsed; score, parallel, verbose)
function ges(X::AbstractMatrix; method=:gaussian_bic, penalty=0.5, parallel=false, verbose=false)
score = zero(eltype(X)) # initial score
n, d = size(X)
d n && @warn "High dimensional data (n ≤ p), ges might not terminate."
if method == :gaussian_bic
C = Symmetric(cov(X, dims = 1, corrected=false))
return ges(d, GaussianScore(C, n, penalty); score, parallel, verbose)
elseif method == :gaussian_bic_raw
return ges(d, GaussianScoreQR(X, penalty); score, parallel, verbose)
else
throw(ArgumentError("method=$method"))
end
end

"""
ges(n, data; score=0.0, parallel=false, verbose=false)
ges(n, local_score; score=0.0, parallel=false, verbose=false)
Estimate a causal graph for the given observed data using GES.
Internal method called by `ges`.
"""
function ges(n, data; score=0.0, parallel=false, verbose=false)
# Create an empty graph with one node for each feature
g = DiGraph(n)

parallel && Threads.nthreads() == 1 && @warn "Only one thread available"
verbose && println("Start forward search")
g, score = ges_search_insert!(g, score, data, parallel, verbose)
t1 = @elapsed g, score = ges_search_insert!(g, score, data, parallel, verbose)

verbose && println("Start backward search")
g, score = ges_search_delete!(g, score, data, verbose)
t2 = @elapsed g, score = ges_search_delete!(g, score, data, verbose)

return g, score
return g, score, (t1, t2)
end

####################################################################
Expand Down Expand Up @@ -159,7 +123,7 @@ function ges_search_insert!(g, score, data, parallel, verbose)

# If the score did not improve...
if step.Δscore 0
verbose && println(vpairs(g))
#verbose && println(vpairs(g))
break
end
verbose && println(step)
Expand All @@ -186,14 +150,14 @@ function ges_search_delete!(g, score, data, verbose)

# If the score did not improve...
if step.Δscore 0
verbose && println(vpairs(g))
#verbose && println(vpairs(g))
break
end
verbose && println(step)
# Use the insert or delete operator update the graph
Delete!(g, step)
score += step.Δscore
verbose && println(vpairs(g))
#verbose && println(vpairs(g))

# Convert the PDAG to a complete PDAG
# Undirect all edges unless they participate in a v-structure
Expand Down Expand Up @@ -356,7 +320,7 @@ export score_dag
function score_dag(g, data) # g dag
s = 0.0
for v in vertices(g)
s += score(data, inneighbors(g, v), v)
s += local_score(data, inneighbors(g, v), v)
end
s
end
Expand All @@ -375,70 +339,68 @@ hash(a::GaussianScore, u::UInt) = hash(a.hash, u)

# compare https://github.com/py-why/causal-learn/blob/f51195473b316662b6f7dce68cd73d734766a6a3/causallearn/score/LocalScoreFunction.py
"""
score(os::GaussianScore, p, v)
local_score(os::GaussianScore, p, v)
Local Gaussian BIC score.
Local Gaussian BIC score. Memoized for `GaussianScore{Float64, Symmetric{Float64, Matrix{Float64}}}`.
"""
@memoize LRU(maxsize=1_000_000) function local_score(os::GaussianScore, p, v)
function local_score(os::GaussianScore, p, v)
length(p) > 2 && return local_score_mem(os, p, v)
local_score_(os, p, v)
end
@memoize LRU{Tuple{Tuple{GaussianScore{Float64, Symmetric{Float64, Matrix{Float64}}}, Vector{Int64}, Int64}, Tuple{}}, Float64}(maxsize=100_000) function local_score_mem(os::GaussianScore{Float64, Symmetric{Float64, Matrix{Float64}}}, p, v)
local_score_(os, p, v)
end
function local_score_mem(os::GaussianScore, p, v)
local_score_(os, p, v)
end

function local_score_(os::GaussianScore, p, v)
k = length(p)
C = os.C
penalty = os.penalty
n = os.n
k = length(p) # dimension
if k == 0
Cp = C[v, v]
elseif k == 1
p_ = p[]
c = C[p_, v]
Cp = C[v, v] - c*(C[p_, p_]\c)
else # compute conditional correlation
Cp = C[v, v] - dot(C[v, p], C[p, p]\C[p, v])
c = @view C[p, v]
Cp = C[v, v] - dot(c, (@view C[p, p])\c)
end
- n*log(Cp) - penalty*k*log(n)
(-n*(1 + log(Cp)) - penalty*(1 + k)*log(n))/2
end



@memoize LRU(maxsize=1_000_000) function local_score(dataParsed::ParseData{Matrix{A}}, nodeParents, node) where A
struct GaussianScoreQR{T, S<:AbstractMatrix{T}}
X::S # Data matrix
penalty::Float64
hash::UInt
end
GaussianScoreQR(X, penalty) = (Xc = X .- mean(X, dims=1); GaussianScoreQR(Xc, penalty, hash((Xc, penalty))))
export GaussianScoreQR
is_equal(a::T, b::T) where T <: GaussianScoreQR = ((a.X === b.X) || is_equal(a.X, b.X)) && is_equal(a.penalty, b.penalty)
hash(a::GaussianScoreQR, u::UInt) = hash(a.hash, u)

# Unpack some variables from the dataParsed structure
n = A(dataParsed.numObservations) #convert datatype
data = dataParsed.normAugData
p = dataParsed.penalty

# The last column of dataParsed.normAugData is all ones which is required for a linear regression with an intercept. If there are no node parents, model the child node with just the intercept, else use the parents and the intercept
if isempty(nodeParents)
parentsAndIncept = [dataParsed.numFeatures+1]
else
parentsAndIncept = [nodeParents; dataParsed.numFeatures+1]
end

# To calculate the score we need a mean-squared
# error which we can get by regessing the the child node onto the parents
# Create variables for a linear regression y = X*b

# Use views to avoid creating copies of the data in memory
# X is the design matrix, augmented with a column of ones at the end
# X is also been standardized so mean(columns)=0 and std(column)=1
# y is data from the child node being tested
@views begin
y = data[:,node]
X = data[:,parentsAndIncept]
@memoize LRU(maxsize=1_000_000) function local_score(os::GaussianScoreQR, p, v)
X = os.X
penalty = os.penalty
n = size(X, 1)
k = length(p) # dimension
y = @view X[:, v]
if k == 0
Cp = var(y; mean=0.0, corrected=false)
else # compute conditional correlation
x = @view X[:, p]
Cp = var(y - x*(x\y); mean=0.0, corrected=false)
end
(-n*(1 + log(Cp)) - penalty*(1 + k)*log(n))/2
end

# Perform a linear regression
b = X \ y

# Get the estimation
ŷ = X*b

# Next we want to calculate the log likelihood that these are the parents of our node
# score = log(P(data|Model)) ≈ -BIC/2
# because we're only comparing log likelihoods we'll ignore the 1/2 factor
# when P(⋅) is Gaussian, log(P(data|Model)) takes the form:
# -k⋅log(n) - n⋅log(mse)
# k is the number of free parameters and mse is mean squared error
k = length(parentsAndIncept) #includes the intercept
mse = sum(x->x^2, y-ŷ) / n

# return the final score we want to maximize (which is technically -2BIC)
return -n*log(mse) - p*k*log(n)
end



4 changes: 2 additions & 2 deletions test/ges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ end
X = [x v w z s]


g, _ = ges(X)
g, _ = ges(X; penalty=2.0)

@test sort(collect(Graphs.edges(g))) == sort(Edge.([1 => 2
@test sort(vpairs(g)) == sort(([1 => 2
1 => 3
2 => 1
2 => 4
Expand Down
42 changes: 42 additions & 0 deletions test/gesvsR.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using Distributions
using Graphs
using DelimitedFiles
using LinearAlgebra
using Test
using CausalInference
using Random
@testset "ges vs GES from pcalg (R)" begin
Random.seed!(1)
data = [:nci60, :gmG8][2]

score_R = Dict([:nci60 => -27741.897919271432, :gmG8 => -20106.287947087083])[data]

if data == :nci60
X = Matrix(Float64.(readdlm(joinpath(@__DIR__, "..", "nci60.csv"), ',')[2:end, 2:end]))
elseif data == :gmG8
X = Matrix(Float64.(readdlm(joinpath(@__DIR__, "..", "gmG8.csv"), ',')[2:end, 2:end]))
end
n, d = size(X)
C = Symmetric(cov(X, dims = 1, corrected=false))

if data == :nci60
g3 = DiGraph(map(x->x=="TRUE", readdlm(joinpath(@__DIR__, "..", "nci60adj.csv"), ',')[2:end, 2:end]))
else
g3 = DiGraph(map(x->x=="TRUE", readdlm(joinpath(@__DIR__, "..", "gmG8adj.csv"), ',')[2:end, 2:end]))
end
penalty = 1.0

g2, s = ges(X; penalty, method=:gaussian_bic_raw)
g2b, sb, (t1b, t2b) = ges(X; penalty)
@test g2 == g2b
@test s sb
#g2c, sc, (t1c, t2c) = ges(X; penalty, parallel=true)
@test score_R score_dag(DiGraph(d), GaussianScore(C, n, penalty)) + s
@show score_R score_dag(pdag2dag!(copy(g2)), GaussianScore(C, n, penalty))
@show score_R score_dag(pdag2dag!(copy(g3)), GaussianScore(C, n, penalty))

@test isempty(symdiff(vpairs(g2), vpairs(g2b)))

@test isempty(symdiff(vpairs(g2), vpairs(g3)))

end
Loading

0 comments on commit 14074af

Please sign in to comment.