-
Notifications
You must be signed in to change notification settings - Fork 101
/
AE_ts_model.py
executable file
·186 lines (156 loc) · 8.19 KB
/
AE_ts_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 22 10:43:29 2016
@author: Rob Romijnders
"""
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
from tensorflow.contrib.rnn import LSTMCell
from sklearn.manifold import TSNE
from sklearn.decomposition import TruncatedSVD
def open_data(direc, ratio_train=0.8, dataset="ECG5000"):
"""Input:
direc: location of the UCR archive
ratio_train: ratio to split training and testset
dataset: name of the dataset in the UCR archive"""
datadir = direc + '/' + dataset + '/' + dataset
data_train = np.loadtxt(datadir + '_TRAIN', delimiter=',')
data_test_val = np.loadtxt(datadir + '_TEST', delimiter=',')[:-1]
data = np.concatenate((data_train, data_test_val), axis=0)
N, D = data.shape
ind_cut = int(ratio_train * N)
ind = np.random.permutation(N)
return data[ind[:ind_cut], 1:], data[ind[ind_cut:], 1:], data[ind[:ind_cut], 0], data[ind[ind_cut:], 0]
def plot_data(X_train, y_train, plot_row=5):
counts = dict(Counter(y_train))
num_classes = len(np.unique(y_train))
f, axarr = plt.subplots(plot_row, num_classes)
for c in np.unique(y_train): # Loops over classes, plot as columns
c = int(c)
ind = np.where(y_train == c)
ind_plot = np.random.choice(ind[0], size=plot_row)
for n in range(plot_row): # Loops over rows
axarr[n, c].plot(X_train[ind_plot[n], :])
# Only shops axes for bottom row and left column
if n == 0:
axarr[n, c].set_title('Class %.0f (%.0f)' % (c, counts[float(c)]))
if not n == plot_row - 1:
plt.setp([axarr[n, c].get_xticklabels()], visible=False)
if not c == 0:
plt.setp([axarr[n, c].get_yticklabels()], visible=False)
f.subplots_adjust(hspace=0) # No horizontal space between subplots
f.subplots_adjust(wspace=0) # No vertical space between subplots
plt.show()
return
def plot_z_run(z_run, label, ):
f1, ax1 = plt.subplots(2, 1)
# First fit a PCA
PCA_model = TruncatedSVD(n_components=3).fit(z_run)
z_run_reduced = PCA_model.transform(z_run)
ax1[0].scatter(z_run_reduced[:, 0], z_run_reduced[:, 1], c=label, marker='*', linewidths=0)
ax1[0].set_title('PCA on z_run')
# THen fit a tSNE
tSNE_model = TSNE(verbose=2, perplexity=80, min_grad_norm=1E-12, n_iter=3000)
z_run_tsne = tSNE_model.fit_transform(z_run)
ax1[1].scatter(z_run_tsne[:, 0], z_run_tsne[:, 1], c=label, marker='*', linewidths=0)
ax1[1].set_title('tSNE on z_run')
plt.show()
return
class Model:
def __init__(self, config):
# Hyperparameters
num_layers = config['num_layers']
hidden_size = config['hidden_size']
max_grad_norm = config['max_grad_norm']
batch_size = config['batch_size']
sl = config['sl']
crd = config['crd']
num_l = config['num_l']
learning_rate = config['learning_rate']
self.sl = sl
self.batch_size = batch_size
# Nodes for the input variables
self.x = tf.placeholder("float", shape=[batch_size, sl], name='Input_data')
self.x_exp = tf.expand_dims(self.x, 1)
self.keep_prob = tf.placeholder("float")
with tf.variable_scope("Encoder"):
# Th encoder cell, multi-layered with dropout
cell_enc = tf.contrib.rnn.MultiRNNCell([LSTMCell(hidden_size) for _ in range(num_layers)])
cell_enc = tf.contrib.rnn.DropoutWrapper(cell_enc, output_keep_prob=self.keep_prob)
# Initial state
initial_state_enc = cell_enc.zero_state(batch_size, tf.float32)
# with tf.name_scope("Enc_2_lat") as scope:
# layer for mean of z
W_mu = tf.get_variable('W_mu', [hidden_size, num_l])
outputs_enc, _ = tf.contrib.rnn.static_rnn(cell_enc,
inputs=tf.unstack(self.x_exp, axis=2),
initial_state=initial_state_enc)
cell_output = outputs_enc[-1]
b_mu = tf.get_variable('b_mu', [num_l])
# For all intents and purposes, self.z_mu is the Tensor containing the hidden representations
# I got many questions over email about this. If you want to do visualization, clustering or subsequent
# classification, then use this z_mu
self.z_mu = tf.nn.xw_plus_b(cell_output, W_mu, b_mu, name='z_mu') # mu, mean, of latent space
# Train the point in latent space to have zero-mean and unit-variance on batch basis
lat_mean, lat_var = tf.nn.moments(self.z_mu, axes=[1])
self.loss_lat_batch = tf.reduce_mean(tf.square(lat_mean) + lat_var - tf.log(lat_var) - 1)
with tf.name_scope("Lat_2_dec"):
# layer to generate initial state
W_state = tf.get_variable('W_state', [num_l, hidden_size])
b_state = tf.get_variable('b_state', [hidden_size])
z_state = tf.nn.xw_plus_b(self.z_mu, W_state, b_state, name='z_state') # mu, mean, of latent space
with tf.variable_scope("Decoder"):
# The decoder, also multi-layered
cell_dec = tf.contrib.rnn.MultiRNNCell([LSTMCell(hidden_size) for _ in range(num_layers)])
# Initial state
initial_state_dec = tuple([(z_state, z_state)] * num_layers)
dec_inputs = [tf.zeros([batch_size, 1])] * sl
# outputs_dec, _ = tf.nn.seq2seq.rnn_decoder(dec_inputs, initial_state_dec, cell_dec)
outputs_dec, _ = tf.contrib.rnn.static_rnn(cell_dec,
inputs=dec_inputs,
initial_state=initial_state_dec)
with tf.name_scope("Out_layer"):
params_o = 2 * crd # Number of coordinates + variances
W_o = tf.get_variable('W_o', [hidden_size, params_o])
b_o = tf.get_variable('b_o', [params_o])
outputs = tf.concat(outputs_dec, axis=0) # tensor in [sl*batch_size,hidden_size]
h_out = tf.nn.xw_plus_b(outputs, W_o, b_o)
h_mu, h_sigma_log = tf.unstack(tf.reshape(h_out, [sl, batch_size, params_o]), axis=2)
h_sigma = tf.exp(h_sigma_log)
dist = tf.contrib.distributions.Normal(h_mu, h_sigma)
px = dist.log_prob(tf.transpose(self.x))
loss_seq = -px
self.loss_seq = tf.reduce_mean(loss_seq)
with tf.name_scope("train"):
# Use learning rte decay
global_step = tf.Variable(0, trainable=False)
lr = tf.train.exponential_decay(learning_rate, global_step, 1000, 0.1, staircase=False)
self.loss = self.loss_seq + self.loss_lat_batch
# Route the gradients so that we can plot them on Tensorboard
tvars = tf.trainable_variables()
# We clip the gradients to prevent explosion
grads = tf.gradients(self.loss, tvars)
grads, _ = tf.clip_by_global_norm(grads, max_grad_norm)
self.numel = tf.constant([[0]])
# And apply the gradients
optimizer = tf.train.AdamOptimizer(lr)
gradients = zip(grads, tvars)
self.train_step = optimizer.apply_gradients(gradients, global_step=global_step)
# for gradient, variable in gradients: #plot the gradient of each trainable variable
# if isinstance(gradient, ops.IndexedSlices):
# grad_values = gradient.values
# else:
# grad_values = gradient
#
# self.numel +=tf.reduce_sum(tf.size(variable))
# tf.summary.histogram(variable.name, variable)
# tf.summary.histogram(variable.name + "/gradients", grad_values)
# tf.summary.histogram(variable.name + "/gradient_norm", clip_ops.global_norm([grad_values]))
self.numel = tf.constant([[0]])
tf.summary.tensor_summary('lat_state', self.z_mu)
# Define one op to call all summaries
self.merged = tf.summary.merge_all()
# and one op to initialize the variables
self.init_op = tf.global_variables_initializer()