This is a pneumonia classification project that addresses the issue of class imbalance by utilizing generative adversarial networks (GAN) to generate images of minority class samples. In addition, the spatial attention mechanism is introduced into ResNet18 to enhance the generalization performance of classifier!
🔥 Workflow
- ✅ Mar 21, 2023. Creat "MedGAN-ResLite" project repository and Find MedMNIST.
- ✅ Mar 22, 2023. Generate pneumonia samples with DCGAN.
- ✅ May 30, 2023. Replace original Loss function with Hinge Adversial Loss.
- ✅ Apri 1, 2023. DCGAN + Spectral Normalization.
- ✅ Apri 4, 2023. Add DCGAN metrics:Inception Score + FID + KID; Fuse and Split dataset;
- ✅ Apri 5, 2023. Override the dataset inheritance class.
- ✅ Apri 6, 2023. Write train, eval and infer scripts for classifier. And get a new-model by modifing input & output shape of pre-trained model. Add metrics:acc + auc + f1 + confusion matrix.
- ✅ April 7, 2023. Add scripts: export_onnx.py and inference_onnx.py.
- ✅ April 8, 2023. Tuning the hyperparameters of DCGAN.
- ✅ April 10, 2023. Explore CBAM attention mechanism to add location and quantity.
- ✅ April 14, 2023. Abalation Study: GAN, DCGAN, DCGAN+Hinge, DCGAN + SN, DCGAN + SH.
- ✅ April 21, 2023. Attention mechanism visualization using CAM.
- ✅ April 25, 2023. Make a Presentation.
-
Coming Back!
-
✅ Mar 10, 2024. The dataset was preprocessed using Chest X-ray 2017 with reference to MedMNIST [paper] practices.
-
✅ Mar 11, 2024. Train GAN & CNN again!
-
✅ Mar 13, 2024. Histogram equalization was tried, but it did not work well~
-
✅ Mar 15, 2024. Attempts were made to introduce residual connection in GAN, but the effect was not good and the training speed was affected~
-
✅ Mar 20, 2024. Trying the WGAN training strategy and introducing Wasserstein distance did not work well~
-
✅ Mar 24, 2024. Add Pruning Sample by NNI.
-
✅ May 15, 2024. Release PulmoInsight Web Applicaiton!
-
✅ May 21, 2024. Release of MedGAN-ResLite-V2!
✅ | ✅❗ | ❓ |
---|---|---|
Finished, and Successfully! | Finished, but Failed! | Unfinished! |
- ❓ Experiment with more challenging datasets, such as ChestXRay2017, Kaggle, etc.
- ❓ Consider introducing the idea of "learning" when scaling the image, such as adopting transposed convolution instead of interpolation when scaling up the image size
- ✅❗ May 23, 2023. Try Muti-Scale Fusion.
- ✅❗ May 25, 2023. Introduce class information into DCGAN to generate samples.【cDCGAN】
- ❓ Replace original Loss function with Wasserstein distance.
- ❓ Apply ensemble learning methods, such as voting evaluation.
- ❓ Pruning:one-hot + iterative ,including L1✅、L2✅、FPGM✅、BNScale.
- ❓ Build the pruned model automatically.
- ❓ Knowledge distillation:design lightweight network A,and use pruned-model to guide A.
- ❓ Deploy model on CPU and NSC2 using OpenVINO. 【Python ✅and C++ version】.
- ❓ Deploy on the web side using Django or flask.
- ❓ Explore the influence of attention mechanism on deep network and shallow network.
Clone repo and install requirements.txt.
git clone git@github.com:MaitreChen/MedGAN-ResLite.git
cd MedGAN-ResLite
pip install -r requirements.txt
You can download dataset from this link. It includes the pneumoniamnist original real dataset and the fake dataset synthesized using GAN (see data README.md for details)
The dataset structure directory is as follows:
MedGAN-ResLite/
|__ data/
|__ real/
|__ train/
|__ normal/
|__ img_1.png
|__ ...
|__ pneumonia/
|__ img_1.png
|__ ...
|__ val/
|__ normal/
|__ pneumonia/
|__ test/
|__ ...
|__ fake/
|__ ...
You can download pretrained checkpoints from this link and put it in your pretrained/ folder. It contains resnet18-sam and sh-dcgan model.
🚀Quick start, and the results will be saved in the figures/classifier_torch folder.
python infer_classifier.py --ckpt-path pretrained/resnet18-sam.pth --image-path imgs/pneumonia_img1.png
🌜Here are the options in more detail:
Option | Description |
---|---|
--ckpt-path | Checkpoints path to load the pre-trained weights for inference. |
--image-path | Path of the input image for inference. |
--device | Alternative infer device, cpu or cuda, default is cpu. |
📛Note
If you want to visualize the attention mechanism, run the following command and the results will be saved in the figures/heatmap folder.
python utils/cam.py --image-path imgs/pneumonia_img1.png
More information about CAM can be found here!💖
🚀Quick start, and the results will be saved in the figures/generator_torch folder.
python infer_generator.py --ckpt-path pretrained/sn-dcgan.pth --batch-size 1 --mode -1
📛Note
If you want to generate fake images for training or sprite images, run following commands:
-
Generate a Sprite map. 【save results in outs/sprite】
python infer_generator.py --ckpt-path pretrained/sn-dcgan.pth --batch-size 64 --mode 0
-
Generate a batch of images. 【save results in outs/singles】
python infer_generator.py --ckpt-path pretrained/sn-dcgan.pth --batch-size 50 --mode 1
💨When you generate a batch of images, batch-size is whatever you like❤
python eval.py --ckpt-path pretrained/resnet18-sam.pth
To evaluate a model, make sure you have torch-fidelity installed in requirements.txt❗
Then, you should prepare two datasets❗
- training datasets in data/merge folder. 【real images】
- generation datasets in outs folder. 【fake images】
Everything is ready, you can execute the following command:
fidelity --gpu 0 --fid --input1 data/merge --input2 data/outs/singles
More information about fidelity can be found here!💖
python train_classifier.py
💝 More details about training your own dataset
Please refer to data/config.yaml and README.md.
In addition, you need to set the normalized parameters mean and std! Please refer to utils/image_utils.py.
python train_dcgan.py
If you want to export the ONNX model for ONNXRuntime or OpenVINO, please refer to README.md!
To use ONNXRuntime, refer to README.md and onnx/inference_onnx.py!
To use OpenVINO, refer to README.md!
Method | Inception Score | FID | KID |
---|---|---|---|
GAN | 2.20 | 260.15 | 0.42 |
DCGAN | 2.20 | 259.72 | 0.39 |
SH-DCGAN | 2.20 | 206.14 | 0.31 |
Method | Inception Score | FID | KID |
---|---|---|---|
DCGAN | 2.20 | 259.72 | 0.39 |
DCGAN + Hinge | 2.20 | 252.42 | 0.38 |
DCGAN + SN | 2.20 | 232.59 | 0.35 |
SH-DCGAN | 2.20 | 206.14 | 0.31 |
Model | Accuracy/% | Precision/% | Recall/% | F1 score/% |
---|---|---|---|---|
AlexNet | 90.16 | 90.16 | 90.16 | 90.16 |
VGG16 | 91.22 | 92.23 | 91.22 | 91.17 |
VGG19 | 91.76 | 92.70 | 91.76 | 91.71 |
ResNet34 | 92.55 | 93.26 | 92.55 | 92.52 |
ResNet50 | 91.15 | 92.44 | 92.15 | 92.14 |
MobileNetV2 | 92.29 | 92.60 | 92.29 | 92.27 |
ResNet18 | 92.02 | 92.02 | 92.02 | 92.02 |
ResNet18-SAM | 93.48 | 93.82 | 93.48 | 93.47 |
For any questions or suggestions about this project, welcome everyone to raise issues!
Also, please feel free to contact hbchenstu@outlook.com.
Thank you, wish you have a pleasant experience~~💓🧡💛💚💙💜