diff --git a/tensorflow_similarity/samplers/tfdata_sampler.py b/tensorflow_similarity/samplers/tfdata_sampler.py index abbf40eb..52d8c5e5 100644 --- a/tensorflow_similarity/samplers/tfdata_sampler.py +++ b/tensorflow_similarity/samplers/tfdata_sampler.py @@ -100,9 +100,7 @@ def apply_augmenter_ds(ds: tf.data.Dataset, augmenter: Callable, warmup: int | N ds = tf.data.Dataset.choose_from_datasets( [ds, aug_ds], - count_ds.map( - lambda x: tf.cast(0, dtype=tf.dtypes.int64) if x < warmup else tf.cast(1, dtype=tf.dtypes.int64) - ), + count_ds.map(lambda x: tf.cast(0, dtype=tf.dtypes.int64) if x < warmup else tf.cast(1, dtype=tf.dtypes.int64)), ) return ds