ImgX-DiffSeg is a Jax-based deep learning toolkit using Flax for biomedical image segmentation.
This repository includes the implementation of the following work
- A Recycling Training Strategy for Medical Image Segmentation with Diffusion Denoising Models
- Importance of Aligning Training Strategy with Evaluation for Diffusion Models in 3D Multiclass Segmentation
🚧 The codebase is still under active development for more enhancements and applications. Please check release notes for more information. 🚧
📫 Please feel free to create an issue to request features or reach out for collaborations. 📫
Current supported functionalities are summarized as follows.
Data sets
See the readme for further details.
- Muscle ultrasound from Marzola et al. 2021.
- Male pelvic MR from Li et al. 2022.
- AMOS CT from Ji et al. 2022.
- Brain MR from Baid et al. 2021.
Algorithms
- Supervised segmentation.
- Diffusion-based segmentation.
- Gaussian noise based diffusion.
- Noise prediction (epsilon-parameterization) or ground truth prediction (x0-parameterization).
- Importance sampling for timestep.
- Recycling training strategies, including xt-recycling and xT-recycling.
- Self-conditioning training strategies, including Chen et al. 2022 and Watson et al. 2023..
Models
- U-Net with Transformers supporting 2D and 3D images.
- Efficient attention.
Training
- Patch-based training.
- Data augmentation with anisotropic support, including
- Random affine: rotation, scaling, shearing, shifting.
- Random gamma adjustment.
- Random flip.
- Multi-device training (one model per device) with
pmap
. - Mixed precision training.
- Gradient clipping and accumulation.
- Early stopping.
The following instructions have been tested only for TPU-v3-8. The docker container uses the root user.
-
TPU often has limited disk space. RAM disk can be used to help.
sudo mkdir /tmp/ramdisk sudo chmod 777 /tmp/ramdisk sudo mount -t tmpfs -o size=256G imgxramdisk /tmp/ramdisk cd /tmp/ramdisk/
-
Build the docker image inside the repository.
sudo docker build --build-arg USER_ID=$(id -u) --build-arg GROUP_ID=$(id -g) -f docker/Dockerfile.tpu -t imgx .
where
--build-arg
provides argument values.-f
provides the docker file.-t
tag the docker image.
-
Run the Docker container.
mkdir -p $(cd ../ && pwd)/tensorflow_datasets sudo docker run -it --rm --privileged --network host \ -v "$(pwd)":/app/ImgX \ -v "$(cd ../ && pwd)"/tensorflow_datasets:/root/tensorflow_datasets \ imgx bash
-
Install the package inside the container.
make pip
CUDA >= 11.8 is required. The docker container uses non-root user. Docker image used may be removed.
-
Build the docker image inside the repository.
docker build --build-arg HOST_UID=$(id -u) --build-arg HOST_GID=$(id -g) -f docker/Dockerfile -t imgx .
where
--build-arg
provides argument values.-f
provides the docker file.-t
tag the docker image.
-
Run the Docker container.
mkdir -p $(cd ../ && pwd)/tensorflow_datasets docker run -it --rm --gpus all \ -v "$(pwd)":/app/ImgX \ -v "$(cd ../ && pwd)"/tensorflow_datasets:/home/app/tensorflow_datasets \ imgx bash
where
--rm
removes the container once exits it.-v
maps the current folder into the container.
-
Install the package inside the container.
make pip
Download Miniforge from GitHub and install it.
conda install -y -n base conda-libmamba-solver
conda config --set solver libmamba
conda env update -f docker/environment_mac_m1.yml
Install Conda and then create the environment.
conda install -y -n base conda-libmamba-solver
conda config --set solver libmamba
conda env update -f docker/environment.yml
Activate the environment and install the package.
conda activate imgx
make pip
💪 We are working on a toy example for using custom data sets without TFDS, thanks for your understanding.
Use the following commands to (re)build all data sets. Check the README of datasets for details. Especially, manual downloading is required for the BraTS 2021 dataset.
make build_dataset
make rebuild_dataset
Or build the selected data set by running one of the following commands.
tfds build imgx/datasets/male_pelvic_mr
tfds build imgx/datasets/amos_ct
tfds build imgx/datasets/muscle_us
tfds build imgx/datasets/brats2021_mr # requires downloading data manually
Example command to use two GPUs for training, validation, and testing. The outputs are stored under
wandb/latest-run/files/
, where
ckpt
stores the model checkpoints and corresponding validation metrics.test_evaluation
stores the prediction on the test set and corresponding metrics.
# limit to two GPUs if using NVIDIA GPUs
export CUDA_VISIBLE_DEVICES="0,1"
# select the data set to use
export DATASET_NAME="male_pelvic_mr"
export DATASET_NAME="amos_ct"
export DATASET_NAME="muscle_us"
export DATASET_NAME="brats2021_mr"
# Vanilla segmentation
imgx_train data=${DATASET_NAME} task=seg
imgx_test --log_dir wandb/latest-run/
# Diffusion-based segmentation
imgx_train data=${DATASET_NAME} task=gaussian_diff_seg
imgx_test --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDPM
imgx_valid --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDIM
imgx_test --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDIM
Optionally, for debugging purposes, use the flag debug=True
to run the experiment with a small
dataset and smaller models.
imgx_train data=${DATASET_NAME} task=seg debug=True
imgx_test --log_dir wandb/latest-run/
imgx_train data=${DATASET_NAME} task=gaussian_diff_seg debug=True
imgx_test --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDPM
Install pre-commit hooks:
pre-commit install
wily build .
Update hooks, and re-verify all files.
pre-commit autoupdate
pre-commit run --all-files
Run the command below to test and get a coverage report. As JAX tests require two CPUs, -n 4
uses
4 threads, therefore requires 8 CPUs in total.
pytest --cov=imgx -n 4 imgx -k "not integration"
-k "not integration"
excludes integration tests, which require downloading muscle ultrasound and
amos CT data sets.
For integration tests, run the command below. -s
enables the print of stdout. This test may take
40-60 minutes.
pytest imgx/integration_test.py -s
To test the jupyter notebooks, run the command below.
pytest --nbmake examples/**/*.ipynb
- Segment Anything (PyTorch)
- MONAI (PyTorch)
- Cross Institution Few Shot Segmentation (PyTorch)
- MegSegDiff (PyTorch)
- MegSegDiff (PyTorch, lucidrains)
- DeepReg (Tensorflow)
- Scenic (JAX)
- DeepMind Research (JAX)
- Haiku (JAX)
- Flax (JAX)
This work was supported by the EPSRC grant (EP/T029404/1), the Wellcome/EPSRC Centre for Interventional and Surgical Sciences (203145Z/16/Z), the International Alliance for Cancer Early Detection, an alliance between Cancer Research UK (C28070/A30912, C73666/A31378), Canary Center at Stanford University, the University of Cambridge, OHSU Knight Cancer Institute, University College London and the University of Manchester, and Cloud TPUs from Google's TPU Research Cloud (TRC).
If you find the code base and method useful in your research, please cite the relevant paper:
@article{melba:2023:016:fu,
title = "A Recycling Training Strategy for Medical Image Segmentation with Diffusion Denoising Models",
author = "Fu, Yunguan and Li, Yiwen and Saeed, Shaheer U. and Clarkson, Matthew J. and Hu, Yipeng",
journal = "Machine Learning for Biomedical Imaging",
volume = "2",
issue = "Special Issue for Generative Models",
year = "2023",
pages = "507--546",
issn = "2766-905X",
doi = "https://doi.org/10.59275/j.melba.2023-fbe4",
url = "https://melba-journal.org/2023:016"
}
@article{fu2023importance,
title={Importance of Aligning Training Strategy with Evaluation for Diffusion Models in 3D Multiclass Segmentation},
author={Fu, Yunguan and Li, Yiwen and Saeed, Shaheer U and Clarkson, Matthew J and Hu, Yipeng},
journal={arXiv preprint arXiv:2303.06040},
year={2023},
doi={10.48550/arXiv.2303.06040},
url={https://arxiv.org/abs/2303.06040},
}