-
Notifications
You must be signed in to change notification settings - Fork 25
/
LSTMCell.py
47 lines (34 loc) · 1.47 KB
/
LSTMCell.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.modules.rnn import *
class LayerNorm(nn.Module):
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.ones(features))
self.beta = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.gamma * (x - mean) / (std + self.eps) + self.beta
class LSTMCell(RNNCellBase):
def __init__(self, input_size, hidden_size, bias=True, dropout=0):
super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.ih = nn.Sequential(nn.Linear(input_size, 4 * hidden_size, bias), LayerNorm(4 * hidden_size))
self.hh = nn.Sequential(nn.Linear(hidden_size, 4 * hidden_size, bias), LayerNorm(4 * hidden_size))
self.c_norm = LayerNorm(hidden_size)
self.drop = nn.Dropout(dropout)
def forward(self, input, hidden):
hx, cx = hidden
gates = self.ih(input) + self.hh(hx)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
outgate = F.sigmoid(outgate)
cy = forgetgate * cx + ingate * cellgate
hy = outgate * F.tanh(self.c_norm(cy))
return hy, cy