Skip to content

A JAX-based deep learning framework for image segmentation using diffusion models.

License

Notifications You must be signed in to change notification settings

mathpluscode/ImgX-DiffSeg

Repository files navigation

ImgX-DiffSeg

pre-commit unit-test CodeFactor License

ImgX-DiffSeg is a Jax-based deep learning toolkit using Flax for biomedical image segmentation.

This repository includes the implementation of the following work

🚧 The codebase is still under active development for more enhancements and applications. Please check release notes for more information. 🚧

📫 Please feel free to create an issue to request features or reach out for collaborations. 📫

graphic_abstract

Features

Current supported functionalities are summarized as follows.

Data sets

See the readme for further details.

Algorithms

Models

Training

  • Patch-based training.
  • Data augmentation with anisotropic support, including
    • Random affine: rotation, scaling, shearing, shifting.
    • Random gamma adjustment.
    • Random flip.
  • Multi-device training (one model per device) with pmap.
  • Mixed precision training.
  • Gradient clipping and accumulation.
  • Early stopping.

Installation

TPU with Docker

The following instructions have been tested only for TPU-v3-8. The docker container uses the root user.

  1. TPU often has limited disk space. RAM disk can be used to help.

    sudo mkdir /tmp/ramdisk
    sudo chmod 777 /tmp/ramdisk
    sudo mount -t tmpfs -o size=256G imgxramdisk /tmp/ramdisk
    cd /tmp/ramdisk/
  2. Build the docker image inside the repository.

    sudo docker build --build-arg USER_ID=$(id -u) --build-arg GROUP_ID=$(id -g) -f docker/Dockerfile.tpu -t imgx .

    where

    • --build-arg provides argument values.
    • -f provides the docker file.
    • -t tag the docker image.
  3. Run the Docker container.

    mkdir -p $(cd ../ && pwd)/tensorflow_datasets
    sudo docker run -it --rm --privileged --network host \
    -v "$(pwd)":/app/ImgX \
    -v "$(cd ../ && pwd)"/tensorflow_datasets:/root/tensorflow_datasets \
    imgx bash
  4. Install the package inside the container.

    make pip

GPU with Docker

CUDA >= 11.8 is required. The docker container uses non-root user. Docker image used may be removed.

  1. Build the docker image inside the repository.

    docker build --build-arg HOST_UID=$(id -u) --build-arg HOST_GID=$(id -g) -f docker/Dockerfile -t imgx .

    where

    • --build-arg provides argument values.
    • -f provides the docker file.
    • -t tag the docker image.
  2. Run the Docker container.

    mkdir -p $(cd ../ && pwd)/tensorflow_datasets
    docker run -it --rm --gpus all \
    -v "$(pwd)":/app/ImgX \
    -v "$(cd ../ && pwd)"/tensorflow_datasets:/home/app/tensorflow_datasets \
    imgx bash

    where

    • --rm removes the container once exits it.
    • -v maps the current folder into the container.
  3. Install the package inside the container.

    make pip

Local with Conda

Install Conda for Mac M1

Download Miniforge from GitHub and install it.

conda install -y -n base conda-libmamba-solver
conda config --set solver libmamba
conda env update -f docker/environment_mac_m1.yml

Install Conda for Linux / Mac Intel

Install Conda and then create the environment.

conda install -y -n base conda-libmamba-solver
conda config --set solver libmamba
conda env update -f docker/environment.yml

Activate Conda Environment

Activate the environment and install the package.

conda activate imgx
make pip

Build Data Sets

⚠️ For using your own data set, the following steps are not needed. Please check the existing datasets for examples of using TFDS. Particularly, BraTS 2021 does not require downloading.

💪 We are working on a toy example for using custom data sets without TFDS, thanks for your understanding.

Use the following commands to (re)build all data sets. Check the README of datasets for details. Especially, manual downloading is required for the BraTS 2021 dataset.

make build_dataset
make rebuild_dataset

