Skip to content

Commit

Permalink
Add option for padding when increasing size with Interpolation decora…
Browse files Browse the repository at this point in the history
…tor.

PiperOrigin-RevId: 673361777
  • Loading branch information
aleximmer authored and copybara-github committed Sep 27, 2024
1 parent c65e963 commit 0a9f0ff
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
17 changes: 16 additions & 1 deletion connectomics/volume/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,8 @@ def __init__(self,
"""
super().__init__(context_spec)
self._size = size
backends = ('scipy_map_coordinates', 'jax_map_coordinates', 'jax_resize')
backends = (
'scipy_map_coordinates', 'jax_map_coordinates', 'jax_resize', 'pad')
if backend not in backends:
raise ValueError(f'Unsupported backend: {backend} not in {backends}.')
self._backend = backend
Expand All @@ -820,6 +821,9 @@ def decorate(self, input_ts: ts.TensorStore) -> ts.TensorStore:
raise ValueError(
f'Length of `size` ({len(self._size)}) does not match ' +
f'dimensionality of input TensorStore ({input_ts.ndim}).')
if (any(new < old for new, old in zip(self._size, input_ts.shape))
and self._backend == 'pad'):
raise ValueError('Can only pad to increase size.')
inclusive_min = input_ts.schema.domain.inclusive_min
if inclusive_min != tuple([0 for _ in range(input_ts.ndim)]):
raise ValueError(
Expand Down Expand Up @@ -854,6 +858,17 @@ def read_fn(domain: ts.IndexDomain, array: np.ndarray,
for d, s in enumerate(self._size)]
array[...] = jax.image.resize(data, sub_size,
**self._interpolation_args)
elif self._backend == 'pad':
pad_width = []
for d in range(input_ts.ndim):
if d not in resize_dim:
pad_width.append((0, 0))
else:
difference = self._size[d] - data.shape[d]
left = difference // 2
pad_width.append((left, difference - left))
array[...] = np.pad(data, pad_width=pad_width,
**self._interpolation_args)
else:
array[...] = map_coordinates(data, np.mgrid[slices],
**self._interpolation_args)
Expand Down
25 changes: 25 additions & 0 deletions connectomics/volume/decorators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,31 @@ def test_interpolation(self):
vc = dec.decorate(data_ts)
np.testing.assert_equal(vc[...].read().result(), expected_res)

def test_padding_interpolation(self):
data = np.array([
[1, 2],
[3, 4]
]).astype(np.float32)
data_ts = ts.open({
'driver': 'n5',
'kvstore': {'driver': 'memory'},
'metadata': {
'dataType': 'float32',
'dimensions': (2, 2),
'axes': ('x', 'y'),
},
'create': True,
'delete_existing': True,
}).result()
data_ts[...] = data

expected_res = np.zeros((4, 4)).astype(np.float32)
expected_res[1:3, 1:3] = data
dec = decorators.Interpolation(size=(4, 4), backend='pad',
constant_values=0)
vc = dec.decorate(data_ts)
np.testing.assert_equal(vc[...].read().result(), expected_res)

def test_multiply(self):
mask = np.zeros_like(self._data, dtype='float32')

Expand Down

0 comments on commit 0a9f0ff

Please sign in to comment.