Skip to content

Code to reproduce the paper "Do causal predictors generalize better to new domains?"

License

Notifications You must be signed in to change notification settings

socialfoundations/causal-features

 
 

Repository files navigation

❗ This repository is heavily based on the TableShift GitHub repository. Our code is built on top of TableShift and code from Ricardo Sandoval and Hardt & Kim (2023).

Predictors from causal features do not generalize better to new domains

This is code to reproduce experiments in the paper:

Vivian Y. Nastl and Moritz Hardt. "Do causal predictors generalize better to new domains?", 2024.

Quickstart

Simply clone the repo, enter the root directory and create a local conda environment.

git clone https://github.com/socialfoundations/causal-features.git
# set up the environment
conda env create -f environment.yml

Run the following commands to test the local execution environment:

conda env create -f environment.yml
conda activate tableshift
# test the install by running the training script
python examples/run_expt.py

The final line above will print some detailed logging output as the script executes. When you see training completed! test accuracy: 0.6221 your environment is ready to go! (Accuracy may vary slightly due to randomness.)

Reproducing the experiments in the paper

The training script we run is located at experiments_causal/run_experiment.py. It takes the following arguments:

  • experiment (experiment to run)
  • model (model to use)
  • cache_dir (directory to cache raw data files to)
  • save_dir (directory to save result files to)

The full list of model names is given below. For more details on each algorithm, see TableShift.

Model Name in TableShift
XGBoost xgb
LightGBM lightgbm
SAINT saint
NODE node
Group DRO group_dro
MLP mlp
Tabular ResNet resnet
Adversarial Label DRO aldro
CORAL deepcoral
MMD mmd
DRO dro
DANN dann
TabTransformer tabtransformer
MixUp mixup
Label Group DRO label_group_dro
IRM irm
VREX vrex
FT-Transformer ft_transformer
IB-IRM ib_irm
CausIRL CORAL causirl_coral
CausIRL MMD causirl_mmd
AND-Mask and_mask

All experiments were run as jobs submitted to a centralized cluster, running the open-source HTCondor scheduler. The relevant script launching the jobs is located at experiments_causal/launch_experiments.py.

Raw results of experiments

We provide the raw results of our experiments in the folder experiments_causal/results/. They contain a single json file for each task, feature selection and trained model.

Reproducing the figures in the paper

Use the following Python scripts:

  • Main result:
    • Figure in introduction: experiments_causal/plot_paper_introduction_figure.py
    • Figures in section "Empirical results": experiments_causal/plot_paper_figures.py
  • Appendix:
    • Main results: experiments_causal/plot_paper_appendix_figures.py, experiments_causal/plot_paper_appendix_figures_extra.py, experiments_causal/plot_paper_appendix_figures_extra2.py
    • Anti-causal features: experiments_causal/plot_paper_appendix_figures.py
    • Causal machine learning: experiments_causal/plot_add_on_causalml.py
    • Causal discovery: experiments_causal/plot_add_on_causal_discovery.py
    • Random subsets: experiments_causal/plot_add_on_random_subsets.py
    • Ablation study: experiments_causal/plot_experiment_ablation.py
    • Empirical results across machine learning models: experiments_causal/plot_add_on_models.py
    • Synthetic experiments: experiments_causal/synthetic_experiments.ipynb

Dataset Availability

The datasets in our paper are either publicly available, or provide open credentialized access. The datasets with open credentialized access require signing a data use agreement. For the tasks ICU Mortality and ICU Length of Stay, it is required to complete training CITI Data or Specimens Only Research, as they contain sensitive personal information. Hence, these datasets must be manually fetched and stored locally.

A list of datasets, their names in our code, and the corresponding access levels are below. The string identifier is the value that should be passed as the experiment parameter to the --experiment flag of experiments_causal/run_experiment.py. The causal, arguably causal, and anti-causal feature sets are obtained by appending _causal, _arguablycausal and _anticausal to the string identifier. Combined causal and anti-causal features have the appendix _causal_anticausal. If they exist, one obtains the estimated parents from causal discovery algorithms by appending the abbreviation of the algorithms in lower letters. For example, acsincome_pc. Random subsets are indexed from 0 to 500, and callable via the appendix _random_test_{index}.

Tasks String Identifier Availability Source Preprocessing
Voting anes Public Credentialized Access (source) American National Election Studies (ANES) TableShift
ASSISTments assistments Public Kaggle TableShift
Childhood Lead nhanes_lead Public National Health and Nutrition Examination Survey (NHANES) TableShift
College Scorecard college_scorecard Public College Scorecard TableShift
Diabetes brfss_diabetes Public Behavioral Risk Factor Surveillance System (BRFSS) TableShift
Food Stamps acsfoodstamps Public American Community Survey (via folktables)
Hospital Readmission diabetes_readmission Public UCI TableShift
Hypertension brfss_blood_pressure Public Behavioral Risk Factor Surveillance System (BRFSS) TableShift
ICU Length of Stay mimic_extract_los_3 Public Credentialized Access (source) MIMIC-iii via MIMIC-Extract TableShift
ICU Mortality mimic_extract_mort_hosp Public Credentialized Access (source) MIMIC-iii via MIMIC-Extract TableShift
Income acsincome Public American Community Survey (via folktables) TableShift
Public Health Insurance acspubcov Public American Community Survey (via folktables) TableShift
Sepsis physionet Public Physionet TableShift
Unemployment acsunemployment Public American Community Survey (via folktables) TableShift
Utilization meps Public (source) Medical expenditure panel survey Hardt & Kim (2023)
Poverty sipp Public (source, source) Survey of income and program participation Hardt & Kim (2023)

TableShift includes the preprocessing of the data files in their implementation. For the tasks Utilization and Poverty, follow the instructions provided by Hardt & Kim (2023) in backward_predictor/README.md.

Differences to TableShift

We list in the following which files/folders we changed for our experiments:

  • created folder experiments_causal with python scripts to run experiments, launch experiments on a cluster, and plot figures for the paper
  • created folder backward_prediction with preprocessing files adapted from Hardt & Kim (2023) with backward_predictor/sipp/data/data_cleaning.ipynb © Ricardo Sandoval, 2024
  • added tasks meps and sipp, as well as feature selections of all tasks in their respective Python scripts in the folder tableshift/datasets
  • added data source for meps and sipp in tableshift/core/data_source.py
  • added tasks meps and sipp, as well as feature selections of all tasks in tableshift/core/tasks.py
  • added configurations for tasks and their feature selections in tableshift/configs/non_benchmark_configs.py
  • added models ib_erm, ib_irm, causirl_coral, causirl_mmd and and_mask in tableshift/models, adapted from Gulrajani & Lopez-Paz (2021)
  • added configurations for hyperparameters of added models in tableshift/configs/hparams.py
  • added computation of balanced accuracy in tableshift/models/torchutils.py and adapted tableshift/models/compat.py accordingly
  • minor fixes in tableshift/core/features.py, tableshift/core/tabular_dataset.py and tableshift/models/training.py
  • added the packages paretoset==1.2.3 and seaborn==0.13.0 in requirements.txt

Citing

This repository contains code and supplementary materials for the following preprint:

@misc{nastl2024predictors,
      title={Do causal predictors generalize better to new domains?}, 
      author={Vivian Y. Nastl and Moritz Hardt},
      year={2024},
      eprint={2402.09891},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

About

Code to reproduce the paper "Do causal predictors generalize better to new domains?"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 98.0%
  • Jupyter Notebook 2.0%