-
Notifications
You must be signed in to change notification settings - Fork 9
/
colorize.py
88 lines (52 loc) · 2.16 KB
/
colorize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import argparse
import torch
import torchvision.transforms
import numpy as np
from skimage import io, color
from src import colnet
from src import dataset
from src import utils
def colorize(img_path, model):
# Load model
checkpoint = torch.load(model, map_location=torch.device("cpu"))
classes = checkpoint['classes']
net_divisor = checkpoint['net_divisor']
num_classes = len(classes)
net = colnet.ColNet(num_classes=num_classes, net_divisor=net_divisor)
net.load_state_dict(checkpoint['model_state_dict'])
# Image transforms
composed_transforms = torchvision.transforms.Compose(
[dataset.HandleGrayscale(),
dataset.RandomCrop(224),
dataset.Rgb2LabNorm(),
dataset.ToTensor(),
dataset.SplitLab()]
)
# Load and process image
img = io.imread(img_path)
img_name = os.path.basename(img_path)
L, ab = composed_transforms(img)
L_tensor = torch.from_numpy(np.expand_dims(L, axis=0))
# Colorize
softmax = torch.nn.Softmax(dim=1)
net.eval()
with torch.no_grad():
ab_out, predicted = net(L_tensor)
img_colorized = utils.net_out2rgb(L, ab_out[0])
colorized_img_name = "colorized-" + img_name
io.imsave(colorized_img_name, img_colorized)
print("\nSaved image to: {}\n".format(colorized_img_name))
sm = softmax(predicted)
probs = sm[0].numpy()
probs_and_classes = sorted(zip(probs, classes), key=lambda x: x[0], reverse=True)
print("Predicted labels: \n")
for p, c in probs_and_classes[:10]:
print("{:>7.2f}% \t{}".format(p*100.0, c))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="A script to colorize a photo")
parser.add_argument('image', help="Path to the image. RGB one will be converted to grayscale")
parser.add_argument('model', help="Path a *.pt model")
args = parser.parse_args()
print("[Warrning] Only 224x224 images are supported. Otherwise an image will be randomly cropped")
colorize(args.image, args.model)