diff --git a/stable_nalu/network/simple_function_recurrent.py b/stable_nalu/network/simple_function_recurrent.py index 892d49c..de24731 100644 --- a/stable_nalu/network/simple_function_recurrent.py +++ b/stable_nalu/network/simple_function_recurrent.py @@ -32,9 +32,9 @@ def __init__(self, unit_name, input_size=10, writer=None, **kwargs): name='recurrent_layer', **kwargs) self.output_layer = GeneralizedLayer(self.hidden_size, 1, - 'linear' - if unit_name in {'GRU', 'LSTM', 'MCLSTM', 'RNN-tanh', 'RNN-ReLU'} - else unit_name, + 'linear', + # if unit_name in {'GRU', 'LSTM', 'MCLSTM', 'RNN-tanh', 'RNN-ReLU'} + # else unit_name, writer=self.writer, name='output_layer', **kwargs)