-
Notifications
You must be signed in to change notification settings - Fork 5
/
stn_nearest.py
228 lines (171 loc) · 7.2 KB
/
stn_nearest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
"""
Thanks to voxelmorph: Learning-Based Image Registration, https://github.com/voxelmorph/voxelmorph for this code.
If you use this code, please cite the respective papers in their repo.
"""
# third party
import numpy as np
from keras import backend as K
from keras.layers import Layer
import tensorflow as tf
import itertools
import utils
class SpatialTransformer(Layer):
"""
N-D Spatial Transformer Tensorflow / Keras Layer
The Layer can handle both affine and dense transforms.
Both transforms are meant to give a 'shift' from the current position.
Therefore, a dense transform gives displacements (not absolute locations) at each voxel,
and an affine transform gives the *difference* of the affine matrix from
the identity matrix.
If you find this function useful, please cite:
Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration
Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu
MICCAI 2018.
Originally, this code was based on voxelmorph code, which
was in turn transformed to be dense with the help of (affine) STN code
via https://github.com/kevinzakka/spatial-transformer-network
Since then, we've re-written the code to be generalized to any
dimensions, and along the way wrote grid and interpolation functions
"""
def __init__(self,
interp_method='linear',
indexing='ij',
single_transform=False,
**kwargs):
"""
Parameters:
interp_method: 'linear' or 'nearest'
single_transform: whether a single transform supplied for the whole batch
indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian)
'xy' indexing will have the first two entries of the flow
(along last axis) flipped compared to 'ij' indexing
"""
self.interp_method = interp_method
self.ndims = None
self.inshape = None
self.single_transform = single_transform
assert indexing in ['ij', 'xy'], "indexing has to be 'ij' (matrix) or 'xy' (cartesian)"
self.indexing = indexing
super(self.__class__, self).__init__(**kwargs)
def build(self, input_shape):
"""
input_shape should be a list for two inputs:
input1: image.
input2: transform Tensor
if affine:
should be a N x N+1 matrix
*or* a N*N+1 tensor (which will be reshape to N x (N+1) and an identity row added)
if not affine:
should be a *vol_shape x N
"""
if len(input_shape) > 2:
raise Exception('Spatial Transformer must be called on a list of length 2.'
'First argument is the image, second is the transform.')
# set up number of dimensions
self.ndims = len(input_shape[0]) - 2
self.inshape = input_shape
vol_shape = input_shape[0][1:-1]
trf_shape = input_shape[1][1:]
# confirm built
self.built = True
def call(self, inputs):
"""
Parameters
inputs: list with two entries
"""
# check shapes
assert len(inputs) == 2, "inputs has to be len 2, found: %d" % len(inputs)
vol = inputs[0]
trf = inputs[1]
# necessary for multi_gpu models...
vol = K.reshape(vol, [-1, *self.inshape[0][1:]])
trf = K.reshape(trf, [-1, *self.inshape[1][1:]])
return tf.map_fn(self._single_transform, [vol, trf], dtype=tf.float32)
def _single_transform(self, inputs):
return transform(inputs[0], inputs[1], interp_method=self.interp_method)
def transform(vol, loc_shift, interp_method='linear'):
"""
transform (interpolation N-D volumes (features) given shifts at each location in tensorflow
Essentially interpolates volume vol at locations determined by loc_shift.
This is a spatial transform in the sense that at location [x] we now have the data from,
[x + shift] so we've moved data.
Parameters:
vol: volume with size vol_shape or [*vol_shape, nb_features]
loc_shift: shift volume [*new_vol_shape, N]
interp_method (default:'linear'): 'linear', 'nearest'
indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian).
In general, prefer to leave this 'ij'
Return:
new interpolated volumes in the same size as loc_shift[0]
Keyworks:
interpolation, sampler, resampler, linear, bilinear
"""
# parse shapes
if isinstance(loc_shift.shape, (tf.Dimension, tf.TensorShape)):
volshape = loc_shift.shape[:-1].as_list()
else:
volshape = loc_shift.shape[:-1]
nb_dims = len(volshape)
# location should be mesh and delta
# test single
return interpn(vol, loc_shift)
def interpn(vol, loc):
"""
N-D gridded interpolation in tensorflow
vol can have more dimensions than loc[i], in which case loc[i] acts as a slice
for the first dimensions
Parameters:
vol: volume with size vol_shape or [*vol_shape, nb_features]
loc: a N-long list of N-D Tensors (the interpolation locations) for the new grid
each tensor has to have the same size (but not nec. same size as vol)
or a tensor of size [*new_vol_shape, D]
interp_method: interpolation type 'linear' (default) or 'nearest'
Returns:
new interpolated volume of the same size as the entries in loc
TODO:
enable optional orig_grid - the original grid points.
check out tf.contrib.resampler, only seems to work for 2D data
"""
# since loc can be a list, nb_dims has to be based on vol.
nb_dims = loc.shape[-1]
if nb_dims != len(vol.shape[:-1]):
raise Exception("Number of loc Tensors %d does not match volume dimension %d"
% (nb_dims, len(vol.shape[:-1])))
if nb_dims > len(vol.shape):
raise Exception("Loc dimension %d does not match volume dimension %d"
% (nb_dims, len(vol.shape)))
if len(vol.shape) == nb_dims:
vol = K.expand_dims(vol, -1)
# flatten and float location Tensors
loc = tf.cast(loc, 'float32')
if isinstance(vol.shape, (tf.Dimension, tf.TensorShape)):
volshape = vol.shape.as_list()
else:
volshape = vol.shape
roundloc = tf.cast(tf.round(loc), 'int32')
# clip values
max_loc = [tf.cast(d - 1, 'int32') for d in vol.shape]
roundloc = [tf.clip_by_value(roundloc[..., d], 0, max_loc[d]) for d in range(nb_dims)]
# get values
# tf stacking is slow. replace with gather
# roundloc = tf.stack(roundloc, axis=-1)
# interp_vol = tf.gather_nd(vol, roundloc)
idx = sub2ind(vol.shape[:-1], roundloc)
interp_vol = tf.gather(tf.reshape(vol, [-1, vol.shape[-1]]), idx)
return interp_vol
def prod_n(lst):
prod = lst[0]
for p in lst[1:]:
prod *= p
return prod
def sub2ind(siz, subs):
"""
assumes column-order major
"""
# subs is a list
assert len(siz) == len(subs), 'found inconsistent siz and subs: %d %d' % (len(siz), len(subs))
k = np.cumprod(siz[::-1])
ndx = subs[-1]
for i, v in enumerate(subs[:-1][::-1]):
ndx = ndx + v * k[i]
return ndx