-
Notifications
You must be signed in to change notification settings - Fork 19
/
RES_VAE.py
145 lines (114 loc) · 4.93 KB
/
RES_VAE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import torch.nn as nn
import torch.utils.data
class ResDown(nn.Module):
"""
Residual down sampling block for the encoder
"""
def __init__(self, channel_in, channel_out, kernel_size=3):
super(ResDown, self).__init__()
self.conv1 = nn.Conv2d(channel_in, channel_out // 2, kernel_size, 2, kernel_size // 2)
self.bn1 = nn.BatchNorm2d(channel_out // 2, eps=1e-4)
self.conv2 = nn.Conv2d(channel_out // 2, channel_out, kernel_size, 1, kernel_size // 2)
self.bn2 = nn.BatchNorm2d(channel_out, eps=1e-4)
self.conv3 = nn.Conv2d(channel_in, channel_out, kernel_size, 2, kernel_size // 2)
self.act_fnc = nn.ELU()
def forward(self, x):
skip = self.conv3(x)
x = self.act_fnc(self.bn1(self.conv1(x)))
x = self.conv2(x)
return self.act_fnc(self.bn2(x + skip))
class ResUp(nn.Module):
"""
Residual up sampling block for the decoder
"""
def __init__(self, channel_in, channel_out, kernel_size=3, scale_factor=2):
super(ResUp, self).__init__()
self.conv1 = nn.Conv2d(channel_in, channel_in // 2, kernel_size, 1, kernel_size // 2)
self.bn1 = nn.BatchNorm2d(channel_in // 2, eps=1e-4)
self.conv2 = nn.Conv2d(channel_in // 2, channel_out, kernel_size, 1, kernel_size // 2)
self.bn2 = nn.BatchNorm2d(channel_out, eps=1e-4)
self.conv3 = nn.Conv2d(channel_in, channel_out, kernel_size, 1, kernel_size // 2)
self.up_nn = nn.Upsample(scale_factor=scale_factor, mode="nearest")
self.act_fnc = nn.ELU()
def forward(self, x):
x = self.up_nn(x)
skip = self.conv3(x)
x = self.act_fnc(self.bn1(self.conv1(x)))
x = self.conv2(x)
return self.act_fnc(self.bn2(x + skip))
class Encoder(nn.Module):
"""
Encoder block
Built for a 3x64x64 image and will result in a latent vector of size z x 1 x 1
As the network is fully convolutional it will work for images LARGER than 64
For images sized 64 * n where n is a power of 2, (1, 2, 4, 8 etc) the latent feature map size will be z x n x n
When in .eval() the Encoder will not sample from the distribution and will instead output mu as the encoding vector
and log_var will be None
"""
def __init__(self, channels, ch=64, latent_channels=512):
super(Encoder, self).__init__()
self.conv_in = nn.Conv2d(channels, ch, 7, 1, 3)
self.res_down_block1 = ResDown(ch, 2 * ch)
self.res_down_block2 = ResDown(2 * ch, 4 * ch)
self.res_down_block3 = ResDown(4 * ch, 8 * ch)
self.res_down_block4 = ResDown(8 * ch, 16 * ch)
self.conv_mu = nn.Conv2d(16 * ch, latent_channels, 4, 1)
self.conv_log_var = nn.Conv2d(16 * ch, latent_channels, 4, 1)
self.act_fnc = nn.ELU()
def sample(self, mu, log_var):
std = torch.exp(0.5*log_var)
eps = torch.randn_like(std)
return mu + eps*std
def forward(self, x):
x = self.act_fnc(self.conv_in(x))
x = self.res_down_block1(x) # 32
x = self.res_down_block2(x) # 16
x = self.res_down_block3(x) # 8
x = self.res_down_block4(x) # 4
mu = self.conv_mu(x) # 1
log_var = self.conv_log_var(x) # 1
if self.training:
x = self.sample(mu, log_var)
else:
x = mu
return x, mu, log_var
class Decoder(nn.Module):
"""
Decoder block
Built to be a mirror of the encoder block
"""
def __init__(self, channels, ch=64, latent_channels=512):
super(Decoder, self).__init__()
self.conv_t_up = nn.ConvTranspose2d(latent_channels, ch * 16, 4, 1)
self.res_up_block1 = ResUp(ch * 16, ch * 8)
self.res_up_block2 = ResUp(ch * 8, ch * 4)
self.res_up_block3 = ResUp(ch * 4, ch * 2)
self.res_up_block4 = ResUp(ch * 2, ch)
self.conv_out = nn.Conv2d(ch, channels, 3, 1, 1)
self.act_fnc = nn.ELU()
def forward(self, x):
x = self.act_fnc(self.conv_t_up(x)) # 4
x = self.res_up_block1(x) # 8
x = self.res_up_block2(x) # 16
x = self.res_up_block3(x) # 32
x = self.res_up_block4(x) # 64
x = torch.tanh(self.conv_out(x))
return x
class VAE(nn.Module):
"""
VAE network, uses the above encoder and decoder blocks
"""
def __init__(self, channel_in=3, ch=64, latent_channels=512):
super(VAE, self).__init__()
"""Res VAE Network
channel_in = number of channels of the image
z = the number of channels of the latent representation
(for a 64x64 image this is the size of the latent vector)
"""
self.encoder = Encoder(channel_in, ch=ch, latent_channels=latent_channels)
self.decoder = Decoder(channel_in, ch=ch, latent_channels=latent_channels)
def forward(self, x):
encoding, mu, log_var = self.encoder(x)
recon_img = self.decoder(encoding)
return recon_img, mu, log_var