Skip to content

Commit

Permalink
Added SNR distance. Closes #64 (#205)
Browse files Browse the repository at this point in the history
Added Signal-to-Noise Ratio distance metric as defined in
[Signal-to-Noise Ratio: A Robust Distance Metric for Deep Metric Learning](https://arxiv.org/abs/1904.02616)
  • Loading branch information
abhisharsinha authored Dec 21, 2021
1 parent be59fcc commit 681f100
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 2 deletions.
49 changes: 47 additions & 2 deletions tensorflow_similarity/distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def call(self, embeddings: FloatTensor) -> FloatTensor:
class SquaredEuclideanDistance(Distance):
"""Compute pairwise squared Euclidean distance.
The [Sequared Euclidean Distance](https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance) is
The [Squared Euclidean Distance](https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance) is
a distance that varies from 0 (similar) to infinity (dissimilar).
"""
def __init__(self):
Expand Down Expand Up @@ -214,13 +214,58 @@ def call(self, embeddings: FloatTensor) -> FloatTensor:
return distances


@tf.keras.utils.register_keras_serializable(package="Similarity")
class SNRDistance(Distance):
"""
Computes pairwise SNR distances between embeddings.
The [Signal-to-Noise Ratio distance](https://arxiv.org/abs/1904.02616)
is the ratio of noise variance to the feature variance.
"""
def __init__(self):
"Init SNR distance"
super().__init__('snr')

@tf.function
def call(self, embeddings: FloatTensor) -> FloatTensor:
"""Compute pairwise snr distances for a given batch of embeddings.
SNR(i, j): anchor i and compared feature j
SNR(i,j) may not be equal to SNR(j, i)
Args:
embeddings: Embeddings to compute the pairwise one.
Returns:
FloatTensor: Pairwise distance tensor.
"""
# Calculating feature variance for each example
embed_mean = tf.math.reduce_mean(embeddings, axis=1)
embed_square = tf.math.square(embeddings)
embed_sq_mean = tf.math.reduce_mean(embed_square, axis=1)
anchor_var = embed_sq_mean - tf.square(embed_mean)

# Calculating pairwise noise variances
x_rs = tf.reshape(embeddings, shape=[tf.shape(embeddings)[0], -1])
delta = tf.expand_dims(x_rs, axis=1) - tf.expand_dims(x_rs, axis=0)
delta_mean = tf.math.reduce_mean(delta, axis=2)
delta_sq = tf.math.square(delta)
delta_sq_mean = tf.math.reduce_mean(delta_sq, axis=2)
noise_var = delta_sq_mean - tf.square(delta_mean)

distances: FloatTensor = tf.divide(noise_var,
tf.expand_dims(anchor_var, axis=1))

return distances


# List of implemented distances
DISTANCES = [
InnerProductSimilarity(),
EuclideanDistance(),
SquaredEuclideanDistance(),
ManhattanDistance(),
CosineDistance()
CosineDistance(),
SNRDistance()
]


Expand Down
27 changes: 27 additions & 0 deletions tests/test_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tensorflow_similarity.distances import CosineDistance, InnerProductSimilarity
from tensorflow_similarity.distances import EuclideanDistance
from tensorflow_similarity.distances import ManhattanDistance
from tensorflow_similarity.distances import SNRDistance
from tensorflow_similarity.distances import distance_canonicalizer
from tensorflow_similarity.distances import DISTANCES

Expand Down Expand Up @@ -131,3 +132,29 @@ def test_innerprod():
d = InnerProductSimilarity()
vals = d(a)
assert tf.round(tf.reduce_sum(vals)) == 65


def test_snr_dist():
"""
Comparing SNRDistance with simple loop based implementation
of SNR distance.
"""
num_inputs = 3
dims = 5
x = np.random.uniform(0, 1, (num_inputs, dims))

# Computing SNR distance values using loop
snr_pairs = []
for i in range(num_inputs):
row = []
for j in range(num_inputs):
dist = np.var(x[i]-x[j])/np.var(x[i])
row.append(dist)
snr_pairs.append(row)
snr_pairs = np.array(snr_pairs)

x = tf.convert_to_tensor(x)
snr_distances = SNRDistance()(x).numpy()
assert np.all(snr_distances >= 0)
diff = snr_distances - snr_pairs
assert np.all(np.abs(diff) < 1e-4)

0 comments on commit 681f100

Please sign in to comment.