-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
57 lines (42 loc) · 1.54 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""
loss.py
Mar 4 2023
Gabriel Moreira
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Callable
class ProtoLoss(nn.Module):
def __init__(self,
shot: int,
way: int,
query: int,
distance_fn: Callable,
centroid_fn: Callable,
device: str='cuda'):
super(ProtoLoss, self).__init__()
self.shot = shot
self.way = way
self.p = self.shot * self.way
self.query = query
self.distance_fn = distance_fn
self.centroid_fn = centroid_fn
self.label = torch.arange(self.way).repeat(self.query).to(device)
# Store scores to compute accuracies
self.t = None
self.tc = None
def forward(self, x: torch.Tensor, target: torch.Tensor = None):
x_shot, x_query = x[:self.p,...], x[self.p:,...]
x_shot = x_shot.reshape((self.shot, self.way, -1))
if self.shot > 1:
x_prototypes = self.centroid_fn(x_shot)
else:
x_prototypes = x_shot.squeeze(0)
logits = -self.distance_fn(x_query, x_prototypes)
loss = F.cross_entropy(logits, self.label)
self.tc = (torch.argmax(logits, dim=-1) == self.label).sum().detach().cpu().item()
self.t = logits.shape[0]
return loss
def scores(self):
return self.tc, self.t