-
Notifications
You must be signed in to change notification settings - Fork 435
/
layers.py
188 lines (149 loc) · 5.78 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from inits import *
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
# global unique layer ID dictionary for layer name assignment
_LAYER_UIDS = {}
def get_layer_uid(layer_name=''):
"""Helper function, assigns unique layer IDs."""
if layer_name not in _LAYER_UIDS:
_LAYER_UIDS[layer_name] = 1
return 1
else:
_LAYER_UIDS[layer_name] += 1
return _LAYER_UIDS[layer_name]
def sparse_dropout(x, keep_prob, noise_shape):
"""Dropout for sparse tensors."""
random_tensor = keep_prob
random_tensor += tf.random_uniform(noise_shape)
dropout_mask = tf.cast(tf.floor(random_tensor), dtype=tf.bool)
pre_out = tf.sparse_retain(x, dropout_mask)
return pre_out * (1./keep_prob)
def dot(x, y, sparse=False):
"""Wrapper for tf.matmul (sparse vs dense)."""
if sparse:
res = tf.sparse_tensor_dense_matmul(x, y)
else:
res = tf.matmul(x, y)
return res
class Layer(object):
"""Base layer class. Defines basic API for all layer objects.
Implementation inspired by keras (http://keras.io).
# Properties
name: String, defines the variable scope of the layer.
logging: Boolean, switches Tensorflow histogram logging on/off
# Methods
_call(inputs): Defines computation graph of layer
(i.e. takes input, returns output)
__call__(inputs): Wrapper for _call()
_log_vars(): Log all variables
"""
def __init__(self, **kwargs):
allowed_kwargs = {'name', 'logging'}
for kwarg in kwargs.keys():
assert kwarg in allowed_kwargs, 'Invalid keyword argument: ' + kwarg
name = kwargs.get('name')
if not name:
layer = self.__class__.__name__.lower()
name = layer + '_' + str(get_layer_uid(layer))
self.name = name
self.vars = {}
logging = kwargs.get('logging', False)
self.logging = logging
self.sparse_inputs = False
def _call(self, inputs):
return inputs
def __call__(self, inputs):
with tf.name_scope(self.name):
if self.logging and not self.sparse_inputs:
tf.summary.histogram(self.name + '/inputs', inputs)
outputs = self._call(inputs)
if self.logging:
tf.summary.histogram(self.name + '/outputs', outputs)
return outputs
def _log_vars(self):
for var in self.vars:
tf.summary.histogram(self.name + '/vars/' + var, self.vars[var])
class Dense(Layer):
"""Dense layer."""
def __init__(self, input_dim, output_dim, placeholders, dropout=0., sparse_inputs=False,
act=tf.nn.relu, bias=False, featureless=False, **kwargs):
super(Dense, self).__init__(**kwargs)
if dropout:
self.dropout = placeholders['dropout']
else:
self.dropout = 0.
self.act = act
self.sparse_inputs = sparse_inputs
self.featureless = featureless
self.bias = bias
# helper variable for sparse dropout
self.num_features_nonzero = placeholders['num_features_nonzero']
with tf.variable_scope(self.name + '_vars'):
self.vars['weights'] = glorot([input_dim, output_dim],
name='weights')
if self.bias:
self.vars['bias'] = zeros([output_dim], name='bias')
if self.logging:
self._log_vars()
def _call(self, inputs):
x = inputs
# dropout
if self.sparse_inputs:
x = sparse_dropout(x, 1-self.dropout, self.num_features_nonzero)
else:
x = tf.nn.dropout(x, 1-self.dropout)
# transform
output = dot(x, self.vars['weights'], sparse=self.sparse_inputs)
# bias
if self.bias:
output += self.vars['bias']
return self.act(output)
class GraphConvolution(Layer):
"""Graph convolution layer."""
def __init__(self, input_dim, output_dim, placeholders, dropout=0.,
sparse_inputs=False, act=tf.nn.relu, bias=False,
featureless=False, **kwargs):
super(GraphConvolution, self).__init__(**kwargs)
if dropout:
self.dropout = placeholders['dropout']
else:
self.dropout = 0.
self.act = act
self.support = placeholders['support']
self.sparse_inputs = sparse_inputs
self.featureless = featureless
self.bias = bias
# helper variable for sparse dropout
self.num_features_nonzero = placeholders['num_features_nonzero']
with tf.variable_scope(self.name + '_vars'):
for i in range(len(self.support)):
self.vars['weights_' + str(i)] = glorot([input_dim, output_dim],
name='weights_' + str(i))
if self.bias:
self.vars['bias'] = zeros([output_dim], name='bias')
if self.logging:
self._log_vars()
def _call(self, inputs):
x = inputs
# dropout
if self.sparse_inputs:
x = sparse_dropout(x, 1-self.dropout, self.num_features_nonzero)
else:
x = tf.nn.dropout(x, 1-self.dropout)
# convolve
supports = list()
for i in range(len(self.support)):
if not self.featureless:
pre_sup = dot(x, self.vars['weights_' + str(i)],
sparse=self.sparse_inputs)
else:
pre_sup = self.vars['weights_' + str(i)]
support = dot(self.support[i], pre_sup, sparse=True)
supports.append(support)
output = tf.add_n(supports)
# bias
if self.bias:
output += self.vars['bias']
self.embedding = output #output
return self.act(output)