Skip to content

Commit

Permalink
Tensor slice sampler (#329)
Browse files Browse the repository at this point in the history
* Create tfdata_sampler.py

Initial version of new tf.data.Dataset sampler.

* Refactor and clean up the tf data sampler.

* Add initial tests for tfdata_sampler

* Reformat TFDataSampler test file.

* Fix proto dep issue in github workflow tests. py 3.10 breaks with protobuf > 3.20.x

* Setting env var didn't work. Trying again with pinning the protobuf version to 3.20.1

* Check TF version before creating the tf dataset counter.

* Format file

* Remove as_numpy_iterator when creating the list of grouped datasets.

* Also move class_list filter to before the group_by function
* Apply the total_examples_per_class as a take() function on each
  grouped dataset
* Remove as much casting as possible from the dataset. Certain functions
  expect an int64 though and require casting.

* Refactor to move the filter by class list out of the window_group_by function.

* Add class list filter test.

* Move augment_fn and load_fn to before the repeat and batch functions.

This change means the aug and load functions apply per example now. This
will make it easier to apply random augmentations per example and is
more consistent with how we implemented it in the existing memory
sampler.

This change also improves the tests for all parts of the module.

* Add support for handling tuple and dict values for y.

This change adds support for passing a callable to parse the correct
class id element for batch sampling. By default y is assumed to be a 1D
tensor with the class ids and the function is lambda y:y. Otherwise we
accept an int or str and construct a parser to get the class id tensor.
  • Loading branch information
owenvallis authored May 5, 2023
1 parent 910fbb8 commit 8789737
Show file tree
Hide file tree
Showing 3 changed files with 503 additions and 1 deletion.
4 changes: 3 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,16 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install coveralls
- name: Install dev packages
run: |
pip install ".[dev]"
- name: Install TF package
run: |
pip install tensorflow==${{ matrix.tf-version }}
# Fix proto dep issue in protobuf 4
pip install protobuf==3.20.*
- name: Lint with flake8
run: |
Expand Down
207 changes: 207 additions & 0 deletions tensorflow_similarity/samplers/tfdata_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
from __future__ import annotations

from collections.abc import Callable, Sequence

import tensorflow as tf


def filter_classes(
ds: tf.data.Dataset,
class_list: Sequence[int] | None = None,
y_parser: Callable = lambda y: y,
) -> tf.data.Dataset:
"""
Filters a dataset by class id.
Args:
ds: A `tf.data.Dataset` object.
class_list: An optional `Sequence` of integers representing the classes
to include in the dataset. If `None`, all classes are included.
y_parser: A callable function used to parse the class id from the y outputs.
Returns:
A `tf.data.Dataset` object filtered by class id.
"""

if class_list is not None:
class_list = tf.constant(class_list)
ds = ds.filter(lambda x, y, *args: tf.reduce_any(tf.equal(y_parser(y), class_list)))

return ds


def create_grouped_dataset(
ds: tf.data.Dataset,
window_size: int,
total_examples: int | None = None,
buffer_size: int | None = None,
y_parser: Callable = lambda y: y,
) -> list[tf.data.Dataset]:
"""
Creates a list of datasets grouped by class id.
Args:
ds: A `tf.data.Dataset` object.
window_size: An integer representing the dataset cardinality.
total_examples: An integer representing the maximum number of examples
to include in the dataset. If `None`, all examples are included.
buffer_size: An optional integer representing the size of the buffer
for shuffling. Default is None.
y_parser: A callable function used to parse the class id from the y outputs.
Returns:
A List of `tf.data.Dataset` objects grouped by class id.
"""
# NOTE: We need to cast the key_func as the group_op expects an int64.
grouped_by_cid = ds.group_by_window(
key_func=lambda x, y, *args: tf.cast(y_parser(y), dtype=tf.int64),
reduce_func=lambda key, ds: ds.batch(window_size),
window_size=window_size,
)

cid_datasets = []
for elem in grouped_by_cid:
cid_ds = tf.data.Dataset.from_tensor_slices(elem)
if total_examples is not None:
cid_ds = cid_ds.take(total_examples)
if buffer_size is not None:
cid_ds = cid_ds.shuffle(buffer_size)
cid_datasets.append(cid_ds.repeat())

return cid_datasets


def create_choices_dataset(num_classes: int, examples_per_class: int) -> tf.data.Dataset:
"""
Creates a dataset that generates random integers between 0 and `num_classes`.
Integers will be generated in contiguous blocks of size `examples_per_class`.
Integers are sampled without replacement and are not selected again until all
other interger values have been sampled.
Args:
num_classes: An integer representing the total number of classes in the dataset.
examples_per_class: An integer representing the number of examples per class.
Returns:
A `tf.data.Dataset` object representing the dataset with random choices.
"""
return (
tf.data.Dataset.range(num_classes)
.shuffle(num_classes)
.map(lambda z: [[z] * examples_per_class], name="repeat_cid")
.flat_map(tf.data.Dataset.from_tensor_slices)
.repeat()
)


def apply_augmenter_ds(ds: tf.data.Dataset, augmenter: Callable, warmup: int = 0) -> tf.data.Dataset:
"""
Applies an augmenter function to a dataset batch and optionally delays
applying the function for `warmup` number of examples.
Args:
ds: A `tf.data.Dataset` object.
augmenter: A callable function used to apply data augmentation to
individual examples within each batch. If `None`, no data
augmentation is applied.
warmup: An integer representing the number of examples to wait
before applying the data augmentation function.
Returns:
A `tf.data.Dataset` object with the applied augmenter.
"""
if not warmup:
return ds.map(augmenter, name="augmenter")

aug_ds = ds.map(augmenter, name="augmenter").skip(warmup)
tf_version_split = tf.__version__.split(".")
if int(tf_version_split[0]) >= 2 and int(tf_version_split[1]) >= 10:
count_ds = tf.data.Dataset.counter()
else:
count_ds = tf.data.experimental.Counter()

ds = tf.data.Dataset.choose_from_datasets(
[ds.take(warmup), aug_ds],
count_ds.map(lambda x: tf.cast(0, dtype=tf.int64) if x < warmup else tf.cast(1, dtype=tf.int64)),
)

return ds


def TFDataSampler(
ds: tf.data.Dataset,
classes_per_batch: int = 2,
examples_per_class_per_batch: int = 2,
class_list: Sequence[int] | None = None,
total_examples_per_class: int | None = None,
augmenter: Callable | None = None,
load_fn: Callable | None = None,
warmup: int = 0,
label_output: int | str | None = None,
) -> tf.data.Dataset:
"""
Returns a `tf.data.Dataset` object that generates batches of examples with
equal number of examples per class. The input dataset cardinality must be
finite and known.
Args:
ds: A `tf.data.Dataset` object representing the original dataset.
classes_per_batch: An integer specifying the number of classes per batch.
examples_per_class_per_batch: An integer specifying the number of examples
per class per batch.
class_list: An optional sequence of integers representing the class IDs
to include in the dataset. If `None`, all classes in the original
dataset will be used.
total_examples_per_class: An optional integer representing the total
number of examples per class to use. If `None`, all examples for
each class will be used.
augmenter: An optional function to apply data augmentation to each
example in a batch.
load_fn: An optional callable function for loading real examples from `x`.
It is useful for loading images from their corresponding file path
provided in `x` or similar situations.
warmup: An integer specifying the number of examples to use for unaugmented
warmup.
label_output: An optional integer or string representing the label output
used to create the balanced dataset batches. If `None`, y is assumed
to be a 1D integer tensor containing the class IDs. If `int`, y is
assumed to be a tuple of tensors with the class IDs in the element
specified by `label_output`. If `str`, y is assumed to be a dictionary
with the class IDs in the key specified by `label_output`.
Returnsk:
A `tf.data.Dataset` object representing the balanced dataset for few-shot learning tasks.
"""
if ds.cardinality() == tf.data.INFINITE_CARDINALITY:
raise ValueError("Dataset must be finite.")
if ds.cardinality() == tf.data.UNKNOWN_CARDINALITY:
raise ValueError("Dataset cardinality must be known.")

def y_parser(y):
return y if label_output is None else y[label_output]

window_size = ds.cardinality().numpy()
batch_size = examples_per_class_per_batch * classes_per_batch

ds = filter_classes(ds, class_list=class_list, y_parser=y_parser)
ds = create_grouped_dataset(
ds,
window_size=window_size,
total_examples=total_examples_per_class,
y_parser=y_parser,
)
choices_ds = create_choices_dataset(
len(ds),
examples_per_class=examples_per_class_per_batch,
)
ds = tf.data.Dataset.choose_from_datasets(ds, choices_ds)

if load_fn is not None:
ds = ds.map(load_fn, name="load_fn")

if augmenter is not None:
ds = apply_augmenter_ds(ds, augmenter, warmup)

ds = ds.repeat().batch(batch_size).prefetch(tf.data.AUTOTUNE)

return ds
Loading

0 comments on commit 8789737

Please sign in to comment.