-
Notifications
You must be signed in to change notification settings - Fork 29
/
model.py
32 lines (28 loc) · 1.04 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn
class Transformer(nn.Module):
def __init__(self, base_model, num_classes, method):
super().__init__()
self.base_model = base_model
self.num_classes = num_classes
self.method = method
self.linear = nn.Linear(base_model.config.hidden_size, num_classes)
self.dropout = nn.Dropout(0.5)
for param in base_model.parameters():
param.requires_grad_(True)
def forward(self, inputs):
raw_outputs = self.base_model(**inputs)
hiddens = raw_outputs.last_hidden_state
cls_feats = hiddens[:, 0, :]
if self.method in ['ce', 'scl']:
label_feats = None
predicts = self.linear(self.dropout(cls_feats))
else:
label_feats = hiddens[:, 1:self.num_classes+1, :]
predicts = torch.einsum('bd,bcd->bc', cls_feats, label_feats)
outputs = {
'predicts': predicts,
'cls_feats': cls_feats,
'label_feats': label_feats
}
return outputs