Skip to content

[CVPR 2022] Code release for "Multimodal Token Fusion for Vision Transformers"

License

Notifications You must be signed in to change notification settings

yikaiw/TokenFusion

Repository files navigation

Multimodal Token Fusion for Vision Transformers

By Yikai Wang, Xinghao Chen, Lele Cao, Wenbing Huang, Fuchun Sun, Yunhe Wang.

[Paper]

This repository is a PyTorch implementation of "Multimodal Token Fusion for Vision Transformers", in CVPR 2022.

Homogeneous predictions,

Heterogeneous predictions,

Datasets

For semantic segmentation task on NYUDv2 (official dataset), we provide a link to download the dataset here. The provided dataset is originally preprocessed in this repository, and we add depth data in it.

For image-to-image translation task, we use the sample dataset of Taskonomy, where a link to download the sample dataset is here.

Please modify the data paths in the codes, where we add comments 'Modify data path'.

Dependencies

python==3.6
pytorch==1.7.1
torchvision==0.8.2
numpy==1.19.2

Semantic Segmentation

First,

cd semantic_segmentation

Download the segformer pretrained model (pretrained on ImageNet) from weights, e.g., mit_b3.pth. Move this pretrained model to folder 'pretrained'.

Training script for segmentation with RGB and Depth input,

python main.py --backbone mit_b3 -c exp_name --lamda 1e-6 --gpu 0 1 2

Evaluation script,

python main.py --gpu 0 --resume path_to_pth --evaluate  # optionally use --save-img to visualize results

Checkpoint models, training logs, mask ratios and the single-scale performance on NYUDv2 are provided as follows:

Method Backbone Pixel Acc. (%) Mean Acc. (%) Mean IoU (%) Download
CEN ResNet101 76.2 62.8 51.1 Google Drive
CEN ResNet152 77.0 64.4 51.6 Google Drive
Ours SegFormer-B3 78.7 67.5 54.8 Google Drive

Mindspore implementation is available at: https://gitee.com/mindspore/models/tree/master/research/cv/TokenFusion

Image-to-Image Translation

First,

cd image2image_translation

Training script, from Shade and Texture to RGB,

python main.py --gpu 0 -c exp_name

This script will auto-evaluate on the validation dataset every 5 training epochs.

Predicted images will be automatically saved during training, in the following folder structure:

code_root/ckpt/exp_name/results
  ├── input0  # 1st modality input
  ├── input1  # 2nd modality input
  ├── fake0   # 1st branch output 
  ├── fake1   # 2nd branch output
  ├── fake2   # ensemble output
  ├── best    # current best output
  │    ├── fake0
  │    ├── fake1
  │    └── fake2
  └── real    # ground truth output

Checkpoint models:

Method Task FID KID Download
CEN Texture+Shade->RGB 62.6 1.65 -
Ours Texture+Shade->RGB 45.5 1.00 Google Drive

3D Object Detection (under construction)

Data preparation, environments, and training scripts follow Group-Free and ImVoteNet.

E.g.,

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --master_port 2229 --nproc_per_node 4 train_dist.py --max_epoch 600 --val_freq 25 --save_freq 25 --lr_decay_epochs 420 480 540 --num_point 20000 --num_decoder_layers 6 --size_cls_agnostic --size_delta 0.0625 --heading_delta 0.04 --center_delta 0.1111111111111 --weight_decay 0.00000001 --query_points_generator_loss_coef 0.2 --obj_loss_coef 0.4 --dataset sunrgbd --data_root . --use_img --log_dir log/exp_name

Citation

If you find our work useful for your research, please consider citing the following paper.

@inproceedings{wang2022tokenfusion,
  title={Multimodal Token Fusion for Vision Transformers},
  author={Wang, Yikai and Chen, Xinghao and Cao, Lele and Huang, Wenbing and Sun, Fuchun and Wang, Yunhe},
  booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2022}
}