-
Notifications
You must be signed in to change notification settings - Fork 7
/
loss.py
21 lines (19 loc) · 1021 Bytes
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from pytorch_metric_learning import losses, miners
from pytorch_metric_learning.distances import CosineSimilarity, DotProductSimilarity
loss_fn = losses.MultiSimilarityLoss(alpha=1.0, beta=50, base=0.0, distance=DotProductSimilarity())
miner = miners.MultiSimilarityMiner(epsilon=0.1, distance=CosineSimilarity())
# The loss function call (this method will be called at each training iteration)
def loss_function(descriptors, labels):
# we mine the pairs/triplets if there is an online mining strategy
if miner is not None:
miner_outputs = miner(descriptors, labels)
loss = loss_fn(descriptors, labels, miner_outputs)
# calculate the % of trivial pairs/triplets
# which do not contribute in the loss value
nb_samples = descriptors.shape[0]
nb_mined = len(set(miner_outputs[0].detach().cpu().numpy()))
batch_acc = 1.0 - (nb_mined/nb_samples)
else: # no online mining
loss = loss_fn(descriptors, labels)
batch_acc = 0.0
return loss