Skip to content

Commit

Permalink
add support for composed estimands in from_param_file (#190)
Browse files Browse the repository at this point in the history
* add support for composed estimands in from_param_file

* up manifest
  • Loading branch information
olivierlabayle authored Apr 4, 2024
1 parent 7795fe7 commit 58148df
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 274 deletions.
170 changes: 73 additions & 97 deletions Manifest.toml

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions src/tl_inputs/from_actors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ function control_case_settings(::Type{TMLE.StatisticalATE}, treatments, data)
end

function addEstimands!(estimands, treatments, variables, data; positivity_constraint=0.)
freqs = TargeneCore.frequency_table(data, treatments)
freqs = TMLE.frequency_table(data, treatments)
# This loop adds all ATE estimands where all other treatments than
# the bQTL are fixed, at the order 1, this is the simple bQTL's ATE
for setting in control_case_settings(TMLE.StatisticalATE, treatments, data)
Expand All @@ -134,7 +134,7 @@ function addEstimands!(estimands, treatments, variables, data; positivity_constr
treatment_confounders = NamedTuple{keys(setting)}([variables.confounders for key in keys(setting)]),
outcome_extra_covariates = variables.covariates
)
if satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint)
if TMLE.satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint)
update_estimands_from_outcomes!(estimands, Ψ, variables.targets)
end
end
Expand All @@ -147,7 +147,7 @@ function addEstimands!(estimands, treatments, variables, data; positivity_constr
treatment_confounders = NamedTuple{keys(setting)}([variables.confounders for key in keys(setting)]),
outcome_extra_covariates = variables.covariates
)
if satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint)
if TMLE.satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint)
update_estimands_from_outcomes!(estimands, Ψ, variables.targets)
end
end
Expand Down
81 changes: 66 additions & 15 deletions src/tl_inputs/from_param_files.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ MismatchedCaseControlEncodingError() =

NoRemainingParamsError(positivity_constraint) = ArgumentError(string("No parameter passed the given positivity constraint: ", positivity_constraint))

MismatchedVariableError(variable) = ArgumentError(string("Each component of a ComposedEstimand should contain the same ", variable, " variables."))

function check_genotypes_encoding(val::NamedTuple, type)
if !(typeof(val.case) <: type && typeof(val.control) <: type)
Expand All @@ -27,17 +28,66 @@ check_genotypes_encoding(val::T, type) where T =
T <: type || throw(MismatchedCaseControlEncodingError())


get_treatments(Ψ) = keys.treatment_values)

function get_treatments::ComposedEstimand)
treatments = get_treatments(first.args))
if length.args) > 1
for arg in Ψ.args[2:end]
get_treatments(arg) == treatments || throw(MismatchedVariableError("treatments"))
end
end
return treatments
end

get_confounders(Ψ) = Tuple(Iterators.flatten((Tconf for Tconf Ψ.treatment_confounders)))

function get_confounders::ComposedEstimand)
confounders = get_confounders(first.args))
if length.args) > 1
for arg in Ψ.args[2:end]
get_confounders(arg) == confounders || throw(MismatchedVariableError("confounders"))
end
end
return confounders
end

get_outcome_extra_covariates(Ψ) = Ψ.outcome_extra_covariates

function get_outcome_extra_covariates::ComposedEstimand)
outcome_extra_covariates = get_outcome_extra_covariates(first.args))
if length.args) > 1
for arg in Ψ.args[2:end]
get_outcome_extra_covariates(arg) == outcome_extra_covariates || throw(MismatchedVariableError("outcome extra covariates"))
end
end
return outcome_extra_covariates
end

get_outcome(Ψ) = Ψ.outcome

function get_outcome::ComposedEstimand)
outcome = get_outcome(first.args))
if length.args) > 1
for arg in Ψ.args[2:end]
get_outcome(arg) == outcome || throw(MismatchedVariableError("outcome"))
end
end
return outcome
end

function get_variables(estimands, traits, pcs)
genetic_variants = Set{Symbol}()
others = Set{Symbol}()
pcs = Set{Symbol}(filter(x -> x != :SAMPLE_ID, propertynames(pcs)))
alltraits = Set{Symbol}(filter(x -> x != :SAMPLE_ID, propertynames(traits)))
for Ψ in estimands
treatments = keys.treatment_values)
confounders = Iterators.flatten((Tconf for Tconf Ψ.treatment_confounders))
treatments = get_treatments(Ψ)
confounders = get_confounders(Ψ)
outcome_extra_covariates = get_outcome_extra_covariates(Ψ)
push!(
others,
Ψ.outcome_extra_covariates...,
outcome_extra_covariates...,
confounders...,
treatments...
)
Expand Down Expand Up @@ -123,6 +173,8 @@ function adjust_parameter_sections(Ψ::T, variants_alleles, pcs) where T<:TMLE.E
return T(outcome=Ψ.outcome, treatment_values=treatments, treatment_confounders=confounders, outcome_extra_covariates=Ψ.outcome_extra_covariates)
end