Or build the selected data set by running one of the following commands.

tfds build imgx/datasets/male_pelvic_mr
tfds build imgx/datasets/amos_ct
tfds build imgx/datasets/muscle_us
tfds build imgx/datasets/brats2021_mr  # requires downloading data manually

Experiment

Training and Testing

Example command to use two GPUs for training, validation, and testing. The outputs are stored under wandb/latest-run/files/, where

  • ckpt stores the model checkpoints and corresponding validation metrics.
  • test_evaluation stores the prediction on the test set and corresponding metrics.
# limit to two GPUs if using NVIDIA GPUs
export CUDA_VISIBLE_DEVICES="0,1"

# select the data set to use
export DATASET_NAME="male_pelvic_mr"
export DATASET_NAME="amos_ct"
export DATASET_NAME="muscle_us"
export DATASET_NAME="brats2021_mr"

# Vanilla segmentation
imgx_train data=${DATASET_NAME} task=seg
imgx_test --log_dir wandb/latest-run/

# Diffusion-based segmentation
imgx_train data=${DATASET_NAME} task=gaussian_diff_seg
imgx_test --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDPM
imgx_valid --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDIM
imgx_test --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDIM

Optionally, for debugging purposes, use the flag debug=True to run the experiment with a small dataset and smaller models.

imgx_train data=${DATASET_NAME} task=seg debug=True
imgx_test --log_dir wandb/latest-run/
imgx_train data=${DATASET_NAME} task=gaussian_diff_seg debug=True
imgx_test --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDPM

Code Quality

Pre-commit

Install pre-commit hooks:

pre-commit install
wily build .

Update hooks, and re-verify all files.

pre-commit autoupdate
pre-commit run --all-files

Code Test

Run the command below to test and get a coverage report. As JAX tests require two CPUs, -n 4 uses 4 threads, therefore requires 8 CPUs in total.

pytest --cov=imgx -n 4 imgx -k "not integration"

-k "not integration" excludes integration tests, which require downloading muscle ultrasound and amos CT data sets.

For integration tests, run the command below. -s enables the print of stdout. This test may take 40-60 minutes.

pytest imgx/integration_test.py -s

To test the jupyter notebooks, run the command below.

pytest --nbmake examples/**/*.ipynb

References

Acknowledgement

This work was supported by the EPSRC grant (EP/T029404/1), the Wellcome/EPSRC Centre for Interventional and Surgical Sciences (203145Z/16/Z), the International Alliance for Cancer Early Detection, an alliance between Cancer Research UK (C28070/A30912, C73666/A31378), Canary Center at Stanford University, the University of Cambridge, OHSU Knight Cancer Institute, University College London and the University of Manchester, and Cloud TPUs from Google's TPU Research Cloud (TRC).

Citation

If you find the code base and method useful in your research, please cite the relevant paper:

@article{melba:2023:016:fu,
    title = "A Recycling Training Strategy for Medical Image Segmentation with Diffusion Denoising Models",
    author = "Fu, Yunguan and Li, Yiwen and Saeed, Shaheer U. and Clarkson, Matthew J. and Hu, Yipeng",
    journal = "Machine Learning for Biomedical Imaging",
    volume = "2",
    issue = "Special Issue for Generative Models",
    year = "2023",
    pages = "507--546",
    issn = "2766-905X",
    doi = "https://doi.org/10.59275/j.melba.2023-fbe4",
    url = "https://melba-journal.org/2023:016"
}

@article{fu2023importance,
  title={Importance of Aligning Training Strategy with Evaluation for Diffusion Models in 3D Multiclass Segmentation},
  author={Fu, Yunguan and Li, Yiwen and Saeed, Shaheer U and Clarkson, Matthew J and Hu, Yipeng},
  journal={arXiv preprint arXiv:2303.06040},
  year={2023},
  doi={10.48550/arXiv.2303.06040},
  url={https://arxiv.org/abs/2303.06040},
}