-
Notifications
You must be signed in to change notification settings - Fork 2
/
model.py
62 lines (46 loc) · 1.7 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python
# encoding: utf-8
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch.autograd import Variable
import pickle
import pdb
from utils import length2mask
class RNNModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""
def __init__(self, ntoken, ninp, nhid, nlayers, decoder, dropout=0.5):
super(RNNModel, self).__init__()
self.ntoken = ntoken
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp)
self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout, batch_first=True)
self.decoder = decoder
self.init_weights()
self.nhid = nhid
self.nlayers = nlayers
def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.init_weights()
def forward(self, input, length=None):
emb = self.drop(self.encoder(input))
output, hidden = self.rnn(emb)
output = self.drop(output)
return output
def loss(self, data):
# forward rnn
input, target, length = data
rnn_output = self(input, length)
# discard the pad
mask = length2mask(length)
rnn_output = rnn_output.masked_select(
mask.unsqueeze(dim=2).expand_as(rnn_output)
).view(-1, self.nhid)
target = target.masked_select(mask)
# forward decoder and calculate loss
decoder_loss = self.decoder.forward_with_loss(rnn_output, target)
return decoder_loss
def forward_all(self, data, length):
output = self(data, length)
return self.decoder.forward_all(output, length)