Skip to content
/ VoCo Public

[CVPR 2024] VoCo: A Simple-yet-Effective Volume Contrastive Learning Framework for 3D Medical Image Analysis

License

Notifications You must be signed in to change notification settings

Luffy03/VoCo

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

VoCo

Paper PDF

Code for CVPR 2024 paper, "VoCo: A Simple-yet-Effective Volume Contrastive Learning Framework for 3D Medical Image Analysis"

Authors: Linshan Wu, Jiaxin Zhuang, and Hao Chen

This work presents VoCo, a simple-yet-effective contrastive learning framework for pre-training large scale 3D medical images. Our 10k CT images pre-training model are available. Our 160k CT images pre-training models are available!

Our extention is at Large-Scale-Medical, which provides stronger models, larger-scale datasets, various training recipes, and more downstream tasks!!!

Abstract

Self-Supervised Learning (SSL) has demonstrated promising results in 3D medical image analysis. However, the lack of high-level semantics in pre-training still heavily hinders the performance of downstream tasks. We observe that 3D medical images contain relatively consistent contextual position information, i.e., consistent geometric relations between different organs, which leads to a potential way for us to learn consistent semantic representations in pre-training. In this paper, we propose a simple-yet-effective Volume Contrast (VoCo) framework to leverage the contextual position priors for pre-training. Specifically, we first generate a group of base crops from different regions while enforcing feature discrepancy among them, where we employ them as class assignments of different regions. Then, we randomly crop sub-volumes and predict them belonging to which class (located at which region) by contrasting their similarity to different base crops, which can be seen as predicting contextual positions of different sub-volumes. Through this pretext task, VoCo implicitly encodes the contextual position priors into model representations without the guidance of annotations, enabling us to effectively improve the performance of downstream tasks that require high-level semantics. Extensive experimental results on six downstream tasks demonstrate the superior effectiveness of VoCo.

teaser

Usage

Load Pre-trained weight

import torch
import argparse
from monai.networks.nets import SwinUNETR

parser = argparse.ArgumentParser(description="Swin UNETR")
parser.add_argument("--roi_x", default=roi, type=int, help="roi size in x direction")
parser.add_argument("--roi_y", default=roi, type=int, help="roi size in y direction")
parser.add_argument("--roi_z", default=roi, type=int, help="roi size in z direction")
parser.add_argument("--feature_size", default=48, type=int, help="feature size")
parser.add_argument("--in_channels", default=1, type=int, help="number of input channels")
parser.add_argument("--out_channels", default=14, type=int, help="number of output channels")
parser.add_argument("--use_checkpoint", default=True, help="use gradient checkpointing to save memory")
args = parser.parse_args()
model = SwinUNETR(
        img_size=(args.roi_x, args.roi_y, args.roi_z),
        in_channels=args.in_channels,
        out_channels=args.out_channels,
        feature_size=args.feature_size,
        use_checkpoint=args.use_checkpoint,
        use_v2=True)
model_dict = torch.load(args.pretrained_checkpoint, map_location=torch.device('cpu'))
state_dict = model_dict
if "module." in list(state_dict.keys())[0]:
    print("Tag 'module.' found in state dict - fixing!")
    for key in list(state_dict.keys()):
        state_dict[key.replace("module.", "")] = state_dict.pop(key)
if "swin_vit" in list(state_dict.keys())[0]:
    print("Tag 'swin_vit' found in state dict - fixing!")
    for key in list(state_dict.keys()):
        state_dict[key.replace("swin_vit", "swinViT")] = state_dict.pop(key)
model.load_state_dict(state_dict, strict=False)
print("Using pretrained voco ema self-supervised Swin UNETR backbone weights !")

Prepare Dataset

First, you need to download the pre-training dataset. The 10k dataset are all open-source and you can download yourself. Or you can download it in our hugging face repo. teaser Note: 10k dataset is collected by Dr. Jiaxin Zhuang

├── data
    ├── BTCV
    ├── TCIAcovid19
    ├── Luna16-jx
    ├── stoic21
    ├── Totalsegmentator_dataset
    ├── Flare23
    ├── LiDC
    └── HNSCC_convert_v1

Pre-Training

(1) Note that in this repo, we present the version of our 10k pre-training, thus some details may be different to our paper.

(2) To accerlate the training, we use "Persistentdataset" to pre-cache dataset, which requires extra storage. It is important in our codes. If you don't have enough storage, you can change it back in "utils/data_utils.py".

To pre-train:

sh train.sh

Finetune

Our finetune codes will soon be available, or you can directly use the codes in MONAI.

More finetune implementation are in preparation!

Acknowledgement

We thank MONAI for part of their codes.

Citation ✏️ 📄

If you find this repo useful for your research, please consider citing the paper as follows:

@article{wu2024large,
  title={Large-Scale 3D Medical Image Pre-training with Geometric Context Priors},
  author={Wu, Linshan and Zhuang, Jiaxin and Chen, Hao},
  journal={arXiv preprint arXiv:2410.09890},
  year={2024}
}
@InProceedings{voco,
    author    = {Wu, Linshan and Zhuang, Jiaxin and Chen, Hao},
    title     = {VoCo: A Simple-yet-Effective Volume Contrastive Learning Framework for 3D Medical Image Analysis},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2024},
    pages     = {22873-22882}
}

About

[CVPR 2024] VoCo: A Simple-yet-Effective Volume Contrastive Learning Framework for 3D Medical Image Analysis

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published