diff --git a/README.md b/README.md index 9231c556..9f34de69 100644 --- a/README.md +++ b/README.md @@ -16,19 +16,13 @@ With Tensorflow Similarity you can train two main types of models: ## What's new -- [May 2022]: 0.16 major optimization release - * Cross-batch memory (XBM) loss added thank to @chjort - * Many self-supervised related improvement thanks to @dewball345 - * Major layers and callback refactoring to make them faster and more flexible. E.g `EvalCallback()` now support splited validation. - For full changes see [the changelog](./releases.md) - -- [Jan 2022]: 0.15 self-supervised release - * Added support for self-supervised contrastive learning. Including SimCLR, SimSiam, and Barlow Twins. Checkout the in-depth [hello world notebook](examples/unsupervised_hello_world.ipynb) to get started. - * Soft Nearest Neighbor Loss added thanks to [Abhishar Sinha](https://github.com/abhisharsinha) - * Added GenerlizedMeanPooling2D support that improves similarity matching accuracy over GlobalMeanPooling2D. - * Numerous speed optimizations and general bug fixes. - -For previous changes and more details - see [the changelog](./releases.md) +- [Mar 2023]: 0.17 more losses and metric and massive refactoring + * Added VicReg Loss to contrastive losses. + * Added metrics used in retrieval papers such as Precision@K + * Native support for distributed training e.g SimClr now works correctly with distributed training. + * Multi-modal embedding initial support (CLIP) + +For more details and previous releases information - see [the changelog](./releases.md) ## Getting Started diff --git a/examples/unsupervised_hello_world.ipynb b/examples/unsupervised_hello_world.ipynb index f24d3ed5..66d97da0 100644 --- a/examples/unsupervised_hello_world.ipynb +++ b/examples/unsupervised_hello_world.ipynb @@ -384,7 +384,7 @@ "metadata": {}, "outputs": [], "source": [ - "ALGORITHM = \"simsiam\" # @param [\"barlow\", \"simsiam\", \"simclr\", \"vigreg\"]" + "ALGORITHM = \"simsiam\" # @param [\"barlow\", \"simsiam\", \"simclr\", \"vicreg\"]" ] }, { diff --git a/tensorflow_similarity/losses/__init__.py b/tensorflow_similarity/losses/__init__.py index e89a08b2..3e11740e 100644 --- a/tensorflow_similarity/losses/__init__.py +++ b/tensorflow_similarity/losses/__init__.py @@ -17,6 +17,7 @@ """ from .barlow import Barlow # noqa from .circle_loss import CircleLoss # noqa +from .lifted_structure_loss import LiftedStructLoss # noqa from .metric_loss import MetricLoss # noqa from .multinegrank_loss import MultiNegativesRankLoss # noqa from .multisim_loss import MultiSimilarityLoss # noqa diff --git a/tensorflow_similarity/losses/lifted_structure_loss.py b/tensorflow_similarity/losses/lifted_structure_loss.py new file mode 100644 index 00000000..dc542d7f --- /dev/null +++ b/tensorflow_similarity/losses/lifted_structure_loss.py @@ -0,0 +1,124 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Lifted Structured Loss + Deep Metric Learning via Lifted Structured Feature Embedding. + https://arxiv.org/abs/1511.06452 +""" +from __future__ import annotations + +import tensorflow as tf + +from tensorflow_similarity import losses as tfsim_losses +from tensorflow_similarity.algebra import build_masks +from tensorflow_similarity.distances import Distance, distance_canonicalizer +from tensorflow_similarity.types import FloatTensor, IntTensor + +from .metric_loss import MetricLoss +from .utils import positive_distances + + +def lifted_struct_loss( + labels: IntTensor, + embeddings: FloatTensor, + key_labels: IntTensor, + key_embeddings: FloatTensor, + distance: Distance, + positive_mining_strategy: str = "hard", + margin: float = 1.0, +) -> FloatTensor: + """Lifted Struct loss computations""" + + # Compute pairwise distances + pairwise_distances = distance(embeddings, key_embeddings) + + # Build masks for positive and negative pairs + positive_mask, negative_mask = build_masks( + query_labels=labels, key_labels=key_labels, batch_size=tf.shape(embeddings)[0] + ) + + # Get positive distances and indices + positive_dists, positive_indices = positive_distances(positive_mining_strategy, pairwise_distances, positive_mask) + + # Reorder pairwise distances and negative mask based on positive indices + reordered_pairwise_distances = tf.gather(pairwise_distances, positive_indices, axis=1) + reordered_negative_mask = tf.gather(negative_mask, positive_indices, axis=1) + + # Concatenate pairwise distances and negative masks along axis=1 + concatenated_distances = tf.concat([pairwise_distances, reordered_pairwise_distances], axis=1) + concatenated_negative_mask = tf.concat([negative_mask, reordered_negative_mask], axis=1) + concatenated_negative_mask = tf.cast(concatenated_negative_mask, tf.float32) + # Compute (margin - neg_dist) logsum_exp values for each row (equation 4 in the paper) + neg_logsumexp = tfsim_losses.utils.logsumexp(margin - concatenated_distances, concatenated_negative_mask) + + # Calculate the loss + j_values = neg_logsumexp + positive_dists + + loss: FloatTensor = j_values / 2.0 + + return loss + + +@tf.keras.utils.register_keras_serializable(package="Similarity") +class LiftedStructLoss(MetricLoss): + """Computes the lifted structured loss in an online fashion. + This loss encourages the positive distances between a pair of embeddings + with the same labels to be smaller than the negative distances between pair + of embeddings of different labels. + See: https://arxiv.org/abs/1511.06452 for the original paper. + `y_true` must be a 1-D integer `Tensor` of shape (batch_size,). + It's values represent the classes associated with the examples as + **integer values**. + `y_pred` must be 2-D float `Tensor` of L2 normalized embedding vectors. + You can use the layer `tensorflow_similarity.layers.L2Embedding()` as the + last layer of your model to ensure your model output is properly normalized. + """ + + def __init__( + self, + distance: Distance | str = "cosine", + positive_mining_strategy: str = "hard", + margin: float = 1.0, + name: str = "LiftedStructLoss", + **kwargs, + ): + """Initializes the LiftedStructLoss. + Args: + distance: Which distance function to use to compute the pairwise + distances between embeddings. + positive_mining_strategy: What mining strategy to use to select + embedding from the same class. Defaults to 'hard'. + Available: {'easy', 'hard'} + margin: Use an explicit value for the margin term. + name: Loss name. Defaults to "LiftedStructLoss". + Raises: + ValueError: Invalid positive mining strategy. + """ + + # distance canonicalization + distance = distance_canonicalizer(distance) + self.distance = distance + + # sanity checks + if positive_mining_strategy not in ["easy", "hard"]: + raise ValueError("Invalid positive mining strategy") + + super().__init__( + lifted_struct_loss, + name=name, + distance=distance, + positive_mining_strategy=positive_mining_strategy, + margin=margin, + **kwargs, + ) diff --git a/tensorflow_similarity/samplers/tfdata_sampler.py b/tensorflow_similarity/samplers/tfdata_sampler.py index 6a87ba2c..940b1107 100644 --- a/tensorflow_similarity/samplers/tfdata_sampler.py +++ b/tensorflow_similarity/samplers/tfdata_sampler.py @@ -102,7 +102,8 @@ def apply_augmenter_ds(ds: tf.data.Dataset, augmenter: Callable, warmup: int = 0 Args: ds: A `tf.data.Dataset` object. augmenter: A callable function used to apply data augmentation to - individual examples. If `None`, no data augmentation is applied. + individual examples within each batch. If `None`, no data + augmentation is applied. warmup: An integer representing the number of examples to wait before applying the data augmentation function. @@ -139,33 +140,36 @@ def TFDataSampler( label_output: int | str | None = None, ) -> tf.data.Dataset: """ - Returns a `tf.data.Dataset` object that generates batches of examples with an - equal number of examples per class. The input dataset cardinality must be finite - and known. + Returns a `tf.data.Dataset` object that generates batches of examples with + equal number of examples per class. The input dataset cardinality must be + finite and known. Args: ds: A `tf.data.Dataset` object representing the original dataset. classes_per_batch: An integer specifying the number of classes per batch. examples_per_class_per_batch: An integer specifying the number of examples per class per batch. - class_list: An optional sequence of integers representing the class IDs to - include in the dataset. If `None`, all classes in the original dataset - will be used. - total_examples_per_class: An optional integer representing the total number - of examples per class to use. If `None`, all examples for each class will - be used. - augmenter: An optional function to apply data augmentation to each example. - load_fn: An optional callable function for loading real examples from `x`. It - is useful for loading images from their corresponding file path provided - in `x` or similar situations. + class_list: An optional sequence of integers representing the class IDs + to include in the dataset. If `None`, all classes in the original + dataset will be used. + total_examples_per_class: An optional integer representing the total + number of examples per class to use. If `None`, all examples for + each class will be used. + augmenter: An optional function to apply data augmentation to each + example in a batch. + load_fn: An optional callable function for loading real examples from `x`. + It is useful for loading images from their corresponding file path + provided in `x` or similar situations. warmup: An integer specifying the number of examples to use for unaugmented warmup. - label_output: An optional integer or string representing the label output used - to create the balanced dataset batches. If `None`, y is assumed to be a 1D - integer tensor containing the class IDs. - + label_output: An optional integer or string representing the label output + used to create the balanced dataset batches. If `None`, y is assumed + to be a 1D integer tensor containing the class IDs. If `int`, y is + assumed to be a tuple of tensors with the class IDs in the element + specified by `label_output`. If `str`, y is assumed to be a dictionary + with the class IDs in the key specified by `label_output`. Returns: - A `tf.data.Dataset` object representing the balanced dataset batches. + A `tf.data.Dataset` object representing the balanced dataset for few-shot learning tasks. Raises: ValueError: If `ds` is an infinite dataset or the cardinality is unknown. diff --git a/tests/losses/test_lifted_structure_loss.py b/tests/losses/test_lifted_structure_loss.py new file mode 100644 index 00000000..c244bfb8 --- /dev/null +++ b/tests/losses/test_lifted_structure_loss.py @@ -0,0 +1,67 @@ +import tensorflow as tf +from absl.testing import parameterized +from tensorflow.keras.losses import Reduction +from tensorflow.python.framework import combinations + +from tensorflow_similarity import losses + +from . import utils + + +@combinations.generate(combinations.combine(mode=["graph", "eager"])) +class TestLiftedStructLoss(tf.test.TestCase, parameterized.TestCase): + def test_config(self): + lifted_obj = losses.LiftedStructLoss( + reduction=Reduction.SUM, + name="lifted_loss", + ) + self.assertEqual(lifted_obj.distance.name, "cosine") + self.assertEqual(lifted_obj.name, "lifted_loss") + self.assertEqual(lifted_obj.reduction, Reduction.SUM) + + @parameterized.named_parameters( + {"testcase_name": "_fixed_margin", "margin": 1.1, "expected_loss": 157.68167}, + ) + def test_all_correct_unweighted(self, margin, expected_loss): + """Tests the LiftedStructLoss with different parameters.""" + y_true, y_preds = utils.generate_perfect_test_batch() + + lifted_obj = losses.LiftedStructLoss(reduction=Reduction.SUM, margin=margin) + loss = lifted_obj(y_true, y_preds) + self.assertAlmostEqual(self.evaluate(loss), expected_loss, 3) + + @parameterized.named_parameters( + {"testcase_name": "_fixed_margin", "margin": 1.0, "expected_loss": 187.37393}, + ) + def test_all_mismatch_unweighted(self, margin, expected_loss): + """Tests the LiftedStructLoss with different parameters.""" + y_true, y_preds = utils.generate_bad_test_batch() + + lifted_obj = losses.LiftedStructLoss(reduction=Reduction.SUM, margin=margin) + loss = lifted_obj(y_true, y_preds) + self.assertAlmostEqual(self.evaluate(loss), expected_loss, 3) + + @parameterized.named_parameters( + {"testcase_name": "_fixed_margin", "margin": 1.0, "expected_loss": 2.927718}, + ) + def test_no_reduction(self, margin, expected_loss): + """Tests the LiftedStructLoss with different parameters.""" + y_true, y_preds = utils.generate_bad_test_batch() + + lifted_obj = losses.LiftedStructLoss(reduction=Reduction.NONE, margin=margin) + loss = lifted_obj(y_true, y_preds) + loss = self.evaluate(loss) + expected_loss = self.evaluate(tf.fill(y_true.shape, expected_loss)) + self.assertArrayNear(loss, expected_loss, 0.001) + + @parameterized.named_parameters( + {"testcase_name": "_fixed_margin", "margin": 1.0, "expected_loss": 2.414156913757324}, + ) + def test_sum_reduction(self, margin, expected_loss): + """Tests the LiftedStructLoss with different parameters.""" + y_true, y_preds = utils.generate_perfect_test_batch() + + lifted_obj = losses.LiftedStructLoss(reduction=Reduction.SUM, margin=margin) + loss = lifted_obj(y_true, y_preds) + expected_loss = y_true.shape[0] * expected_loss + self.assertAlmostEqual(self.evaluate(loss), expected_loss, 3)