-
Notifications
You must be signed in to change notification settings - Fork 0
/
gru_entitiylibrary.py
75 lines (61 loc) · 3.66 KB
/
gru_entitiylibrary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.autograd import Variable
class LSTMClassifier(nn.Module):
def __init__(self, embedding_dim, hidden_dim, vocab_size, label_size, batch_size):
super(LSTMClassifier, self).__init__()
self.hidden_dim = hidden_dim
self.batch_size = batch_size
self.label_size = label_size
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim)
self.lstm1 = nn.LSTM(embedding_dim, hidden_dim)
self.hidden2hidden1 = nn.Linear(hidden_dim, hidden_dim)
self.hidden2hidden1_label = nn.Linear(hidden_dim, hidden_dim)
self.cos = nn.CosineSimilarity(dim=-1)
self.dropout = nn.Dropout(0.5)
def last_timestep(self, unpacked, lengths):
# Index of the last output for each sequence.
idx = (lengths - 1).view(-1, 1).expand(unpacked.size(0),
unpacked.size(2)).unsqueeze(1)
if torch.cuda.is_available():
idx = idx.cuda()
return unpacked.gather(1, idx).squeeze()
def init_hidden(self):
if torch.cuda.is_available():
h0 = Variable(torch.zeros(1, self.batch_size, self.hidden_dim), requires_grad = False).cuda()
c0 = Variable(torch.zeros(1, self.batch_size, self.hidden_dim), requires_grad = False).cuda()
h1 = Variable(torch.zeros(1, self.label_size, self.hidden_dim), requires_grad = False).cuda()
c1 = Variable(torch.zeros(1, self.label_size, self.hidden_dim), requires_grad = False).cuda()
else:
h0 = Variable(torch.zeros(1, self.batch_size, self.hidden_dim), requires_grad = False)
c0 = Variable(torch.zeros(1, self.batch_size, self.hidden_dim), requires_grad = False)
h1 = Variable(torch.zeros(1, self.label_size, self.hidden_dim), requires_grad = False)
c1 = Variable(torch.zeros(1, self.label_size, self.hidden_dim), requires_grad = False)
return (h0,c0),(h1,c1)
def forward(self, sentence,lengths,label_input,label_seq_input):
(h0, c0), (h1, c1) = self.init_hidden()
embeds = self.word_embeddings(sentence)
#embeds = self.dropout(embeds)
packed = torch.nn.utils.rnn.pack_padded_sequence(embeds, lengths,batch_first=True)
lstm_out, (h0,c0) = self.lstm(packed,(h0,c0) )
unpacked, unpacked_len = torch.nn.utils.rnn.pad_packed_sequence(lstm_out,batch_first=True)
# get the outputs from the last *non-masked* timestep for each sentence
last_outputs = self.last_timestep(unpacked, unpacked_len)
last_outputs = self.dropout(last_outputs)
hidden_1 = self.hidden2hidden1(last_outputs)
label_embeds = self.word_embeddings(label_input)
label_embeds = self.dropout(label_embeds)
label_packed = torch.nn.utils.rnn.pack_padded_sequence(label_embeds, label_seq_input,batch_first=True)
label_lstm_out, (h1, c1) = self.lstm1(label_packed, (h1, c1))
label_unpacked, label_unpacked_len = torch.nn.utils.rnn.pad_packed_sequence(label_lstm_out,batch_first=True)
# get the outputs from the last *non-masked* timestep for each sentence
label_last_outputs = self.last_timestep(label_unpacked, label_unpacked_len)
label_last_outputs = self.dropout(label_last_outputs)
#label_hidden_1 = self.hidden2hidden1(label_last_outputs)
#label_hidden_1 = self.relu1(label_hidden_1)
label_hidden_1 = self.hidden2hidden1_label(label_last_outputs)
score = self.cos(hidden_1.unsqueeze(-2), label_hidden_1.unsqueeze(0))
#y = self.hidden2label(hidden_1)
return score