Skip to content

Commit

Permalink
provide auxiliary inputs to mclstm
Browse files Browse the repository at this point in the history
  • Loading branch information
hoedt committed Jan 14, 2021
1 parent ab6fc75 commit b7522fb
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
4 changes: 2 additions & 2 deletions stable_nalu/layer/generalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def __init__(self, input_size, hidden_size, unit_name, writer=None, **kwags):
def reset_parameters(self):
self.cell.reset_parameters()

def forward(self, x_t, h_tm1):
return self.cell(x_t, h_tm1)
def forward(self, *args):
return self.cell(*args)

def extra_repr(self):
return 'input_size={}, hidden_size={}, unit_name={}'.format(
Expand Down
4 changes: 2 additions & 2 deletions stable_nalu/layer/mclstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ def reset_parameters(self):
self.junction.reset_parameters()
self.redistribution.reset_parameters()

def forward(self, xt_m, c):
xt_a = xt_m.new_ones((len(xt_m), self.aux_input_size))
def forward(self, xt_m, xt_a, c):
j = self.junction(xt_a)
r = self.redistribution(xt_a)
o = self.out_gate(xt_a)
Expand All @@ -39,6 +38,7 @@ def forward(self, xt_m, c):
m_sys = torch.matmul(c.unsqueeze(-2), r).squeeze(-2)
m_new = m_in + m_sys
return o * m_new, (1 - o) * m_new
# return m_new, m_new


def get_redistribution(kind: str,
Expand Down
8 changes: 7 additions & 1 deletion stable_nalu/network/sequential_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,17 @@ def _forward_trainable_accumulator(self, x):
else:
h_tm1 = self.zero_state.repeat(x.size(0), 1)

if self.unit_name == 'MCLSTM':
auxiliaries = x.new_ones(*x.shape[:2], 1)
auxiliaries[:, -1] = -1

for t in range(x.size(1)):
x_t = x[:, t]
l_t = self.image2label(x_t)

if self.nac_mul == 'none' or self.nac_mul == 'mnac':
if self.unit_name == 'MCLSTM':
h_t = self.recurent_cell(l_t, auxiliaries[:, t], h_tm1)
elif self.nac_mul == 'none' or self.nac_mul == 'mnac':
h_t = self.recurent_cell(l_t, h_tm1)
elif self.nac_mul == 'normal':
h_t = torch.exp(self.recurent_cell(
Expand Down

0 comments on commit b7522fb

Please sign in to comment.