By Yikai Wang, Xinghao Chen, Lele Cao, Wenbing Huang, Fuchun Sun, Yunhe Wang.
This repository is a PyTorch implementation of "Multimodal Token Fusion for Vision Transformers", in CVPR 2022.
Homogeneous predictions,
Heterogeneous predictions,
For semantic segmentation task on NYUDv2 (official dataset), we provide a link to download the dataset here. The provided dataset is originally preprocessed in this repository, and we add depth data in it.
For image-to-image translation task, we use the sample dataset of Taskonomy, where a link to download the sample dataset is here.
Please modify the data paths in the codes, where we add comments 'Modify data path'.
python==3.6
pytorch==1.7.1
torchvision==0.8.2
numpy==1.19.2
First,
cd semantic_segmentation
Download the segformer pretrained model (pretrained on ImageNet) from weights, e.g., mit_b3.pth. Move this pretrained model to folder 'pretrained'.
Training script for segmentation with RGB and Depth input,
python main.py --backbone mit_b3 -c exp_name --lamda 1e-6 --gpu 0 1 2
Evaluation script,
python main.py --gpu 0 --resume path_to_pth --evaluate # optionally use --save-img to visualize results
Checkpoint models, training logs, mask ratios and the single-scale performance on NYUDv2 are provided as follows:
Method | Backbone | Pixel Acc. (%) | Mean Acc. (%) | Mean IoU (%) | Download |
---|---|---|---|---|---|
CEN | ResNet101 | 76.2 | 62.8 | 51.1 | Google Drive |
CEN | ResNet152 | 77.0 | 64.4 | 51.6 | Google Drive |
Ours | SegFormer-B3 | 78.7 | 67.5 | 54.8 | Google Drive |
Mindspore implementation is available at: https://gitee.com/mindspore/models/tree/master/research/cv/TokenFusion
First,
cd image2image_translation
Training script, from Shade and Texture to RGB,
python main.py --gpu 0 -c exp_name
This script will auto-evaluate on the validation dataset every 5 training epochs.
Predicted images will be automatically saved during training, in the following folder structure:
code_root/ckpt/exp_name/results
├── input0 # 1st modality input
├── input1 # 2nd modality input
├── fake0 # 1st branch output
├── fake1 # 2nd branch output
├── fake2 # ensemble output
├── best # current best output
│ ├── fake0
│ ├── fake1
│ └── fake2
└── real # ground truth output
Checkpoint models:
Method | Task | FID | KID | Download |
---|---|---|---|---|
CEN | Texture+Shade->RGB | 62.6 | 1.65 | - |
Ours | Texture+Shade->RGB | 45.5 | 1.00 | Google Drive |
Data preparation, environments, and training scripts follow Group-Free and ImVoteNet.
E.g.,
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --master_port 2229 --nproc_per_node 4 train_dist.py --max_epoch 600 --val_freq 25 --save_freq 25 --lr_decay_epochs 420 480 540 --num_point 20000 --num_decoder_layers 6 --size_cls_agnostic --size_delta 0.0625 --heading_delta 0.04 --center_delta 0.1111111111111 --weight_decay 0.00000001 --query_points_generator_loss_coef 0.2 --obj_loss_coef 0.4 --dataset sunrgbd --data_root . --use_img --log_dir log/exp_name
If you find our work useful for your research, please consider citing the following paper.
@inproceedings{wang2022tokenfusion,
title={Multimodal Token Fusion for Vision Transformers},
author={Wang, Yikai and Chen, Xinghao and Cao, Lele and Huang, Wenbing and Sun, Fuchun and Wang, Yunhe},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2022}
}