Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge master into dev #348

Merged
merged 17 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,13 @@ With Tensorflow Similarity you can train two main types of models:

## What's new

- [May 2022]: 0.16 major optimization release
* Cross-batch memory (XBM) loss added thank to @chjort
* Many self-supervised related improvement thanks to @dewball345
* Major layers and callback refactoring to make them faster and more flexible. E.g `EvalCallback()` now support splited validation.
For full changes see [the changelog](./releases.md)

- [Jan 2022]: 0.15 self-supervised release
* Added support for self-supervised contrastive learning. Including SimCLR, SimSiam, and Barlow Twins. Checkout the in-depth [hello world notebook](examples/unsupervised_hello_world.ipynb) to get started.
* Soft Nearest Neighbor Loss added thanks to [Abhishar Sinha](https://github.com/abhisharsinha)
* Added GenerlizedMeanPooling2D support that improves similarity matching accuracy over GlobalMeanPooling2D.
* Numerous speed optimizations and general bug fixes.

For previous changes and more details - see [the changelog](./releases.md)
- [Mar 2023]: 0.17 more losses and metric and massive refactoring
* Added VicReg Loss to contrastive losses.
* Added metrics used in retrieval papers such as Precision@K
* Native support for distributed training e.g SimClr now works correctly with distributed training.
* Multi-modal embedding initial support (CLIP)

For more details and previous releases information - see [the changelog](./releases.md)

## Getting Started

Expand Down
2 changes: 1 addition & 1 deletion examples/unsupervised_hello_world.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@
"metadata": {},
"outputs": [],
"source": [
"ALGORITHM = \"simsiam\" # @param [\"barlow\", \"simsiam\", \"simclr\", \"vigreg\"]"
"ALGORITHM = \"simsiam\" # @param [\"barlow\", \"simsiam\", \"simclr\", \"vicreg\"]"
]
},
{
Expand Down
1 change: 1 addition & 0 deletions tensorflow_similarity/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""
from .barlow import Barlow # noqa
from .circle_loss import CircleLoss # noqa
from .lifted_structure_loss import LiftedStructLoss # noqa
from .metric_loss import MetricLoss # noqa
from .multinegrank_loss import MultiNegativesRankLoss # noqa
from .multisim_loss import MultiSimilarityLoss # noqa
Expand Down
124 changes: 124 additions & 0 deletions tensorflow_similarity/losses/lifted_structure_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Lifted Structured Loss
Deep Metric Learning via Lifted Structured Feature Embedding.
https://arxiv.org/abs/1511.06452
"""
from __future__ import annotations

import tensorflow as tf

from tensorflow_similarity import losses as tfsim_losses
from tensorflow_similarity.algebra import build_masks
from tensorflow_similarity.distances import Distance, distance_canonicalizer
from tensorflow_similarity.types import FloatTensor, IntTensor

from .metric_loss import MetricLoss
from .utils import positive_distances


def lifted_struct_loss(
labels: IntTensor,
embeddings: FloatTensor,
key_labels: IntTensor,
key_embeddings: FloatTensor,
distance: Distance,
positive_mining_strategy: str = "hard",
margin: float = 1.0,
) -> FloatTensor:
"""Lifted Struct loss computations"""

# Compute pairwise distances
pairwise_distances = distance(embeddings, key_embeddings)

# Build masks for positive and negative pairs
positive_mask, negative_mask = build_masks(
query_labels=labels, key_labels=key_labels, batch_size=tf.shape(embeddings)[0]
)

# Get positive distances and indices
positive_dists, positive_indices = positive_distances(positive_mining_strategy, pairwise_distances, positive_mask)

# Reorder pairwise distances and negative mask based on positive indices
reordered_pairwise_distances = tf.gather(pairwise_distances, positive_indices, axis=1)
reordered_negative_mask = tf.gather(negative_mask, positive_indices, axis=1)

# Concatenate pairwise distances and negative masks along axis=1
concatenated_distances = tf.concat([pairwise_distances, reordered_pairwise_distances], axis=1)
concatenated_negative_mask = tf.concat([negative_mask, reordered_negative_mask], axis=1)
concatenated_negative_mask = tf.cast(concatenated_negative_mask, tf.float32)
# Compute (margin - neg_dist) logsum_exp values for each row (equation 4 in the paper)
neg_logsumexp = tfsim_losses.utils.logsumexp(margin - concatenated_distances, concatenated_negative_mask)

# Calculate the loss
j_values = neg_logsumexp + positive_dists

loss: FloatTensor = j_values / 2.0

return loss


@tf.keras.utils.register_keras_serializable(package="Similarity")
class LiftedStructLoss(MetricLoss):
"""Computes the lifted structured loss in an online fashion.
This loss encourages the positive distances between a pair of embeddings
with the same labels to be smaller than the negative distances between pair
of embeddings of different labels.
See: https://arxiv.org/abs/1511.06452 for the original paper.
`y_true` must be a 1-D integer `Tensor` of shape (batch_size,).
It's values represent the classes associated with the examples as
**integer values**.
`y_pred` must be 2-D float `Tensor` of L2 normalized embedding vectors.
You can use the layer `tensorflow_similarity.layers.L2Embedding()` as the
last layer of your model to ensure your model output is properly normalized.
"""

def __init__(
self,
distance: Distance | str = "cosine",
positive_mining_strategy: str = "hard",
margin: float = 1.0,
name: str = "LiftedStructLoss",
**kwargs,
):
"""Initializes the LiftedStructLoss.
Args:
distance: Which distance function to use to compute the pairwise
distances between embeddings.
positive_mining_strategy: What mining strategy to use to select
embedding from the same class. Defaults to 'hard'.
Available: {'easy', 'hard'}
margin: Use an explicit value for the margin term.
name: Loss name. Defaults to "LiftedStructLoss".
Raises:
ValueError: Invalid positive mining strategy.
"""

# distance canonicalization
distance = distance_canonicalizer(distance)
self.distance = distance

# sanity checks
if positive_mining_strategy not in ["easy", "hard"]:
raise ValueError("Invalid positive mining strategy")

super().__init__(
lifted_struct_loss,
name=name,
distance=distance,
positive_mining_strategy=positive_mining_strategy,
margin=margin,
**kwargs,
)
42 changes: 23 additions & 19 deletions tensorflow_similarity/samplers/tfdata_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def apply_augmenter_ds(ds: tf.data.Dataset, augmenter: Callable, warmup: int = 0
Args:
ds: A `tf.data.Dataset` object.
augmenter: A callable function used to apply data augmentation to
individual examples. If `None`, no data augmentation is applied.
individual examples within each batch. If `None`, no data
augmentation is applied.
warmup: An integer representing the number of examples to wait
before applying the data augmentation function.

Expand Down Expand Up @@ -139,33 +140,36 @@ def TFDataSampler(
label_output: int | str | None = None,
) -> tf.data.Dataset:
"""
Returns a `tf.data.Dataset` object that generates batches of examples with an
equal number of examples per class. The input dataset cardinality must be finite
and known.
Returns a `tf.data.Dataset` object that generates batches of examples with
equal number of examples per class. The input dataset cardinality must be
finite and known.

Args:
ds: A `tf.data.Dataset` object representing the original dataset.
classes_per_batch: An integer specifying the number of classes per batch.
examples_per_class_per_batch: An integer specifying the number of examples
per class per batch.
class_list: An optional sequence of integers representing the class IDs to
include in the dataset. If `None`, all classes in the original dataset
will be used.
total_examples_per_class: An optional integer representing the total number
of examples per class to use. If `None`, all examples for each class will
be used.
augmenter: An optional function to apply data augmentation to each example.
load_fn: An optional callable function for loading real examples from `x`. It
is useful for loading images from their corresponding file path provided
in `x` or similar situations.
class_list: An optional sequence of integers representing the class IDs
to include in the dataset. If `None`, all classes in the original
dataset will be used.
total_examples_per_class: An optional integer representing the total
number of examples per class to use. If `None`, all examples for
each class will be used.
augmenter: An optional function to apply data augmentation to each
example in a batch.
load_fn: An optional callable function for loading real examples from `x`.
It is useful for loading images from their corresponding file path
provided in `x` or similar situations.
warmup: An integer specifying the number of examples to use for unaugmented
warmup.
label_output: An optional integer or string representing the label output used
to create the balanced dataset batches. If `None`, y is assumed to be a 1D
integer tensor containing the class IDs.

label_output: An optional integer or string representing the label output
used to create the balanced dataset batches. If `None`, y is assumed
to be a 1D integer tensor containing the class IDs. If `int`, y is
assumed to be a tuple of tensors with the class IDs in the element
specified by `label_output`. If `str`, y is assumed to be a dictionary
with the class IDs in the key specified by `label_output`.
Returns:
A `tf.data.Dataset` object representing the balanced dataset batches.
A `tf.data.Dataset` object representing the balanced dataset for few-shot learning tasks.

Raises:
ValueError: If `ds` is an infinite dataset or the cardinality is unknown.
Expand Down
67 changes: 67 additions & 0 deletions tests/losses/test_lifted_structure_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import tensorflow as tf
from absl.testing import parameterized
from tensorflow.keras.losses import Reduction
from tensorflow.python.framework import combinations

from tensorflow_similarity import losses

from . import utils


@combinations.generate(combinations.combine(mode=["graph", "eager"]))
class TestLiftedStructLoss(tf.test.TestCase, parameterized.TestCase):
def test_config(self):
lifted_obj = losses.LiftedStructLoss(
reduction=Reduction.SUM,
name="lifted_loss",
)
self.assertEqual(lifted_obj.distance.name, "cosine")
self.assertEqual(lifted_obj.name, "lifted_loss")
self.assertEqual(lifted_obj.reduction, Reduction.SUM)

@parameterized.named_parameters(
{"testcase_name": "_fixed_margin", "margin": 1.1, "expected_loss": 157.68167},
)
def test_all_correct_unweighted(self, margin, expected_loss):
"""Tests the LiftedStructLoss with different parameters."""
y_true, y_preds = utils.generate_perfect_test_batch()

lifted_obj = losses.LiftedStructLoss(reduction=Reduction.SUM, margin=margin)
loss = lifted_obj(y_true, y_preds)
self.assertAlmostEqual(self.evaluate(loss), expected_loss, 3)

@parameterized.named_parameters(
{"testcase_name": "_fixed_margin", "margin": 1.0, "expected_loss": 187.37393},
)
def test_all_mismatch_unweighted(self, margin, expected_loss):
"""Tests the LiftedStructLoss with different parameters."""
y_true, y_preds = utils.generate_bad_test_batch()

lifted_obj = losses.LiftedStructLoss(reduction=Reduction.SUM, margin=margin)
loss = lifted_obj(y_true, y_preds)
self.assertAlmostEqual(self.evaluate(loss), expected_loss, 3)

@parameterized.named_parameters(
{"testcase_name": "_fixed_margin", "margin": 1.0, "expected_loss": 2.927718},
)
def test_no_reduction(self, margin, expected_loss):
"""Tests the LiftedStructLoss with different parameters."""
y_true, y_preds = utils.generate_bad_test_batch()

lifted_obj = losses.LiftedStructLoss(reduction=Reduction.NONE, margin=margin)
loss = lifted_obj(y_true, y_preds)
loss = self.evaluate(loss)
expected_loss = self.evaluate(tf.fill(y_true.shape, expected_loss))
self.assertArrayNear(loss, expected_loss, 0.001)

@parameterized.named_parameters(
{"testcase_name": "_fixed_margin", "margin": 1.0, "expected_loss": 2.414156913757324},
)
def test_sum_reduction(self, margin, expected_loss):
"""Tests the LiftedStructLoss with different parameters."""
y_true, y_preds = utils.generate_perfect_test_batch()

lifted_obj = losses.LiftedStructLoss(reduction=Reduction.SUM, margin=margin)
loss = lifted_obj(y_true, y_preds)
expected_loss = y_true.shape[0] * expected_loss
self.assertAlmostEqual(self.evaluate(loss), expected_loss, 3)
Loading