Skip to content

Commit

Permalink
[CodeCamp2023-500]add large_image_demo (#10719)
Browse files Browse the repository at this point in the history
  • Loading branch information
YanxingLiu authored Aug 14, 2023
1 parent 19ed2f1 commit 60b29b3
Show file tree
Hide file tree
Showing 6 changed files with 519 additions and 4 deletions.
Binary file added demo/large_image.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
282 changes: 282 additions & 0 deletions demo/large_image_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Perform MMDET inference on large images (as satellite imagery) as:
```shell
wget -P checkpoint https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_2x_coco/faster_rcnn_r101_fpn_2x_coco_bbox_mAP-0.398_20200504_210455-1d2dac9c.pth # noqa: E501, E261.
python demo/large_image_demo.py \
demo/large_image.jpg \
configs/faster_rcnn/faster-rcnn_r101_fpn_2x_coco.py \
checkpoint/faster_rcnn_r101_fpn_2x_coco_bbox_mAP-0.398_20200504_210455-1d2dac9c.pth
```
"""

import os
import random
from argparse import ArgumentParser
from pathlib import Path

import mmcv
import numpy as np
from mmengine.config import Config, ConfigDict
from mmengine.logging import print_log
from mmengine.utils import ProgressBar

from mmdet.apis import inference_detector, init_detector

try:
from sahi.slicing import slice_image
except ImportError:
raise ImportError('Please run "pip install -U sahi" '
'to install sahi first for large image inference.')

from mmdet.registry import VISUALIZERS
from mmdet.utils.large_image import merge_results_by_nms, shift_predictions
from mmdet.utils.misc import get_file_list


def parse_args():
parser = ArgumentParser(
description='Perform MMDET inference on large images.')
parser.add_argument(
'img', help='Image path, include image file, dir and URL.')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--out-dir', default='./output', help='Path to output file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--show', action='store_true', help='Show the detection results')
parser.add_argument(
'--tta',
action='store_true',
help='Whether to use test time augmentation')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
parser.add_argument(
'--patch-size', type=int, default=640, help='The size of patches')
parser.add_argument(
'--patch-overlap-ratio',
type=float,
default=0.25,
help='Ratio of overlap between two patches')
parser.add_argument(
'--merge-iou-thr',
type=float,
default=0.25,
help='IoU threshould for merging results')
parser.add_argument(
'--merge-nms-type',
type=str,
default='nms',
help='NMS type for merging results')
parser.add_argument(
'--batch-size',
type=int,
default=1,
help='Batch size, must greater than or equal to 1')
parser.add_argument(
'--debug',
action='store_true',
help='Export debug results before merging')
parser.add_argument(
'--save-patch',
action='store_true',
help='Save the results of each patch. '
'The `--debug` must be enabled.')
args = parser.parse_args()
return args


def main():
args = parse_args()

config = args.config

if isinstance(config, (str, Path)):
config = Config.fromfile(config)
elif not isinstance(config, Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
if 'init_cfg' in config.model.backbone:
config.model.backbone.init_cfg = None

if args.tta:
assert 'tta_model' in config, 'Cannot find ``tta_model`` in config.' \
" Can't use tta !"
assert 'tta_pipeline' in config, 'Cannot find ``tta_pipeline`` ' \
"in config. Can't use tta !"
config.model = ConfigDict(**config.tta_model, module=config.model)
test_data_cfg = config.test_dataloader.dataset
while 'dataset' in test_data_cfg:
test_data_cfg = test_data_cfg['dataset']

test_data_cfg.pipeline = config.tta_pipeline

# TODO: TTA mode will error if cfg_options is not set.
# This is an mmdet issue and needs to be fixed later.
# build the model from a config file and a checkpoint file
model = init_detector(
config, args.checkpoint, device=args.device, cfg_options={})

if not os.path.exists(args.out_dir) and not args.show:
os.mkdir(args.out_dir)

# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = model.dataset_meta

# get file list
files, source_type = get_file_list(args.img)

# start detector inference
print(f'Performing inference on {len(files)} images.... '
'This may take a while.')
progress_bar = ProgressBar(len(files))
for file in files:
# read image
img = mmcv.imread(file)

# arrange slices
height, width = img.shape[:2]
sliced_image_object = slice_image(
img,
slice_height=args.patch_size,
slice_width=args.patch_size,
auto_slice_resolution=False,
overlap_height_ratio=args.patch_overlap_ratio,
overlap_width_ratio=args.patch_overlap_ratio,
)
# perform sliced inference
slice_results = []
start = 0
while True:
# prepare batch slices
end = min(start + args.batch_size, len(sliced_image_object))
images = []
for sliced_image in sliced_image_object.images[start:end]:
images.append(sliced_image)

# forward the model
slice_results.extend(inference_detector(model, images))

if end >= len(sliced_image_object):
break
start += args.batch_size

if source_type['is_dir']:
filename = os.path.relpath(file, args.img).replace('/', '_')
else:
filename = os.path.basename(file)

img = mmcv.imconvert(img, 'bgr', 'rgb')
out_file = None if args.show else os.path.join(args.out_dir, filename)

# export debug images
if args.debug:
# export sliced image results
name, suffix = os.path.splitext(filename)

shifted_instances = shift_predictions(
slice_results,
sliced_image_object.starting_pixels,
src_image_shape=(height, width))
merged_result = slice_results[0].clone()
merged_result.pred_instances = shifted_instances

debug_file_name = name + '_debug' + suffix
debug_out_file = None if args.show else os.path.join(
args.out_dir, debug_file_name)
visualizer.set_image(img.copy())

debug_grids = []
for starting_point in sliced_image_object.starting_pixels:
start_point_x = starting_point[0]
start_point_y = starting_point[1]
end_point_x = start_point_x + args.patch_size
end_point_y = start_point_y + args.patch_size
debug_grids.append(
[start_point_x, start_point_y, end_point_x, end_point_y])
debug_grids = np.array(debug_grids)
debug_grids[:, 0::2] = np.clip(debug_grids[:, 0::2], 1,
img.shape[1] - 1)
debug_grids[:, 1::2] = np.clip(debug_grids[:, 1::2], 1,
img.shape[0] - 1)

palette = np.random.randint(0, 256, size=(len(debug_grids), 3))
palette = [tuple(c) for c in palette]
line_styles = random.choices(['-', '-.', ':'], k=len(debug_grids))
visualizer.draw_bboxes(
debug_grids,
edge_colors=palette,
alpha=1,
line_styles=line_styles)
visualizer.draw_bboxes(
debug_grids, face_colors=palette, alpha=0.15)

visualizer.draw_texts(
list(range(len(debug_grids))),
debug_grids[:, :2] + 5,
colors='w')

visualizer.add_datasample(
debug_file_name,
visualizer.get_image(),
data_sample=merged_result,
draw_gt=False,
show=args.show,
wait_time=0,
out_file=debug_out_file,
pred_score_thr=args.score_thr,
)

if args.save_patch:
debug_patch_out_dir = os.path.join(args.out_dir,
f'{name}_patch')
for i, slice_result in enumerate(slice_results):
patch_out_file = os.path.join(
debug_patch_out_dir,
f'{filename}_slice_{i}_result.jpg')
image = mmcv.imconvert(sliced_image_object.images[i],
'bgr', 'rgb')

visualizer.add_datasample(
'patch_result',
image,
data_sample=slice_result,
draw_gt=False,
show=False,
wait_time=0,
out_file=patch_out_file,
pred_score_thr=args.score_thr,
)

image_result = merge_results_by_nms(
slice_results,
sliced_image_object.starting_pixels,
src_image_shape=(height, width),
nms_cfg={
'type': args.merge_nms_type,
'iou_threshold': args.merge_iou_thr
})

visualizer.add_datasample(
filename,
img,
data_sample=image_result,
draw_gt=False,
show=args.show,
wait_time=0,
out_file=out_file,
pred_score_thr=args.score_thr,
)
progress_bar.update()

if not args.show or (args.debug and args.save_patch):
print_log(
f'\nResults have been saved at {os.path.abspath(args.out_dir)}')


if __name__ == '__main__':
main()
47 changes: 45 additions & 2 deletions docs/en/user_guides/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ Note: `inference_detector` only supports single-image inference for now.

## Demos

We also provide three demo scripts, implemented with high-level APIs and supporting functionality codes.
We also provide four demo scripts, implemented with high-level APIs and supporting functionality codes.
Source codes are available [here](https://github.com/open-mmlab/mmdetection/blob/main/demo).

### Image demo
Expand All @@ -115,7 +115,7 @@ python demo/image_demo.py demo/demo.jpg \
--device cpu
```

#### Webcam demo
### Webcam demo

This is a live demo from a webcam.

Expand Down Expand Up @@ -187,6 +187,49 @@ python demo/video_gpuaccel_demo.py demo/demo.mp4 \
--nvdecode --out result.mp4
```

### Large-image inference demo

This is a script for slicing inference on large images.

```
python demo/large_image_demo.py \
${IMG_PATH} \
${CONFIG_FILE} \
${CHECKPOINT_FILE} \
--device ${GPU_ID} \
--show \
--tta \
--score-thr ${SCORE_THR} \
--patch-size ${PATCH_SIZE} \
--patch-overlap-ratio ${PATCH_OVERLAP_RATIO} \
--merge-iou-thr ${MERGE_IOU_THR} \
--merge-nms-type ${MERGE_NMS_TYPE} \
--batch-size ${BATCH_SIZE} \
--debug \
--save-patch
```

Examples:

```shell
# inferecnce without tta
wget -P checkpoint https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_2x_coco/faster_rcnn_r101_fpn_2x_coco_bbox_mAP-0.398_20200504_210455-1d2dac9c.pth

python demo/large_image_demo.py \
demo/large_image.jpg \
configs/faster_rcnn/faster-rcnn_r101_fpn_2x_coco.py \
checkpoint/faster_rcnn_r101_fpn_2x_coco_bbox_mAP-0.398_20200504_210455-1d2dac9c.pth

# inference with tta
wget -P checkpoint https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth

python demo/large_image_demo.py \
demo/large_image.jpg \
configs/retinanet/retinanet_r50_fpn_1x_coco.py \
checkpoint/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth --tta

```

## Multi-modal algorithm inference demo and evaluation

As multimodal vision algorithms continue to evolve, MMDetection has also supported such algorithms. This section demonstrates how to use the demo and eval scripts corresponding to multimodal algorithms using the GLIP algorithm and model as the example. Moreover, MMDetection integrated a [gradio_demo project](../../../projects/gradio_demo/), which allows developers to quickly play with all image input tasks in MMDetection on their local devices. Check the [document](../../../projects/gradio_demo/README.md) for more details.
Expand Down
Loading

0 comments on commit 60b29b3

Please sign in to comment.