Rethinking Feature Extraction: Gradient-based Localized Feature Extraction for End-to-End Surgical Downstream Tasks
This repository contains the reference code for the paper "Rethinking Feature Extraction: Gradient-based Localized Feature Extraction for End-to-End Surgical Downstream Tasks." To learn more about the project, check out our presentation video.
We develop a detector-free gradient-based localized feature extraction approach that enables end-to-end model training for downstream surgical tasks such as report generation and tool-tissue interaction graph prediction. We eliminate the need for object detection or region proposal and feature extraction networks by extracting the features of interest from the discriminative regions in the feature map of the classification models. Here, the discriminative regions are localized using gradient-based localization techniques (e.g. Grad-CAM). We show that our proposed approaches enable the realtime deployment of end-to-end models for surgical downstream tasks.
If you find our code or paper useful, please cite as
@article{pang2022rethinking,
title={Rethinking Feature Extraction: Gradient-Based Localized Feature Extraction for End-To-End Surgical Downstream Tasks},
author={Pang, Winnie and Islam, Mobarakol and Mitheran, Sai and Seenivasan, Lalithkumar and Xu, Mengya and Ren, Hongliang},
journal={IEEE Robotics and Automation Letters},
volume={7},
number={4},
pages={12623--12630},
year={2022},
publisher={IEEE}
}
- Clone the repository
git clone https://github.com/PangWinnie0219/GradCAMDownstreamTask.git
- Install the packages required using the
requirements.txt
file:
pip install -r requirements.txt
Note: Python 3.6 is required to run our code.
We are using the dataset from Cholec80 and Robotic Instrument Segmentation Dataset from MICCAI2018 Endoscopic Vision Challenge.
Cholec80 dataset: As the tissue label is required for captioning and interaction tasks, we added one extra label at the end of the original tool annotations of all samples, as shown in figure below. Since many types of tissues are present in the Cholec80 datasets (e.g. gallbladder, cystic plate and liver), the tissue label added in this work does not refer to the specific tissue but referring to the interacting tissue. For simplicity, we assume interacting tissue appears at all the frames in Cholec80 dataset.
Run python3.6 baseline.py
to start training the classification model. Ensure save
is set to True
as this checkpoint will be used for visualization and feature extraction later.
Otherwise, you can downloaded the trained model file:
- GC-A: [miccai2018_9class_ResNet50_256,320_32_lr_0.001_dropout_0.2_best_checkpoint.pth.tar] (To be added)
- GC-B: [miccai2018_9class_cholecResNet50_256,320_32_lr_0.001_dropout_0.2_best_checkpoint.pth.tar] (To be added)
- GC-C: [miccai2018_11class_cholec_ResNet50_256,320_32_lr_0.001_best_checkpoint.pth.tar] (To be added)
- GC-D: [combine_miccai18_ResNet50_256,320_170_best_checkpoint.pth.tar] (To be added)
Place the trained model file inside the ./best_model_checkpoints
.
cd into the utils
directory
cd utils
You can visualise the Grad-CAM heatmap and bounding box using
python3.6 miccai_bbox.py
In order to select a specific frame and heatmap of specific class, you can define them with bidx
and tclass
respectively. For example if you want to view the heatmap
for class 3 of the 15th image in the dataset, you can run the following:
python3.6 miccai_bbox.py --bidx 15 --tclass 3
The threshold, T_ROI can be defined using threshold
to see the effect of thresholding to the bounding box generation.
Examples of the Grad-CAM heatmap and bounding box can be found in the supplementary video.
Set the result_filename in the code to accordingly if you are training the Grad-CAM model from scratch. If you are using our checkpoint, set gc_model
to 1
, 2
3
or 4
to load the checkpoint from GC-A, GC-B, GC-C and GC-D respectively. If you are using gc_model
= 1
or 2
, set cls
to 9, else, set cls
to 11.
This method is similar to the conventional feature extraction method. The region images will be cropped from the raw image and these cropped region images will be forwarded to the feature extractor.
- Crop the region images based on the predicted bounding box
python3.6 utils/crop_bbox.py
- Forward the cropped region image to the model again
python3.6 image_extract_feature.py
The features is extracted from the feature map of the classification model based on the bounding box coordinates.
python3.6 bbox_extract_feature.py
The features is extracted from the feature map of the classification model in a single-pass based on the heatmap (no bounding box generation).
python3.6 heatmap_extract_feature.py
Code adopted and modified from : pytorch-grad-cam
The features extracted can be used for the downstream task such as:
- Captioning
- Paper: Meshed-Memory Transformer for Image Captioning
- Official implementation code
- Interaction
- Paper: CogTree: Cognition Tree Loss for Unbiased Scene Graph Generation
- Official implementation code
If you have any questions or feedback about this project, feel free to contact me at winnie_pang@u.nus.edu.