-
Notifications
You must be signed in to change notification settings - Fork 0
/
gibbs_pruning.py
210 lines (188 loc) · 10.1 KB
/
gibbs_pruning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import itertools
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow.keras.backend as K
import tensorflow_probability as tfp
from tensorflow.python.ops import nn
def tf_sample_gibbs(A, b, beta, N):
"""Naive sampling from p(x) = 1/Z*exp(-beta*(x^T*A*x + b*x)"""
xs = K.constant(list(itertools.product([-1,1], repeat=N))) # 2^N x N tensor
quad = -beta * K.sum(tf.tensordot(xs, A, axes=[[1],[0]]) * xs[:,:,None,None], axis=1)
quad = quad - K.max(quad, axis=[0])[None,:,:] # Put the highest quad logits around 0 to ensure precision when we add biases
logits = quad - beta*tf.tensordot(xs, b, axes=[[1],[0]])
logits = logits - K.max(logits, axis=[0]) # Same, tensorflow doesn't seem to work well with high logits
rows = tf.random.categorical(K.transpose(K.reshape(logits, (2**N,-1))), 1)[:,0]
slices = tf.gather(xs, rows, axis=0)
return K.reshape(K.transpose(slices), K.shape(b))
class GibbsPrunedConv2D(layers.Conv2D):
"""2D convolution with Gibbs pruning.
Inherits from keras.layers.Conv2D, so all Conv2D parameters are supported.
See https://arxiv.org/abs/2006.04981 for full details. Note that a
GibbsPruningAnnealer callback should be used to anneal beta. Also note that
this implementation does not try to take advantage of pruning to improve
efficiency, it's just for testing the effectiveness of the pruning method.
Arguments:
- filters, kernel_size, other Conv2D arguments: see Conv2D documentation
- p: Target pruning fraction, e.g. p=0.9 will converge to pruning 90% of
kernel weights
- hamiltonian: One of 'unstructured', 'kernel', or 'filter'. Controls what
Hamiltonian is used to define the Gibbs distribution during training,
either for unstructured pruning or for structured kernel-wise or
filter-wise pruning.
- c: The c parameter controlling the influence of structured pruning terms
in the Hamiltonian.
- test_pruning_mode: One of 'gibbs', 'kernel', or 'filter'. Controls how
pruning should be done at test time. 'gibbs' samples a mask from the
Gibbs distribution, whereas 'kernel' or 'filter' prunes the fraction p of
kernels or filters with lowest average magnitude. If the Gibbs pruning
procedure has converged, these should generally produce the same results,
but this parameter can ensure that the Gibbs distribution doesn't 'cheat'
at test time by including more weights than 'p' normally would allow.
- mcmc_steps: The number of MCMC iterations used in chromatic sampling for
filter-wise Gibbs pruning.
"""
def __init__(self, filters, kernel_size, p=0.5, hamiltonian='unstructured',
c=1.0, test_pruning_mode='gibbs', mcmc_steps=50, **kwargs):
self.p = p
self.hamiltonian = hamiltonian
self.c = c
self.test_pruning_mode = test_pruning_mode
self.mcmc_steps = mcmc_steps
self.beta = tf.Variable(1.0, trainable=False, name='beta') # This will be updated before training by the annealer
super().__init__(filters, kernel_size, **kwargs)
def build(self, input_shape):
super().build(input_shape)
if self.data_format == 'channels_first':
channel_axis = 1
else:
channel_axis = -1
self.n_channels = int(input_shape[channel_axis])
self.mask = K.zeros_like(self.kernel)
def call(self, inputs):
# This code appears in the Keras Conv layer, but is only compatible
# with TF2. I'm not sure what situations it addresses, but it doesn't
# seem necessary for the example code in this repo
# # Check if the input_shape in call() is different from that in build().
# # If they are different, recreate the _convolution_op to avoid the stateful
# # behavior.
# call_input_shape = inputs.get_shape()
# recreate_conv_op = (
# call_input_shape[1:] != self._build_conv_op_input_shape[1:])
# if recreate_conv_op:
# self._convolution_op = nn_ops.Convolution(
# call_input_shape,
# filter_shape=self.kernel.shape,
# dilation_rate=self.dilation_rate,
# strides=self.strides,
# padding=self._padding_op,
# data_format=self._conv_op_data_format)
mask = K.in_train_phase(lambda: self.train_mask(), lambda: self.test_mask())
self.add_metric(1-K.mean(mask), name='gp_mask_p', aggregation='mean')
outputs = self._convolution_op(inputs, self.kernel * mask)
if self.use_bias:
if self.data_format == 'channels_first':
outputs = nn.bias_add(outputs, self.bias, data_format='NCHW')
else:
outputs = nn.bias_add(outputs, self.bias, data_format='NHWC')
if self.activation is not None:
return self.activation(outputs)
return outputs
def test_mask(self):
W2 = self.kernel * self.kernel
if self.test_pruning_mode == 'gibbs':
return self.train_mask()
elif self.test_pruning_mode == 'kernel':
kernel_sums = tf.reduce_sum(W2, axis=[0,1])
Qp = tfp.stats.percentile(kernel_sums, self.p*100, interpolation='linear')
return K.cast(kernel_sums >= Qp, 'float32')[None,None,:,:]
elif self.test_pruning_mode == 'filter':
filter_sums = tf.reduce_sum(W2, axis=[0,1,2])
Qp = tfp.stats.percentile(filter_sums, self.p*100, interpolation='linear')
return K.cast(filter_sums >= Qp, 'float32')[None,None,None,:]
else:
raise ValueError("test_pruning_mode must be one of 'gibbs', 'kernel', or 'filter'")
def train_mask(self):
W2 = self.kernel * self.kernel
n_filter_weights = np.product(self.kernel_size)
if self.hamiltonian == 'unstructured':
Qp = tfp.stats.percentile(K.flatten(W2), self.p*100, interpolation='linear')
P0 = 1/(1+K.exp(self.beta*(W2-Qp)))
R = K.random_uniform(K.shape(P0))
return K.cast(R > P0, 'float32')
elif self.hamiltonian == 'kernel':
# Prune kernels by finding A and B for hamiltonian H(x) = x^TAx +
# b^Tx, and sampling directly for each kernel
flat_W2 = K.reshape(W2, (n_filter_weights, self.n_channels, self.filters))
Qp = tfp.stats.percentile(K.sum(flat_W2,axis=0)/n_filter_weights, self.p*100, interpolation='linear')
b = Qp - flat_W2
A = -self.c * K.constant(np.ones((n_filter_weights, n_filter_weights, self.n_channels, self.filters)))
A_mask = np.ones((n_filter_weights,n_filter_weights))
np.fill_diagonal(A_mask, False)
A = A * A_mask[:,:,None,None]
M = K.reshape(tf_sample_gibbs(A, b, self.beta, n_filter_weights), K.shape(W2))
return (M+1)/2
elif self.hamiltonian == 'filter':
# Prune filters with chromatic gibbs sampling
flat_W2 = K.reshape(W2, (n_filter_weights, self.n_channels, self.filters))
Qp = tfp.stats.percentile(tf.reduce_sum(flat_W2,axis=[0,1])/n_filter_weights/self.n_channels, self.p*100, interpolation='linear')
b = Qp - flat_W2
A = -self.c * K.constant(np.ones((n_filter_weights, n_filter_weights, self.n_channels, self.filters)))
A_mask = np.ones((n_filter_weights,n_filter_weights))
np.fill_diagonal(A_mask, False)
A = A * A_mask[:,:,None,None]
filt_avgs = tf.reduce_sum(flat_W2,axis=[0,1])/n_filter_weights/self.n_channels
x_cvg = K.cast(filt_avgs > Qp, 'float32')
colour_b = b - self.c * (self.n_channels//2) * n_filter_weights * (x_cvg*2-1)[None,None,:]
split = self.n_channels//2
colour_b = colour_b[:,0:split,:]
for i in range(self.mcmc_steps):
P0 = 1/(1+K.exp(-self.beta*colour_b))
R = K.random_uniform(K.shape(P0))
M0 = K.cast(R > P0, 'float32')*2-1
filter_sums = tf.reduce_sum(M0, axis=[0,1])
colour_b = b[:,split:,:] - self.c*filter_sums[None,None,:]
P0 = 1/(1+K.exp(-self.beta*colour_b))
R = K.random_uniform(K.shape(P0))
M1 = K.cast(R > P0, 'float32')*2-1
filter_sums = tf.reduce_sum(M1, axis=[0,1])
colour_b = b[:,0:split,:] - self.c*filter_sums[None,None,:]
M = K.reshape(K.concatenate((M0,M1), axis=1), K.shape(W2))
return (M+1)/2
def get_config(self):
config = {
'beta_init': self.beta_init,
'p': self.p,
'hamiltonian': self.hamiltonian,
'c': self.c,
'test_pruning_mode': self.test_pruning_mode,
'mcmc_steps': self.mcmc_steps,
}
base_config = super().get_config()
return {**config, **base_config}
def set_beta(self, beta):
K.set_value(self.beta, beta)
class GibbsPruningAnnealer(keras.callbacks.Callback):
"""Callback for annealing the beta parameter of Gibbs pruning layers.
Only one instance of the callback is needed for a network, and it will
automatically anneal beta for all Gibbs pruning layers.
Arguments:
- beta_schedule: A list of beta values to use for each epoch. If the list
is shorter than the number of training epochs, the last value in the list
is used for the remaining epochs.
- verbose: Default 0, set to 1 to print messages when beta is updated.
"""
def __init__(self, beta_schedule, verbose=0):
super().__init__()
self.beta_schedule = beta_schedule
self.verbose = verbose
def on_epoch_begin(self, epoch, logs=None):
beta = self.beta_schedule[epoch] if epoch < len(self.beta_schedule) else self.beta_schedule[-1]
count = 0
for layer in self.model.layers:
if isinstance(layer, GibbsPrunedConv2D):
count += 1
layer.set_beta(beta)
if self.verbose > 0:
print(f'GibbsPruningAnnealer: set beta to {beta} in {count} layers')