Skip to content

Commit

Permalink
Scm/equations (#155)
Browse files Browse the repository at this point in the history
* added retrieval of equations from CausalGraph

* Exclude test_regression.ipynb from merge with main

* added DAG checking, OLS residuals/intercept, tables.jl, SCM dag . toposort for data generation

* put ols function outside of estimate function

* removed DataFrames.jl as dependency, exported SCM functions in CI.jl and inserted a runnable test for them in test/equations.jl

* Delete Manifest.toml

* Reversed Project.toml

* readded svg figure

* included equations.jl in runtests.jl

* Removed TikzGraphs from test/equations.jl
  • Loading branch information
JorgeLuizFranco authored Jun 24, 2024
1 parent aac0ac8 commit 408ab81
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/CausalInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ export pdag2dag!, pdag_to_dag_meek!, pdag_to_dag_dortarsi!
export count_moves_uniform, randcpdag, UniformScore, causalzigzag, dagzigzag
export keyedreduce

export estimate_equations, generate_data

#include("pinv.jl")
include("graphs.jl")
include("combinations_without.jl")
Expand All @@ -55,6 +57,7 @@ include("sampler.jl")
include("dag_sampler.jl")
include("misc2.jl")
include("exact.jl")
include("equations.jl")
#include("mcs.jl")

# Compatibility with the new "Package Extensions" (https://github.com/JuliaLang/julia/pull/47695)
Expand All @@ -73,4 +76,4 @@ function __init__()

end
end
end # end of module
end # end of module
135 changes: 135 additions & 0 deletions src/equations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
using LinearAlgebra, Graphs, Tables, Random, Statistics

# Define the SCM struct
"""
struct SCM
variables::Vector{<:AbstractString}
coefficients::Vector{<:Vector{<:AbstractFloat}}
residuals::Vector{<:Vector{<:AbstractFloat}}
dag::DiGraph
A struct representing a Structural Causal Model (SCM).
# Fields
- `variables::Vector{<:AbstractString}`: A list of variable names.
- `coefficients::Vector{<:Vector{<:AbstractFloat}}`: A list of coefficient vectors for each variable.
- `residuals::Vector{<:Vector{<:AbstractFloat}}`: A list of residuals for each variable.
- `dag::DiGraph`: The directed graph representing the structure of the SCM.
"""
struct SCM
variables::Vector{String}
coefficients::Vector{Vector{Float64}}
residuals::Vector{Vector{Float64}}
dag::DiGraph
end

function ols_compute(X, y)
X = hcat(ones(size(X, 1)), X)
coef = X \ y
yhat = X * coef
resids = y - yhat
return coef, resids
end

# Function to estimate equations and return an SCM struct
"""
estimate_equations(t, est_g::DiGraph)::SCM
Estimate linear equations from the given table `t` based on the structure of the directed graph `est_g`.
# Arguments
- `t`: A table containing the data for estimation (supports any Tables.jl-compatible format).
- `est_g::DiGraph`: A directed graph representing the structure of the SCM.
# Returns
- `SCM`: A struct containing the estimated variables, their corresponding coefficients, residuals, and the DAG.
"""
function estimate_equations(t, est_g::DiGraph)::SCM
Tables.istable(t) || throw(ArgumentError("Argument supports just Tables.jl types"))

columns = Tables.columns(t)
schema = Tables.schema(t)
variables = propertynames(schema.names)

# Check if it is a DAG
if is_cyclic(est_g)
throw(ArgumentError("The provided graph is cyclic -> est_g::DiGraph should be a DAG."))
end

adj_list = collect(edges(est_g))

var_names = String[]
coefficients = Vector{Vector{Float64}}()
residuals = Vector{Vector{Float64}}()
nodes = variables

for node in nodes
node_index = findfirst(==(node), nodes)
preds = [nodes[e.src] for e in adj_list if e.dst == node_index]

if !isempty(preds)
X = hcat([columns[pred] for pred in preds]...)
y = columns[node]

coef, resid = ols_compute(X, y)

if isa(coef, Vector)
push!(var_names, string(node))
push!(coefficients, coef)
push!(residuals, resid)
else
println("Warning: Coefficients not stored for node $node. Expected vector, got $coef")
end
else
y = columns[node]
intercept = mean(y)
resid = y .- intercept
push!(var_names, string(node))
push!(coefficients, [intercept])
push!(residuals, resid)
end
end

return SCM(var_names, coefficients, residuals, est_g)
end

# Function to generate data from the SCM
"""
generate_data(scm::SCM, N::Int)::NamedTuple
Generate data from the given SCM.
# Arguments
- `scm::SCM`: The structural causal model.
- `N::Int`: The number of data points to generate.
# Returns
- `NamedTuple`: A NamedTuple containing the generated data.
"""
function generate_data(scm::SCM, N::Int)::NamedTuple
columns = Dict{Symbol, Vector{Float64}}()

sorted_indices = topological_sort_by_dfs(scm.dag)
sorted_variables = [scm.variables[i] for i in sorted_indices]
variable_index_map = Dict(variable => index for (index, variable) in enumerate(scm.variables))

for node in sorted_variables
idx = variable_index_map[node]
coef = scm.coefficients[idx]
residual_std = std(scm.residuals[idx])

if length(coef) == 1
columns[Symbol(node)] = coef[1] .+ residual_std * randn(N)
else
preds = [Symbol(scm.variables[i]) for i in inneighbors(scm.dag, idx)]
if isempty(preds)
columns[Symbol(node)] = coef[1] .+ residual_std * randn(N)
else
X = hcat(ones(N), [columns[pred] for pred in preds]...)
columns[Symbol(node)] = X * coef .+ residual_std * randn(N)
end
end
end

return NamedTuple(columns)
end
39 changes: 39 additions & 0 deletions test/equations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using CausalInference
using Random
Random.seed!(1)

# Generate some sample data to use with the GES algorithm

N = 2000 # number of data points

# define simple linear model with added noise
x = randn(N)
v = x + randn(N)*0.25
w = x + randn(N)*0.25
z = v + w + randn(N)*0.25
s = z + randn(N)*0.25

df = (x=x, v=v, w=w, z=z, s=s)

est_g, score = ges(df; penalty=1.0, parallel=true)


est_dag= pdag2dag!(est_g)

scm= estimate_equations(df,est_dag)

display(scm)

#println(CI.SCM)

df_generated= generate_data(scm, 2000)

println("df: ")

display(df)

println("df_generated: ")



display(df_generated)
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ include("witness.jl")
include("fci.jl")
include("klentropy.jl")
include("backdoor.jl")
include("equations.jl")

1 comment on commit 408ab81

@pat-alt
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mschauer can we get this registered so we can use it with CounterfactualExplanations.jl? 😃 cc @JorgeLuizFranco

Please sign in to comment.