-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
2 changed files
with
357 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
# coding=utf-8 | ||
# Copyright 2024 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""DataLoader that relies mainly on a tensorstore to load batches of data. | ||
The code is adapted from pygrain.DataLoader and is significantly faster when | ||
loading entire batches from tensorstore instead of the per-item convention with | ||
subsequent stacking using numpy. Parallelization of loading chunks is handled | ||
by tensorstore and the loader only requests batches using threading, which | ||
allows for tensorstore caching and has shown to be faster and more memory- | ||
efficient than using shared memory and multiprocessing. | ||
Sharding for distributed loading as well as sampling use pygrain. | ||
""" | ||
|
||
import collections | ||
from collections.abc import Iterator | ||
import dataclasses | ||
from typing import Any, Optional, Sequence, Type, TypeVar | ||
|
||
import grain.python as pygrain | ||
import numpy as np | ||
import tensorstore as ts | ||
|
||
from concurrent.futures import ThreadPoolExecutor | ||
|
||
T = TypeVar('T', bound='BatchMetadata') | ||
Batch = dict[str, Any] | Any | ||
|
||
# Import required dataclasses to make tensorloader a central import | ||
NoSharding = pygrain.NoSharding | ||
ShardOptions = pygrain.ShardOptions | ||
ShardByJaxProcess = pygrain.ShardByJaxProcess | ||
|
||
|
||
@dataclasses.dataclass(slots=True) | ||
class BatchMetadata: | ||
"""BatchMetadata contains metadata about a batch of records. | ||
BatchMetadata are usually created from a sequence of Metadata emitted by | ||
a pygrain Sample and contain read indices, which indicate steps, record_keys | ||
to read from the TensorSource and rng keys for eventual randomness required. | ||
""" | ||
|
||
indices: Sequence[int] | ||
rngs: Optional[Sequence[np.random.Generator]] = None | ||
|
||
@classmethod | ||
def from_entries( | ||
cls: Type[T], records: Sequence[pygrain.RecordMetadata] | ||
) -> T: | ||
indices = [record.record_key for record in records] | ||
rngs = [record.rng for record in records] | ||
return cls(indices, rngs) | ||
|
||
|
||
class TensorSource: | ||
"""TensorSource protocol that loads batches of data from a TensorStore. | ||
TODO(aleximmer): consider adding transforms or index transforms. | ||
""" | ||
|
||
def __len__(self) -> int: | ||
raise NotImplementedError() | ||
|
||
def __getitem__(self, metadata: BatchMetadata) -> Batch: | ||
raise NotImplementedError() | ||
|
||
@property | ||
def item_shape(self) -> dict[str, tuple[int, ...]]: | ||
"""Return shape of individual items.""" | ||
raise NotImplementedError() | ||
|
||
|
||
class BasicTensorSource(TensorSource): | ||
"""Tensor source where the leading dimension corresponds to data points.""" | ||
|
||
def __init__(self, ts_spec: dict[str, Any]): | ||
self._data = ts.open(ts_spec).result() | ||
|
||
def __len__(self) -> int: | ||
return self._data.shape[0] | ||
|
||
def __getitem__(self, metadata: BatchMetadata) -> Batch: | ||
return self._data[metadata.indices] | ||
|
||
|
||
class TensorLoader: | ||
"""TensorLoader loads batches from tensorstore data source. | ||
In comparison to grain, does not support operations but batches automatically. | ||
Tensorstore is significantly faster than numpy stacking for large arrays and | ||
can be optimized using custom chunk sizes that align with batches. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
tensor_source: TensorSource, | ||
batch_size: int = 1, | ||
sampler: pygrain.Sampler | None = None, | ||
num_epochs: int | None = None, | ||
shuffle: bool | None = None, | ||
seed: int | None = None, | ||
shard_options: ShardOptions | None = None, | ||
num_threads: int = 8, | ||
initial_batch: int = 0, | ||
): | ||
"""Loads and transforms input data. | ||
Args: | ||
tensor_source: Responsible for retrieving batches of records based on | ||
their indices. | ||
batch_size: Number of examples to jointly query from the data source. | ||
sampler: Custom sampler, defaults to pygrain.IndexSampler. | ||
num_epochs: Number of epochs to yield data for. Passed to sampler and | ||
defaults to 1. | ||
shuffle: Whether to randomly or sequentially index the tensor_source. | ||
Passed to sampler and defaults to False. | ||
seed: Random seed for sampler. | ||
shard_options: Options for how data should be sharded when using multiple | ||
machines (~ JAX processes) and data parallelism. | ||
num_threads: Number of threads for parallel prefetching of batches. | ||
initial_batch: Batch number to start from (use for checkpointing). | ||
""" | ||
super().__init__() | ||
self._tensor_source = tensor_source | ||
self._batch_size = batch_size | ||
if shard_options is None: | ||
shard_options = NoSharding() | ||
if sampler is not None: | ||
if any([num_epochs, shuffle, seed]): | ||
raise ValueError( | ||
'Cannot specify sampler and one of num_epochs, shuffle, and seed.' | ||
) | ||
assert not any([num_epochs, shuffle, seed]) | ||
self._sampler = sampler | ||
else: | ||
shuffle = False if shuffle is None else shuffle | ||
num_epochs = 1 if num_epochs is None else num_epochs | ||
self._sampler = pygrain.IndexSampler( | ||
num_records=len(tensor_source), | ||
shard_options=shard_options, | ||
shuffle=shuffle, | ||
num_epochs=num_epochs, | ||
seed=seed, | ||
) | ||
if num_threads <= 0: | ||
raise ValueError(f'num_threads must be positive: {num_threads}') | ||
self._num_threads = num_threads | ||
self._shard_options = shard_options | ||
if initial_batch < 0: | ||
raise ValueError(f'initial_batch must be positive: {initial_batch}') | ||
self.set_initial_batch(initial_batch) | ||
|
||
def set_initial_batch(self, initial_batch: int): | ||
# start negative and advance within __iter__ | ||
self._initial_step = ( | ||
initial_batch * self._batch_size | ||
- self._shard_options.shard_count | ||
+ self._shard_options.shard_index | ||
) | ||
|
||
def __iter__(self) -> Iterator[Batch]: | ||
"""Read sampled record indices to load and yield batches.""" | ||
next_index = self._initial_step + self._shard_options.shard_count | ||
buffer = collections.deque() | ||
|
||
def make_index_batch(next_index_: int) -> tuple[int, list[int]]: | ||
next_indices = [] | ||
for _ in range(self._batch_size): | ||
next_indices.append(next_index_) | ||
next_index_ += self._shard_options.shard_count | ||
return next_index_, next_indices | ||
|
||
def prefetch_elements(indices: Sequence[int]) -> Any: | ||
metadata = [] | ||
for i, index in enumerate(indices): | ||
try: | ||
metadata.append(self._sampler[index]) | ||
except IndexError as e: | ||
if i == 0 or self._shard_options.drop_remainder: | ||
raise e | ||
else: | ||
break | ||
batch_metadata = BatchMetadata.from_entries(metadata) | ||
data = self._tensor_source[batch_metadata] | ||
return data | ||
|
||
with ThreadPoolExecutor(self._num_threads) as executor: | ||
# Fill the buffer initially. | ||
while len(buffer) < self._num_threads: | ||
next_index, batch_indices = make_index_batch(next_index) | ||
buffer.append(executor.submit(prefetch_elements, batch_indices)) | ||
|
||
# Iterate until IndexError from the Sampler. | ||
while True: | ||
try: | ||
batch = buffer.popleft().result() | ||
except IndexError: | ||
return | ||
yield batch | ||
next_index, batch_indices = make_index_batch(next_index) | ||
buffer.append(executor.submit(prefetch_elements, batch_indices)) | ||
|
||
@property | ||
def tensor_source(self) -> TensorSource: | ||
return self._tensor_source |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
# coding=utf-8 | ||
# Copyright 2024 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Tests for tensorloader.""" | ||
|
||
from absl.testing import absltest # pylint: disable=g-multiple-import | ||
from absl.testing import parameterized # pylint: disable=g-multiple-import | ||
from connectomics.jax.inputs import tensorloader as tl | ||
import grain.python as pygrain | ||
import numpy as np | ||
import tensorstore as ts | ||
|
||
|
||
class TestTensorSource(tl.TensorSource): | ||
"""Testing source that works for in-memory non-persistent tensorstore.""" | ||
|
||
def __init__(self, data: ts.TensorStore): | ||
self._data = data | ||
|
||
def __len__(self) -> int: | ||
return self._data.shape[0] | ||
|
||
def __getitem__(self, metadata: tl.BatchMetadata) -> tl.Batch: | ||
return {'data': self._data[metadata.indices].read().result()} | ||
|
||
|
||
class TensorloaderTest(parameterized.TestCase): | ||
|
||
def setUp(self): | ||
super().setUp() | ||
ts_test_config = { | ||
'create': True, | ||
'driver': 'zarr', | ||
'dtype': 'float64', | ||
'metadata': {'shape': [16, 8, 2]}, | ||
'kvstore': {'driver': 'memory'}, | ||
} | ||
self.data = ts.open(ts_test_config).result() | ||
generator = np.random.RandomState(seed=4321) | ||
self.example_data = generator.randn(*self.data.shape).astype(np.float64) | ||
self.data[...] = self.example_data | ||
|
||
def test_completeness(self): | ||
loader = tl.TensorLoader( | ||
tensor_source=TestTensorSource(self.data), | ||
batch_size=4, | ||
num_epochs=1, | ||
shuffle=False, | ||
) | ||
data = np.vstack([batch['data'] for batch in loader]) | ||
np.testing.assert_array_equal(data, self.example_data) | ||
|
||
def test_incomplete_last_batch(self): | ||
loader = tl.TensorLoader( | ||
tensor_source=TestTensorSource(self.data), | ||
shard_options=tl.ShardOptions(0, 1, drop_remainder=False), | ||
batch_size=3, | ||
num_epochs=1, | ||
shuffle=False, | ||
) | ||
data = np.vstack([batch['data'] for batch in loader]) | ||
np.testing.assert_array_equal(data, self.example_data) | ||
|
||
def test_skip_last_batch(self): | ||
loader = tl.TensorLoader( | ||
tensor_source=TestTensorSource(self.data), | ||
shard_options=tl.ShardOptions(0, 1, drop_remainder=True), | ||
batch_size=3, | ||
num_epochs=1, | ||
shuffle=False, | ||
) | ||
data = np.vstack([batch['data'] for batch in loader]) | ||
np.testing.assert_array_equal(data, self.example_data[:-1]) | ||
|
||
def test_checkpointing(self): | ||
loader = tl.TensorLoader( | ||
tensor_source=TestTensorSource(self.data), | ||
batch_size=4, | ||
num_epochs=1, | ||
shuffle=False, | ||
initial_batch=2, | ||
) | ||
data = np.vstack([batch['data'] for batch in loader]) | ||
np.testing.assert_array_equal(data, self.example_data[8:]) | ||
|
||
def test_sharding_complement(self): | ||
base_config = dict( | ||
tensor_source=TestTensorSource(self.data), | ||
batch_size=4, | ||
num_epochs=1, | ||
) | ||
shard_options_a = tl.ShardOptions(shard_index=0, shard_count=2) | ||
shard_options_b = tl.ShardOptions(shard_index=1, shard_count=2) | ||
loader_a = tl.TensorLoader(**base_config, shard_options=shard_options_a) | ||
loader_b = tl.TensorLoader(**base_config, shard_options=shard_options_b) | ||
batches = [] | ||
for batch_a, batch_b in zip(loader_a, loader_b): | ||
batches.append(np.vstack([batch_a['data'], batch_b['data']])) | ||
data = np.vstack(batches) | ||
self.assertEqual(data.shape, self.example_data.shape) | ||
# test statistic. sampler separates regions and is not globally sequential. | ||
self.assertEqual(np.mean(data), np.mean(self.example_data)) | ||
|
||
def test_basic_source(self): | ||
ts_test_config = { | ||
'driver': 'array', | ||
'dtype': 'int32', | ||
'array': np.arange(16).reshape((4, 4)), | ||
} | ||
tsource = tl.BasicTensorSource(ts_test_config) | ||
metadata = tl.BatchMetadata(indices=[0, 1]) | ||
batch = tsource[metadata] | ||
np.testing.assert_array_equal(batch, np.arange(8).reshape((2, 4))) | ||
|
||
def test_external_sampler(self): | ||
shard_options = tl.ShardOptions(shard_index=0, shard_count=1) | ||
source = TestTensorSource(self.data) | ||
sampler = pygrain.SequentialSampler(len(source), shard_options) | ||
loader = tl.TensorLoader( | ||
tensor_source=source, shard_options=shard_options, sampler=sampler | ||
) | ||
data = np.vstack([batch['data'] for batch in loader]) | ||
np.testing.assert_array_equal(data, self.example_data) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |