- Environment
- Download, extract and Generate metadata for datasets
- Reproducing Paper Results
- Additional Support/Issues?
- Citation
We use Miniconda to manage the environment. Our Python version is 3.11.5
. To create the environment, run the following command:
conda env create -f environment.yml -n mtl-group-robustness-env
To activate the environment, run the following command:
conda activate mtl-group-robustness-env
To downloads, extracts and formats the datasets as per the code, run the following script. This will store the data and metadata in the data
folder. It already contains the civilcomments-small
dataset.
python3 ./src/setup_datasets.py dataset_name --download --data_path data
The ./src/hparams.yaml
file includes the optimal hyperparameters for each method across all five datasets. To get started, execute the following command to generate Python scripts for training with the best hyperparameters.
python3 ./src/generate_hyper_search_scripts.py --dataset waterbirds --method erm_mt_l1
This will create a txt file in the hparams_files
folder, containing the Python script for five seeds. It will also generate an executable bash file in the scripts
folder. To start training run the following command:
sbatch ./scripts/train_waterbirds_erm_mt_l1_hp.sh
This will store the best results as a json file for each run in the models_params
folder.
If you face any issues in our code / reporducing our results raise a Github issue or contact Atharva Kulkarni (atharvak@cs.cmu.edu)
@article{
kulkarni2024multitask,
title={Multitask Learning Can Improve Worst-Group Outcomes},
author={Atharva Kulkarni and Lucio M. Dery and Amrith Setlur and Aditi Raghunathan and Ameet Talwalkar and Graham Neubig},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2024},
url={https://openreview.net/forum?id=sPlhAIp6mk},
note={}
}