-
Notifications
You must be signed in to change notification settings - Fork 3
/
layers.py
71 lines (59 loc) · 2.72 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch.nn as nn
class AutoEncoder(nn.Module):
def __init__(self, input_dim, feature_dim, dims):
super(AutoEncoder, self).__init__()
self.encoder = nn.Sequential()
for i in range(len(dims)+1):
if i == 0:
self.encoder.add_module('Linear%d' % i, nn.Linear(input_dim, dims[i]))
elif i == len(dims):
self.encoder.add_module('Linear%d' % i, nn.Linear(dims[i-1], feature_dim))
else:
self.encoder.add_module('Linear%d' % i, nn.Linear(dims[i-1], dims[i]))
self.encoder.add_module('relu%d' % i, nn.ReLU())
def forward(self, x):
return self.encoder(x)
class AutoDecoder(nn.Module):
def __init__(self, input_dim, feature_dim, dims):
super(AutoDecoder, self).__init__()
self.decoder = nn.Sequential()
dims = list(reversed(dims))
for i in range(len(dims)+1):
if i == 0:
self.decoder.add_module('Linear%d' % i, nn.Linear(feature_dim, dims[i]))
elif i == len(dims):
self.decoder.add_module('Linear%d' % i, nn.Linear(dims[i-1], input_dim))
else:
self.decoder.add_module('Linear%d' % i, nn.Linear(dims[i-1], dims[i]))
self.decoder.add_module('relu%d' % i, nn.ReLU())
def forward(self, x):
return self.decoder(x)
class CVCLNetwork(nn.Module):
def __init__(self, num_views, input_sizes, dims, dim_high_feature, dim_low_feature, num_clusters):
super(CVCLNetwork, self).__init__()
self.encoders = list()
self.decoders = list()
for idx in range(num_views):
self.encoders.append(AutoEncoder(input_sizes[idx], dim_high_feature, dims))
self.decoders.append(AutoDecoder(input_sizes[idx], dim_high_feature, dims))
self.encoders = nn.ModuleList(self.encoders)
self.decoders = nn.ModuleList(self.decoders)
self.label_learning_module = nn.Sequential(
nn.Linear(dim_high_feature, dim_low_feature),
nn.Linear(dim_low_feature, num_clusters),
nn.Softmax(dim=1)
)
def forward(self, data_views):
lbps = list()
dvs = list()
features = list()
num_views = len(data_views)
for idx in range(num_views):
data_view = data_views[idx]
high_features = self.encoders[idx](data_view)
label_probs = self.label_learning_module(high_features)
data_view_recon = self.decoders[idx](high_features)
features.append(high_features)
lbps.append(label_probs)
dvs.append(data_view_recon)
return lbps, dvs, features