-
Notifications
You must be signed in to change notification settings - Fork 162
/
models.py
100 lines (76 loc) · 2.73 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.utils.data
from torch import nn, optim
from padding_same_conv import Conv2d
def toTensor(img):
img = torch.from_numpy(img.transpose((0, 3, 1, 2)))
return img
def var_to_np(img_var):
return img_var.data.cpu().numpy()
class _ConvLayer(nn.Sequential):
def __init__(self, input_features, output_features):
super(_ConvLayer, self).__init__()
self.add_module('conv2', Conv2d(input_features, output_features,
kernel_size=5, stride=2))
self.add_module('leakyrelu', nn.LeakyReLU(0.1, inplace=True))
class _UpScale(nn.Sequential):
def __init__(self, input_features, output_features):
super(_UpScale, self).__init__()
self.add_module('conv2_', Conv2d(input_features, output_features * 4,
kernel_size=3))
self.add_module('leakyrelu', nn.LeakyReLU(0.1, inplace=True))
self.add_module('pixelshuffler', _PixelShuffler())
class Flatten(nn.Module):
def forward(self, input):
output = input.view(input.size(0), -1)
return output
class Reshape(nn.Module):
def forward(self, input):
output = input.view(-1, 1024, 4, 4) # channel * 4 * 4
return output
class _PixelShuffler(nn.Module):
def forward(self, input):
batch_size, c, h, w = input.size()
rh, rw = (2, 2)
oh, ow = h * rh, w * rw
oc = c // (rh * rw)
out = input.view(batch_size, rh, rw, oc, h, w)
out = out.permute(0, 3, 4, 1, 5, 2).contiguous()
out = out.view(batch_size, oc, oh, ow) # channel first
return out
class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder, self).__init__()
self.encoder = nn.Sequential(
_ConvLayer(3, 128),
_ConvLayer(128, 256),
_ConvLayer(256, 512),
_ConvLayer(512, 1024),
Flatten(),
nn.Linear(1024 * 4 * 4, 1024),
nn.Linear(1024, 1024 * 4 * 4),
Reshape(),
_UpScale(1024, 512),
)
self.decoder_A = nn.Sequential(
_UpScale(512, 256),
_UpScale(256, 128),
_UpScale(128, 64),
Conv2d(64, 3, kernel_size=5, padding=1),
nn.Sigmoid(),
)
self.decoder_B = nn.Sequential(
_UpScale(512, 256),
_UpScale(256, 128),
_UpScale(128, 64),
Conv2d(64, 3, kernel_size=5, padding=1),
nn.Sigmoid(),
)
def forward(self, x, select='A'):
if select == 'A':
out = self.encoder(x)
out = self.decoder_A(out)
else:
out = self.encoder(x)
out = self.decoder_B(out)
return out