adjust_parameter_sections::ComposedEstimand, variants_alleles, pcs) =
ComposedEstimand.f, Tuple(adjust_parameter_sections(arg, variants_alleles, pcs) for arg in Ψ.args))

function append_from_valid_estimands!(
estimands::Vector{<:TMLE.Estimand},
Expand All @@ -136,29 +188,28 @@ function append_from_valid_estimands!(
# Update treatment's and confounders's sections of Ψ
Ψ = adjust_parameter_sections(Ψ, variants_alleles, variables.pcs)
# Update frequency tables with current treatments
treatments = sorted_treatment_names(Ψ)
treatments = get_treatments(Ψ)
if !haskey(frequency_tables, treatments)
frequency_tables[treatments] = TargeneCore.frequency_table(data, collect(treatments))
frequency_tables[treatments] = TMLE.frequency_table(data, treatments)
end
# Check if parameter satisfies positivity
satisfies_positivity(Ψ, frequency_tables[treatments];
positivity_constraint=positivity_constraint) || return
# Expand wildcard to all outcomes
if Ψ.outcome === :ALL
update_estimands_from_outcomes!(estimands, Ψ, variables.outcomes)
else
# Ψ.target || MissingVariableError(variable)
push!(estimands, Ψ)
if TMLE.satisfies_positivity(Ψ, frequency_tables[treatments]; positivity_constraint=positivity_constraint)
# Expand wildcard to all outcomes
if get_outcome(Ψ) === :ALL
update_estimands_from_outcomes!(estimands, Ψ, variables.outcomes)
else
push!(estimands, Ψ)
end
end
end

function adjusted_estimands(estimands, variables, data; positivity_constraint=0.)
final_estimands = TMLE.Estimand[]
variants_alleles = Dict(v => Set(unique(skipmissing(data[!, v]))) for v in variables.genetic_variants)
freqency_tables = Dict()
frequency_tables = Dict()
for Ψ in estimands
# If the genotypes encoding is a string representation make sure they match the actual genotypes
append_from_valid_estimands!(final_estimands, freqency_tables, Ψ, data, variants_alleles, variables; positivity_constraint=positivity_constraint)
append_from_valid_estimands!(final_estimands, frequency_tables, Ψ, data, variants_alleles, variables; positivity_constraint=positivity_constraint)
end

length(final_estimands) > 0 || throw(NoRemainingParamsError(positivity_constraint))
Expand Down
68 changes: 21 additions & 47 deletions src/tl_inputs/tl_inputs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ NotAllVariantsFoundError(rsids) =
ArgumentError(string("Some variants were not found in the genotype files: ", join(rsids, ", ")))

NotBiAllelicOrUnphasedVariantError(rsid) = ArgumentError(string("Variant: ", rsid, " is not bi-allelic or not unphased."))

"""
bgen_files(snps, bgen_prefix)
Expand Down Expand Up @@ -103,47 +104,8 @@ function call_genotypes(bgen_prefix::String, query_rsids::Set{<:AbstractString},
return genotypes
end

sorted_treatment_names(Ψ) = tuple(sort(collect(keys.treatment_values)))...)

function setting_iterator::TMLE.StatisticalIATE)
treatments = sorted_treatment_names(Ψ)
return (
NamedTuple{treatments}(collect(Tval)) for
Tval in Iterators.product((values.treatment_values[T]) for T in treatments)...)
)
end

function setting_iterator::TMLE.StatisticalATE)
treatments = sorted_treatment_names(Ψ)
return (
NamedTuple{treatments}([(Ψ.treatment_values[T][c]) for T in treatments])
for c in (:case, :control)
)
end

function setting_iterator::TMLE.StatisticalCM)
treatments = sorted_treatment_names(Ψ)
return (NamedTuple{treatments}.treatment_values[T] for T in treatments), )
end

function satisfies_positivity::TMLE.Estimand, freqs; positivity_constraint=0.01)
for base_setting in setting_iterator(Ψ)
if !haskey(freqs, base_setting) || freqs[base_setting] < positivity_constraint
return false
end
end
return true
end

function frequency_table(data, treatments::AbstractVector)
treatments = sort(treatments)
freqs = Dict()
N = nrow(data)
for (key, group) in pairs(groupby(data, treatments; skipmissing=true))
freqs[NamedTuple(key)] = nrow(group) / N
end
return freqs
end
TMLE.satisfies_positivity::ComposedEstimand, freqs; positivity_constraint=0.01) =
all(TMLE.satisfies_positivity(arg, freqs; positivity_constraint=positivity_constraint) for arg in Ψ.args)

read_txt_file(path::Nothing) = nothing
read_txt_file(path) = CSV.read(path, DataFrame, header=false)[!, 1]
Expand All @@ -164,15 +126,27 @@ function merge(traits, pcs, genotypes)
)
end

estimand_with_new_outcome::T, outcome) where T = T(
outcome=outcome,
treatment_values=Ψ.treatment_values,
treatment_confounders=Ψ.treatment_confounders,
outcome_extra_covariates=Ψ.outcome_extra_covariates
)

function update_estimands_from_outcomes!(estimands, Ψ::T, outcomes) where T
for outcome in outcomes
push!(
estimands,
T(
outcome=outcome,
treatment_values=Ψ.treatment_values,
treatment_confounders=Ψ.treatment_confounders,
outcome_extra_covariates=Ψ.outcome_extra_covariates)
estimands,
estimand_with_new_outcome(Ψ, outcome)
)
end
end

function update_estimands_from_outcomes!(estimands, Ψ::ComposedEstimand, outcomes)
for outcome in outcomes
push!(
estimands,
ComposedEstimand.f, Tuple(estimand_with_new_outcome(arg, outcome) for arg in Ψ.args))
)
end
end
Expand Down
67 changes: 58 additions & 9 deletions test/tl_inputs/from_param_files.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,40 @@ include(joinpath(TESTDIR, "tl_inputs", "test_utils.jl"))
pcs = TargeneCore.read_csv_file(joinpath(TESTDIR, "data", "pcs.csv"))
# extraW, extraT, extraC are parsed from all param_files
estimands = make_estimands_configuration().estimands
# get_treatments, get_outcome, ...
## Simple Estimand
Ψ = estimands[1]
@test TargeneCore.get_outcome(Ψ) == :ALL
@test TargeneCore.get_treatments(Ψ) == keys.treatment_values)
@test TargeneCore.get_confounders(Ψ) == ()
@test TargeneCore.get_outcome_extra_covariates(Ψ) == ()
## ComposedEstimand
Ψ = estimands[5]
@test TargeneCore.get_outcome(Ψ) == :ALL
@test TargeneCore.get_treatments(Ψ) == keys.args[1].treatment_values)
@test TargeneCore.get_confounders(Ψ) == ()
@test TargeneCore.get_outcome_extra_covariates(Ψ) == (Symbol("22001"), )
## Bad ComposedEstimand
Ψ = ComposedEstimand(
TMLE.joint_estimand, (
CM(
outcome = "Y1",
treatment_values = (RSID_3 = "GG", RSID_198 = "AG"),
treatment_confounders = (RSID_3 = [], RSID_198 = []),
outcome_extra_covariates = [22001]
),
CM(
outcome = "Y2",
treatment_values = (RSID_2 = "AA", RSID_198 = "AG"),
treatment_confounders = (RSID_2 = [:PC1], RSID_198 = []),
outcome_extra_covariates = []
))
)
@test_throws ArgumentError TargeneCore.get_outcome(Ψ) == :ALL
@test_throws ArgumentError TargeneCore.get_treatments(Ψ)
@test_throws ArgumentError TargeneCore.get_confounders(Ψ)
@test_throws ArgumentError TargeneCore.get_outcome_extra_covariates(Ψ)
# get_variables
variables = TargeneCore.get_variables(estimands, traits, pcs)
@test variables.genetic_variants == Set([:RSID_198, :RSID_2])
@test variables.outcomes == Set([:BINARY_1, :CONTINUOUS_2, :CONTINUOUS_1, :BINARY_2])
Expand All @@ -38,8 +72,9 @@ end
)
pcs = Set([:PC1, :PC2])
variants_alleles = Dict(:RSID_198 => Set(genotypes.RSID_198))
# AG is not in the genotypes but GA is
Ψ = make_estimands_configuration().estimands[4]
estimands = make_estimands_configuration().estimands
# RS198 AG is not in the genotypes but GA is
Ψ = estimands[4]
@test Ψ.treatment_values.RSID_198 == (case="AG", control="AA")
new_Ψ = TargeneCore.adjust_parameter_sections(Ψ, variants_alleles, pcs)
@test new_Ψ.outcome == Ψ.outcome
Expand All @@ -50,6 +85,19 @@ end
RSID_2 = (case = "AA", control = "GG")
)

# ComnposedEstimand
Ψ = estimands[5]
@test Ψ.args[1].treatment_values == (RSID_198 = "AG", RSID_2 = "GG")
@test Ψ.args[2].treatment_values == (RSID_198 = "AG", RSID_2 = "AA")
new_Ψ = TargeneCore.adjust_parameter_sections(Ψ, variants_alleles, pcs)
for index in 1:length.args)
@test new_Ψ.args[index].outcome == Ψ.args[index].outcome
@test new_Ψ.args[index].outcome_extra_covariates == (Symbol(22001),)
@test new_Ψ.args[index].treatment_confounders == (RSID_198 = (:PC1, :PC2), RSID_2 = (:PC1, :PC2),)
end
@test new_Ψ.args[1].treatment_values == (RSID_198 = "GA", RSID_2 = "GG")
@test new_Ψ.args[2].treatment_values == (RSID_198 = "GA", RSID_2 = "AA")

# If the allele is not present
variants_alleles = Dict(:RSID_198 => Set(["AA"]))
@test_throws TargeneCore.AbsentAlleleError("RSID_198", "AG") TargeneCore.adjust_parameter_sections(Ψ, variants_alleles, pcs)
Expand Down Expand Up @@ -95,8 +143,8 @@ end

## Estimands file:
output_estimands = deserialize("final.estimands.jls").estimands
# There are 5 initial estimands containing a *
# Those are duplicated for each of the 4 targets.
# There are 5 initial estimands containing a :ALL
# Those are duplicated for each of the 4 outcomes.
@test length(output_estimands) == 20
# In all cases the PCs are appended to the confounders.
for Ψ output_estimands
Expand All @@ -120,10 +168,11 @@ end
@test Ψ.outcome_extra_covariates == (Symbol("22001"),)

# Input Estimand 5: GA is corrected to AG to match the data
elseif Ψ isa TMLE.StatisticalCM && Ψ.treatment_values == (RSID_198 = "AG", RSID_2 = "GG")
@test Ψ.treatment_confounders == (RSID_198 = (:PC1, :PC2), RSID_2 = (:PC1, :PC2))
@test Ψ.outcome_extra_covariates == (Symbol("22001"),)

elseif Ψ isa TMLE.ComposedEstimand
@test Ψ.args[1].treatment_values == (RSID_198 = "AG", RSID_2 = "GG")
@test Ψ.args[2].treatment_values == (RSID_198 = "AG", RSID_2 = "AA")
@test Ψ.args[1].treatment_confounders == Ψ.args[2].treatment_confounders == (RSID_198 = (:PC1, :PC2), RSID_2 = (:PC1, :PC2))
@test Ψ.args[1].outcome_extra_covariates == Ψ.args[2].outcome_extra_covariates == (Symbol("22001"),)
else
throw(AssertionError(string("Which input did this output come from: ", Ψ)))
end
Expand All @@ -142,7 +191,7 @@ end
tl_inputs(parsed_args)
# The IATES are the most sensitives
outestimands = deserialize("final.estimands.jls").estimands
@test allisa Union{TMLE.StatisticalCM, TMLE.StatisticalATE} for Ψ in outestimands)
@test allisa Union{TMLE.StatisticalCM, TMLE.StatisticalATE, ComposedEstimand} for Ψ in outestimands)
@test size(outestimands, 1) == 16

cleanup()
Expand Down
20 changes: 14 additions & 6 deletions test/tl_inputs/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ function cleanup(;prefix="final.")
end
end


function make_estimands_configuration()
estimands = [
IATE(
Expand All @@ -32,11 +31,20 @@ function make_estimands_configuration()
treatment_confounders = (RSID_2 = [], RSID_198 = []),
outcome_extra_covariates = [22001]
),
CM(
outcome = "ALL",
treatment_values = (RSID_2 = "GG", RSID_198 = "GA"),
treatment_confounders = (RSID_2 = [], RSID_198 = []),
outcome_extra_covariates = [22001]
ComposedEstimand(
TMLE.joint_estimand, (
CM(
outcome = "ALL",
treatment_values = (RSID_2 = "GG", RSID_198 = "AG"),
treatment_confounders = (RSID_2 = [], RSID_198 = []),
outcome_extra_covariates = [22001]
),
CM(
outcome = "ALL",
treatment_values = (RSID_2 = "AA", RSID_198 = "AG"),
treatment_confounders = (RSID_2 = [], RSID_198 = []),
outcome_extra_covariates = [22001]
))
)
]
return Configuration(estimands=estimands)
Expand Down
Loading

0 comments on commit 58148df

Please sign in to comment.