-
Notifications
You must be signed in to change notification settings - Fork 381
/
inference.py
176 lines (153 loc) · 7.33 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import argparse
import scipy
import os
import numpy as np
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from scipy import ndimage
from tqdm import tqdm
from math import ceil
from glob import glob
from PIL import Image
import dataloaders
import models
from utils.helpers import colorize_mask
from collections import OrderedDict
def pad_image(img, target_size):
rows_to_pad = max(target_size[0] - img.shape[2], 0)
cols_to_pad = max(target_size[1] - img.shape[3], 0)
padded_img = F.pad(img, (0, cols_to_pad, 0, rows_to_pad), "constant", 0)
return padded_img
def sliding_predict(model, image, num_classes, flip=True):
image_size = image.shape
tile_size = (int(image_size[2]//2.5), int(image_size[3]//2.5))
overlap = 1/3
stride = ceil(tile_size[0] * (1 - overlap))
num_rows = int(ceil((image_size[2] - tile_size[0]) / stride) + 1)
num_cols = int(ceil((image_size[3] - tile_size[1]) / stride) + 1)
total_predictions = np.zeros((num_classes, image_size[2], image_size[3]))
count_predictions = np.zeros((image_size[2], image_size[3]))
tile_counter = 0
for row in range(num_rows):
for col in range(num_cols):
x_min, y_min = int(col * stride), int(row * stride)
x_max = min(x_min + tile_size[1], image_size[3])
y_max = min(y_min + tile_size[0], image_size[2])
img = image[:, :, y_min:y_max, x_min:x_max]
padded_img = pad_image(img, tile_size)
tile_counter += 1
padded_prediction = model(padded_img)
if flip:
fliped_img = padded_img.flip(-1)
fliped_predictions = model(padded_img.flip(-1))
padded_prediction = 0.5 * (fliped_predictions.flip(-1) + padded_prediction)
predictions = padded_prediction[:, :, :img.shape[2], :img.shape[3]]
count_predictions[y_min:y_max, x_min:x_max] += 1
total_predictions[:, y_min:y_max, x_min:x_max] += predictions.data.cpu().numpy().squeeze(0)
total_predictions /= count_predictions
return total_predictions
def multi_scale_predict(model, image, scales, num_classes, device, flip=False):
input_size = (image.size(2), image.size(3))
upsample = nn.Upsample(size=input_size, mode='bilinear', align_corners=True)
total_predictions = np.zeros((num_classes, image.size(2), image.size(3)))
image = image.data.data.cpu().numpy()
for scale in scales:
scaled_img = ndimage.zoom(image, (1.0, 1.0, float(scale), float(scale)), order=1, prefilter=False)
scaled_img = torch.from_numpy(scaled_img).to(device)
scaled_prediction = upsample(model(scaled_img).cpu())
if flip:
fliped_img = scaled_img.flip(-1).to(device)
fliped_predictions = upsample(model(fliped_img).cpu())
scaled_prediction = 0.5 * (fliped_predictions.flip(-1) + scaled_prediction)
total_predictions += scaled_prediction.data.cpu().numpy().squeeze(0)
total_predictions /= len(scales)
return total_predictions
def save_images(image, mask, output_path, image_file, palette):
# Saves the image, the model output and the results after the post processing
w, h = image.size
image_file = os.path.basename(image_file).split('.')[0]
colorized_mask = colorize_mask(mask, palette)
colorized_mask.save(os.path.join(output_path, image_file+'.png'))
# output_im = Image.new('RGB', (w*2, h))
# output_im.paste(image, (0,0))
# output_im.paste(colorized_mask, (w,0))
# output_im.save(os.path.join(output_path, image_file+'_colorized.png'))
# mask_img = Image.fromarray(mask, 'L')
# mask_img.save(os.path.join(output_path, image_file+'.png'))
def main():
args = parse_arguments()
config = json.load(open(args.config))
# Dataset used for training the model
dataset_type = config['train_loader']['type']
assert dataset_type in ['VOC', 'COCO', 'CityScapes', 'ADE20K', 'DeepScene']
if dataset_type == 'CityScapes':
scales = [0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25]
else:
scales = [0.75, 1.0, 1.25, 1.5, 1.75, 2.0]
loader = getattr(dataloaders, config['train_loader']['type'])(**config['train_loader']['args'])
to_tensor = transforms.ToTensor()
normalize = transforms.Normalize(loader.MEAN, loader.STD)
num_classes = loader.dataset.num_classes
palette = loader.dataset.palette
# Model
model = getattr(models, config['arch']['type'])(num_classes, **config['arch']['args'])
availble_gpus = list(range(torch.cuda.device_count()))
device = torch.device('cuda:0' if len(availble_gpus) > 0 else 'cpu')
# Load checkpoint
checkpoint = torch.load(args.model, map_location=device)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint.keys():
checkpoint = checkpoint['state_dict']
# If during training, we used data parallel
if 'module' in list(checkpoint.keys())[0] and not isinstance(model, torch.nn.DataParallel):
# for gpu inference, use data parallel
if "cuda" in device.type:
model = torch.nn.DataParallel(model)
else:
# for cpu inference, remove module
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
name = k[7:]
new_state_dict[name] = v
checkpoint = new_state_dict
# load
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
if not os.path.exists('outputs'):
os.makedirs('outputs')
image_files = sorted(glob(os.path.join(args.images, f'*.{args.extension}')))
with torch.no_grad():
tbar = tqdm(image_files, ncols=100)
for img_file in tbar:
image = Image.open(img_file).convert('RGB')
input = normalize(to_tensor(image)).unsqueeze(0)
if args.mode == 'multiscale':
prediction = multi_scale_predict(model, input, scales, num_classes, device)
elif args.mode == 'sliding':
prediction = sliding_predict(model, input, num_classes)
else:
prediction = model(input.to(device))
prediction = prediction.squeeze(0).cpu().numpy()
prediction = F.softmax(torch.from_numpy(prediction), dim=0).argmax(0).cpu().numpy()
save_images(image, prediction, args.output, img_file, palette)
def parse_arguments():
parser = argparse.ArgumentParser(description='Inference')
parser.add_argument('-c', '--config', default='VOC',type=str,
help='The config used to train the model')
parser.add_argument('-mo', '--mode', default='multiscale', type=str,
help='Mode used for prediction: either [multiscale, sliding]')
parser.add_argument('-m', '--model', default='model_weights.pth', type=str,
help='Path to the .pth model checkpoint to be used in the prediction')
parser.add_argument('-i', '--images', default=None, type=str,
help='Path to the images to be segmented')
parser.add_argument('-o', '--output', default='outputs', type=str,
help='Output Path')
parser.add_argument('-e', '--extension', default='jpg', type=str,
help='The extension of the images to be segmented')
args = parser.parse_args()
return args
if __name__ == '__main__':
main()