diff --git a/stable_nalu/layer/generalized.py b/stable_nalu/layer/generalized.py index dce7525..1d1a305 100644 --- a/stable_nalu/layer/generalized.py +++ b/stable_nalu/layer/generalized.py @@ -206,8 +206,8 @@ def __init__(self, input_size, hidden_size, unit_name, writer=None, **kwags): def reset_parameters(self): self.cell.reset_parameters() - def forward(self, x_t, h_tm1): - return self.cell(x_t, h_tm1) + def forward(self, *args): + return self.cell(*args) def extra_repr(self): return 'input_size={}, hidden_size={}, unit_name={}'.format( diff --git a/stable_nalu/layer/mclstm.py b/stable_nalu/layer/mclstm.py index 4720392..8a806da 100644 --- a/stable_nalu/layer/mclstm.py +++ b/stable_nalu/layer/mclstm.py @@ -29,8 +29,7 @@ def reset_parameters(self): self.junction.reset_parameters() self.redistribution.reset_parameters() - def forward(self, xt_m, c): - xt_a = xt_m.new_ones((len(xt_m), self.aux_input_size)) + def forward(self, xt_m, xt_a, c): j = self.junction(xt_a) r = self.redistribution(xt_a) o = self.out_gate(xt_a) @@ -39,6 +38,7 @@ def forward(self, xt_m, c): m_sys = torch.matmul(c.unsqueeze(-2), r).squeeze(-2) m_new = m_in + m_sys return o * m_new, (1 - o) * m_new + # return m_new, m_new def get_redistribution(kind: str, diff --git a/stable_nalu/network/sequential_mnist.py b/stable_nalu/network/sequential_mnist.py index 88e281f..c4c590c 100644 --- a/stable_nalu/network/sequential_mnist.py +++ b/stable_nalu/network/sequential_mnist.py @@ -72,11 +72,17 @@ def _forward_trainable_accumulator(self, x): else: h_tm1 = self.zero_state.repeat(x.size(0), 1) + if self.unit_name == 'MCLSTM': + auxiliaries = x.new_ones(*x.shape[:2], 1) + auxiliaries[:, -1] = -1 + for t in range(x.size(1)): x_t = x[:, t] l_t = self.image2label(x_t) - if self.nac_mul == 'none' or self.nac_mul == 'mnac': + if self.unit_name == 'MCLSTM': + h_t = self.recurent_cell(l_t, auxiliaries[:, t], h_tm1) + elif self.nac_mul == 'none' or self.nac_mul == 'mnac': h_t = self.recurent_cell(l_t, h_tm1) elif self.nac_mul == 'normal': h_t = torch.exp(self.recurent_cell(