-
Notifications
You must be signed in to change notification settings - Fork 50
/
patch_utils.py
109 lines (83 loc) · 3.62 KB
/
patch_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import numpy as np
def num_patches(output_img_dim=(3, 256, 256), sub_patch_dim=(64, 64)):
"""
Creates non-overlaping patches to feed to the PATCH GAN
(Section 2.2.2 in paper)
The paper provides 3 options.
Pixel GAN = 1x1 patches (aka each pixel)
PatchGAN = nxn patches (non-overlaping blocks of the image)
ImageGAN = im_size x im_size (full image)
Ex: 4x4 image with patch_size of 2 means 4 non-overlaping patches
:param output_img_dim:
:param sub_patch_dim:
:return:
"""
# num of non-overlaping patches
nb_non_overlaping_patches = (output_img_dim[1] / sub_patch_dim[0]) * (output_img_dim[2] / sub_patch_dim[1])
# dimensions for the patch discriminator
patch_disc_img_dim = (output_img_dim[0], sub_patch_dim[0], sub_patch_dim[1])
return int(nb_non_overlaping_patches), patch_disc_img_dim
def extract_patches(images, sub_patch_dim):
"""
Cuts images into k subpatches
Each kth cut as the kth patches for all images
ex: input 3 images [im1, im2, im3]
output [[im_1_patch_1, im_2_patch_1], ... , [im_n-1_patch_k, im_n_patch_k]]
:param images: array of Images (num_images, im_channels, im_height, im_width)
:param sub_patch_dim: (height, width) ex: (30, 30) Subpatch dimensions
:return:
"""
im_height, im_width = images.shape[2:]
patch_height, patch_width = sub_patch_dim
# list out all xs ex: 0, 29, 58, ...
x_spots = range(0, im_width, patch_width)
# list out all ys ex: 0, 29, 58
y_spots = range(0, im_height, patch_height)
all_patches = []
for y in y_spots:
for x in x_spots:
# indexing here is cra
# images[num_images, num_channels, width, height]
# this says, cut a patch across all images at the same time with this width, height
image_patches = images[:, :, y: y+patch_height, x: x+patch_width]
all_patches.append(np.asarray(image_patches, dtype=np.float32))
return all_patches
def get_disc_batch(X_original_batch, X_decoded_batch, generator_model, batch_counter, patch_dim,
label_smoothing=False, label_flipping=0):
# Create X_disc: alternatively only generated or real images
if batch_counter % 2 == 0:
# generate fake image
# Produce an output
X_disc = generator_model.predict(X_decoded_batch)
# each image will produce a 1x2 vector for the results (aka is fake or not)
y_disc = np.zeros((X_disc.shape[0], 2), dtype=np.uint8)
# sets all first entries to 1. AKA saying these are fake
# these are fake iamges
y_disc[:, 0] = 1
if label_flipping > 0:
p = np.random.binomial(1, label_flipping)
if p > 0:
y_disc[:, [0, 1]] = y_disc[:, [1, 0]]
else:
# generate real image
X_disc = X_original_batch
# each image will produce a 1x2 vector for the results (aka is fake or not)
y_disc = np.zeros((X_disc.shape[0], 2), dtype=np.uint8)
if label_smoothing:
y_disc[:, 1] = np.random.uniform(low=0.9, high=1, size=y_disc.shape[0])
else:
# these are real images
y_disc[:, 1] = 1
if label_flipping > 0:
p = np.random.binomial(1, label_flipping)
if p > 0:
y_disc[:, [0, 1]] = y_disc[:, [1, 0]]
# Now extract patches form X_disc
X_disc = extract_patches(images=X_disc, sub_patch_dim=patch_dim)
return X_disc, y_disc
def gen_batch(X1, X2, batch_size):
while True:
idx = np.random.choice(X1.shape[0], batch_size, replace=False)
x1 = X1[idx]
x2 = X2[idx]
yield x1, x2