Official implementation for "Flex-MoE: Modeling Arbitrary Modality Combination via the Flexible Mixture-of-Experts" accepted by NeurIPS 2024 (Spotlight).
- Authors: Sukwon Yun, Inyoung Choi, Jie Peng, Yangfan Wu, Jingxuan Bao, Qiyiwen Zhang, Jiayi Xin, Qi Long and Tianlong Chen
Multimodal learning has gained increasing importance across various fields, offering the ability to integrate data from diverse sources such as images, text, and personalized records, which are frequently observed in medical domains. However, in scenarios where some modalities are missing, many existing frameworks struggle to accommodate arbitrary modality combinations, often relying heavily on a single modality or complete data. This oversight of potential modality combinations limits their applicability in real-world situations. To address this challenge, we propose Flex-MoE (Flexible Mixture-of-Experts), a new framework designed to flexibly incorporate arbitrary modality combinations while maintaining robustness to missing data. The core idea of Flex-MoE is to first address missing modalities using a new missing modality bank that integrates observed modality combinations with the corresponding missing ones. This is followed by a uniquely designed Sparse MoE framework. Specifically, Flex-MoE first trains experts using samples with all modalities to inject generalized knowledge through the generalized router (
conda create -n flex-moe python=3.10 -y
conda activate flex-moe
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia -y
conda install nvidia::cudatoolkit conda-forge::cudatoolkit-dev -y
conda install -c conda-forge gxx=10.4 -y
export LD_LIBRARY_PATH=/path_to_anaconda3/flex-moe/lib:$LD_LIBRARY_PATH
export CUDA_HOME=/path_to_anaconda3/envs/flex-moe
git clone https://github.com/laekov/fastmoe.git
cd fastmoe
USE_NCCL=0 python setup.py install
cd ..
conda install dm-tree scikit-learn tqdm pandas -y
conda install -c conda-forge scanpy nibabel -y
*For the details, please refer to data/adni/README.md
-
To access ADNI dataset, please first visit ADNI website: https://adni.loni.usc.edu/ and apply for the data access here: https://ida.loni.usc.edu/collaboration/access/appApply.jsp?project=ADNI.
-
Once you get an access, login to IDA and dowlonad necessary files for each modality.
- Image:
- {Subject ID}_{Scan date}
- Genomic:
- ADNI 1 SNP genotype data - PLINK
- ADNI GO/2 SNP genotype data - Complete PLINK for sets 1 - 9
- ADNI GO/2 SNP genotype data - Complete PLINK for sets 10 - 15
- ADNI3 SNP genotype data Set 1 - PLINK
- ADNI3 SNP genotype data Set 2 - PLINK
- Clinical:
- MEDHIST
- NEUROEXM
- PTDEMOG
- RECCMEDS
- VITALS
- Biospecimen:
- APOERES
- UPENNBIOMK_ROCHE_ELECSYS
- Image:
-
Once the files are ready, please run preprocessing codes in each modality.
To run Flex-MoE in other datasets based on existing dataloader format, please customize below two parts.
-
main.py
- As in line 114 in
main.py
, please define dataset specificmodality_dict
. - While setting the modality name, please (1) use small letter while (2) enusring each modality's first letter is unique.
- Example:
if args.data == 'your_new': modality_dict = {'a_modality':0, 'b_modality': 1, 'c_modality': 2}
- As in line 114 in
-
data.py
- Please (1) set an appropriate path to the dataset and (2) cutomize each modality with its own speicifc preprocessing (e.g., initial fillling, MinMaxScaler).
- Example:
def load_and_preprocess_new_data(args, modality_dict): # Paths a_modality_path = './data/your_dataset/a_modality' b_modality_path = './data/your_dataset/b_modality' c_modality_path = './data/your_dataset/c_modality' label_df = pd.read_csv('./data/your_dataset/your_label_csv_file', index_col='your_index') labels = label_df['your_label'].values.astype(np.int64) with open('./data/your_dataset/your_index_train_valid_test_split.json') as json_file: data_split = json.load(json_file) id_to_idx = {id: idx for idx, id in enumerate(label_df.index)} common_idx_list = [] observed_idx_arr = np.zeros((labels.shape[0], args.n_full_modalities), dtype=bool) # Load modalities (a_modality as an example) if 'A' in args.modality or 'a' in args.modality: if args.initial_filling == 'mean': path = a_modality_path + '_mean.csv' else: path = a_modality_path + '.csv' ...
Run Example
python main.py \
--data adni \
--modality IG \
--lr 1e-3 \
--num_experts 16 \
--num_layers_fus 1 \
--top_k 4 \
--train_epochs 50 \
--warm_up_epochs 5 \
--hidden_dim 128 \
--batch_size 16 \
--n_runs 3 \
--device 0
IGCB Result
sh run_adni_igcb.sh
@misc{yun2024flexmoemodelingarbitrarymodality,
title={Flex-MoE: Modeling Arbitrary Modality Combination via the Flexible Mixture-of-Experts},
author={Sukwon Yun and Inyoung Choi and Jie Peng and Yangfan Wu and Jingxuan Bao and Qiyiwen Zhang and Jiayi Xin and Qi Long and Tianlong Chen},
year={2024},
eprint={2410.08245},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2410.08245},
}