Skip to content

Commit

Permalink
Merge master into dev (#348)
Browse files Browse the repository at this point in the history
Merging master back into dev for consistency.
  • Loading branch information
owenvallis authored Aug 11, 2023
1 parent cb35ce4 commit 96b89a9
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 33 deletions.
20 changes: 7 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,13 @@ With Tensorflow Similarity you can train two main types of models:

## What's new

- [May 2022]: 0.16 major optimization release
* Cross-batch memory (XBM) loss added thank to @chjort
* Many self-supervised related improvement thanks to @dewball345
* Major layers and callback refactoring to make them faster and more flexible. E.g `EvalCallback()` now support splited validation.
For full changes see [the changelog](./releases.md)

- [Jan 2022]: 0.15 self-supervised release
* Added support for self-supervised contrastive learning. Including SimCLR, SimSiam, and Barlow Twins. Checkout the in-depth [hello world notebook](examples/unsupervised_hello_world.ipynb) to get started.
* Soft Nearest Neighbor Loss added thanks to [Abhishar Sinha](https://github.com/abhisharsinha)
* Added GenerlizedMeanPooling2D support that improves similarity matching accuracy over GlobalMeanPooling2D.
* Numerous speed optimizations and general bug fixes.

For previous changes and more details - see [the changelog](./releases.md)
- [Mar 2023]: 0.17 more losses and metric and massive refactoring
* Added VicReg Loss to contrastive losses.
* Added metrics used in retrieval papers such as Precision@K
* Native support for distributed training e.g SimClr now works correctly with distributed training.
* Multi-modal embedding initial support (CLIP)

For more details and previous releases information - see [the changelog](./releases.md)

## Getting Started

Expand Down
2 changes: 1 addition & 1 deletion examples/unsupervised_hello_world.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@
"metadata": {},
"outputs": [],
"source": [
"ALGORITHM = \"simsiam\" # @param [\"barlow\", \"simsiam\", \"simclr\", \"vigreg\"]"
"ALGORITHM = \"simsiam\" # @param [\"barlow\", \"simsiam\", \"simclr\", \"vicreg\"]"
]
},
{
Expand Down
1 change: 1 addition & 0 deletions tensorflow_similarity/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""
from .barlow import Barlow # noqa
from .circle_loss import CircleLoss # noqa
from .lifted_structure_loss import LiftedStructLoss # noqa
from .metric_loss import MetricLoss # noqa
from .multinegrank_loss import MultiNegativesRankLoss # noqa
from .multisim_loss import MultiSimilarityLoss # noqa
Expand Down
124 changes: 124 additions & 0 deletions tensorflow_similarity/losses/lifted_structure_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Lifted Structured Loss
Deep Metric Learning via Lifted Structured Feature Embedding.
https://arxiv.org/abs/1511.06452
"""
from __future__ import annotations

import tensorflow as tf

from tensorflow_similarity import losses as tfsim_losses
from tensorflow_similarity.algebra import build_masks
from tensorflow_similarity.distances import Distance, distance_canonicalizer
from tensorflow_similarity.types import FloatTensor, IntTensor

from .metric_loss import MetricLoss
from .utils import positive_distances


def lifted_struct_loss(
labels: IntTensor,
embeddings: FloatTensor,
key_labels: IntTensor,
key_embeddings: FloatTensor,
distance: Distance,
positive_mining_strategy: str = "hard",
margin: float = 1.0,
) -> FloatTensor:
"""Lifted Struct loss computations"""

# Compute pairwise distances
pairwise_distances = distance(embeddings, key_embeddings)

# Build masks for positive and negative pairs
positive_mask, negative_mask = build_masks(
query_labels=labels, key_labels=key_labels, batch_size=tf.shape(embeddings)[0]
)

# Get positive distances and indices
positive_dists, positive_indices = positive_distances(positive_mining_strategy, pairwise_distances, positive_mask)

# Reorder pairwise distances and negative mask based on positive indices
reordered_pairwise_distances = tf.gather(pairwise_distances, positive_indices, axis=1)
reordered_negative_mask = tf.gather(negative_mask, positive_indices, axis=1)

# Concatenate pairwise distances and negative masks along axis=1
concatenated_distances = tf.concat([pairwise_distances, reordered_pairwise_distances], axis=1)
concatenated_negative_mask = tf.concat([negative_mask, reordered_negative_mask], axis=1)
concatenated_negative_mask = tf.cast(concatenated_negative_mask, tf.float32)
# Compute (margin - neg_dist) logsum_exp values for each row (equation 4 in the paper)
neg_logsumexp = tfsim_losses.utils.logsumexp(margin - concatenated_distances, concatenated_negative_mask)

# Calculate the loss
j_values = neg_logsumexp + positive_dists

loss: FloatTensor = j_values / 2.0

return loss


@tf.keras.utils.register_keras_serializable(package="Similarity")
class LiftedStructLoss(MetricLoss):
"""Computes the lifted structured loss in an online fashion.
This loss encourages the positive distances between a pair of embeddings
with the same labels to be smaller than the negative distances between pair
of embeddings of different labels.
See: https://arxiv.org/abs/1511.06452 for the original paper.
`y_true` must be a 1-D integer `Tensor` of shape (batch_size,).
It's values represent the classes associated with the examples as
**integer values**.
`y_pred` must be 2-D float `Tensor` of L2 normalized embedding vectors.
You can use the layer `tensorflow_similarity.layers.L2Embedding()` as the
last layer of your model to ensure your model output is properly normalized.
"""

def __init__(
self,
distance: Distance | str = "cosine",
positive_mining_strategy: str = "hard",
margin: float = 1.0,
name: str = "LiftedStructLoss",
**kwargs,
):
"""Initializes the LiftedStructLoss.
Args:
distance: Which distance function to use to compute the pairwise
distances between embeddings.
positive_mining_strategy: What mining strategy to use to select
embedding from the same class. Defaults to 'hard'.
Available: {'easy', 'hard'}
margin: Use an explicit value for the margin term.
name: Loss name. Defaults to "LiftedStructLoss".
Raises:
ValueError: Invalid positive mining strategy.
"""

# distance canonicalization
distance = distance_canonicalizer(distance)
self.distance = distance

# sanity checks
if positive_mining_strategy not in ["easy", "hard"]:
raise ValueError("Invalid positive mining strategy")

super().__init__(
lifted_struct_loss,
name=name,
distance=distance,
positive_mining_strategy=positive_mining_strategy,
margin=margin,
**kwargs,
)
42 changes: 23 additions & 19 deletions tensorflow_similarity/samplers/tfdata_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def apply_augmenter_ds(ds: tf.data.Dataset, augmenter: Callable, warmup: int = 0
Args:
ds: A `tf.data.Dataset` object.
augmenter: A callable function used to apply data augmentation to
individual examples. If `None`, no data augmentation is applied.
individual examples within each batch. If `None`, no data
augmentation is applied.
warmup: An integer representing the number of examples to wait
before applying the data augmentation function.
Expand Down Expand Up @@ -139,33 +140,36 @@ def TFDataSampler(
label_output: int | str | None = None,
) -> tf.data.Dataset:
"""
Returns a `tf.data.Dataset` object that generates batches of examples with an
equal number of examples per class. The input dataset cardinality must be finite
and known.
Returns a `tf.data.Dataset` object that generates batches of examples with
equal number of examples per class. The input dataset cardinality must be
finite and known.
Args:
ds: A `tf.data.Dataset` object representing the original dataset.
classes_per_batch: An integer specifying the number of classes per batch.
examples_per_class_per_batch: An integer specifying the number of examples
per class per batch.
class_list: An optional sequence of integers representing the class IDs to
include in the dataset. If `None`, all classes in the original dataset
will be used.
total_examples_per_class: An optional integer representing the total number
of examples per class to use. If `None`, all examples for each class will
be used.
augmenter: An optional function to apply data augmentation to each example.
load_fn: An optional callable function for loading real examples from `x`. It
is useful for loading images from their corresponding file path provided
in `x` or similar situations.
class_list: An optional sequence of integers representing the class IDs
to include in the dataset. If `None`, all classes in the original
dataset will be used.
total_examples_per_class: An optional integer representing the total
number of examples per class to use. If `None`, all examples for
each class will be used.
augmenter: An optional function to apply data augmentation to each
example in a batch.
load_fn: An optional callable function for loading real examples from `x`.
It is useful for loading images from their corresponding file path
provided in `x` or similar situations.
warmup: An integer specifying the number of examples to use for unaugmented
warmup.
label_output: An optional integer or string representing the label output used
to create the balanced dataset batches. If `None`, y is assumed to be a 1D
integer tensor containing the class IDs.
label_output: An optional integer or string representing the label output
used to create the balanced dataset batches. If `None`, y is assumed
to be a 1D integer tensor containing the class IDs. If `int`, y is
assumed to be a tuple of tensors with the class IDs in the element
specified by `label_output`. If `str`, y is assumed to be a dictionary
with the class IDs in the key specified by `label_output`.
Returns:
A `tf.data.Dataset` object representing the balanced dataset batches.
A `tf.data.Dataset` object representing the balanced dataset for few-shot learning tasks.
Raises:
ValueError: If `ds` is an infinite dataset or the cardinality is unknown.
Expand Down
67 changes: 67 additions & 0 deletions tests/losses/test_lifted_structure_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import tensorflow as tf
from absl.testing import parameterized
from tensorflow.keras.losses import Reduction
from tensorflow.python.framework import combinations

from tensorflow_similarity import losses

from . import utils


@combinations.generate(combinations.combine(mode=["graph", "eager"]))
class TestLiftedStructLoss(tf.test.TestCase, parameterized.TestCase):
def test_config(self):
lifted_obj = losses.LiftedStructLoss(
reduction=Reduction.SUM,
name="lifted_loss",
)
self.assertEqual(lifted_obj.distance.name, "cosine")
self.assertEqual(lifted_obj.name, "lifted_loss")
self.assertEqual(lifted_obj.reduction, Reduction.SUM)

@parameterized.named_parameters(
{"testcase_name": "_fixed_margin", "margin": 1.1, "expected_loss": 157.68167},
)
def test_all_correct_unweighted(self, margin, expected_loss):
"""Tests the LiftedStructLoss with different parameters."""
y_true, y_preds = utils.generate_perfect_test_batch()

lifted_obj = losses.LiftedStructLoss(reduction=Reduction.SUM, margin=margin)
loss = lifted_obj(y_true, y_preds)
self.assertAlmostEqual(self.evaluate(loss), expected_loss, 3)

@parameterized.named_parameters(
{"testcase_name": "_fixed_margin", "margin": 1.0, "expected_loss": 187.37393},
)
def test_all_mismatch_unweighted(self, margin, expected_loss):
"""Tests the LiftedStructLoss with different parameters."""
y_true, y_preds = utils.generate_bad_test_batch()

lifted_obj = losses.LiftedStructLoss(reduction=Reduction.SUM, margin=margin)
loss = lifted_obj(y_true, y_preds)
self.assertAlmostEqual(self.evaluate(loss), expected_loss, 3)

@parameterized.named_parameters(
{"testcase_name": "_fixed_margin", "margin": 1.0, "expected_loss": 2.927718},
)
def test_no_reduction(self, margin, expected_loss):
"""Tests the LiftedStructLoss with different parameters."""
y_true, y_preds = utils.generate_bad_test_batch()

lifted_obj = losses.LiftedStructLoss(reduction=Reduction.NONE, margin=margin)
loss = lifted_obj(y_true, y_preds)
loss = self.evaluate(loss)
expected_loss = self.evaluate(tf.fill(y_true.shape, expected_loss))
self.assertArrayNear(loss, expected_loss, 0.001)

@parameterized.named_parameters(
{"testcase_name": "_fixed_margin", "margin": 1.0, "expected_loss": 2.414156913757324},
)
def test_sum_reduction(self, margin, expected_loss):
"""Tests the LiftedStructLoss with different parameters."""
y_true, y_preds = utils.generate_perfect_test_batch()

lifted_obj = losses.LiftedStructLoss(reduction=Reduction.SUM, margin=margin)
loss = lifted_obj(y_true, y_preds)
expected_loss = y_true.shape[0] * expected_loss
self.assertAlmostEqual(self.evaluate(loss), expected_loss, 3)

0 comments on commit 96b89a9

Please sign in to comment.