From afffe0652d0569121ccaf854bedc72c8465de7ba Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Sun, 7 Apr 2024 17:15:42 +0800 Subject: [PATCH 01/10] :sparkles: Enhance LSTM and RNNBase --- python/paddle/nn/layer/rnn.py | 80 ++++++++++++++++++++++------- test/dygraph_to_static/test_lstm.py | 19 +++++-- test/rnn/rnn_numpy.py | 70 +++++++++++++++++-------- test/rnn/test_rnn_nets.py | 33 +++++++++++- test/rnn/test_rnn_nets_static.py | 61 ++++++++++++++++++++-- 5 files changed, 217 insertions(+), 46 deletions(-) diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index aca8b66e6ad3d..e9ce8d18b926c 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -902,6 +902,9 @@ class LSTMCell(RNNCellBase): Parameters: input_size (int): The input size. hidden_size (int): The hidden size. + proj_size (int, optional): If specified, the output hidden state + will be projected to `proj_size`. `proj_size` must be smaller than + `hidden_size`. Default: None. weight_ih_attr(ParamAttr, optional): The parameter attribute for `weight_ih`. Default: None. weight_hh_attr(ParamAttr, optional): The parameter attribute for @@ -915,7 +918,8 @@ class LSTMCell(RNNCellBase): Variables: - **weight_ih** (Parameter): shape (4 * hidden_size, input_size), input to hidden weight, which corresponds to the concatenation of :math:`W_{ii}, W_{if}, W_{ig}, W_{io}` in the formula. - - **weight_hh** (Parameter): shape (4 * hidden_size, hidden_size), hidden to hidden weight, which corresponds to the concatenation of :math:`W_{hi}, W_{hf}, W_{hg}, W_{ho}` in the formula. + - **weight_hh** (Parameter): shape (4 * hidden_size, hidden_size), hidden to hidden weight, which corresponds to the concatenation of :math:`W_{hi}, W_{hf}, W_{hg}, W_{ho}` in the formula. If proj_size was specified, the shape will be (4 * hidden_size, proj_size). + - **weight_ho** (Parameter, optional): shape (hidden_size, proj_size), project the hidden state. - **bias_ih** (Parameter): shape (4 * hidden_size, ), input to hidden bias, which corresponds to the concatenation of :math:`b_{ii}, b_{if}, b_{ig}, b_{io}` in the formula. - **bias_hh** (Parameter): shape (4 * hidden_size, ), hidden to hidden bias, which corresponds to the concatenation of :math:`b_{hi}, b_{hf}, b_{hg}, b_{ho}` in the formula. @@ -925,7 +929,8 @@ class LSTMCell(RNNCellBase): Returns: - **outputs** (Tensor): shape `[batch_size, hidden_size]`, the output, corresponding to :math:`h_{t}` in the formula. - - **states** (tuple): a tuple of two tensors, each of shape `[batch_size, hidden_size]`, the new hidden states, corresponding to :math:`h_{t}, c_{t}` in the formula. + - **states** (tuple): a tuple of two tensors, each of shape `[batch_size, hidden_size]`, if proj_size is specified, output shape of the first element will be `[batch_size, proj_size]` + the new hidden states, corresponding to :math:`h_{t}, c_{t}` in the formula. Notes: All the weights and bias are initialized with `Uniform(-std, std)` by @@ -958,6 +963,7 @@ def __init__( self, input_size, hidden_size, + proj_size=0, weight_ih_attr=None, weight_hh_attr=None, bias_ih_attr=None, @@ -969,6 +975,14 @@ def __init__( raise ValueError( f"hidden_size of {self.__class__.__name__} must be greater than 0, but now equals to {hidden_size}" ) + if proj_size and proj_size < 0: + raise ValueError( + f"proj_size of {self.__class__.__name__} must be greater than 0, but now equals to {hidden_size}" + ) + + if proj_size and proj_size >= hidden_size: + raise ValueError("proj_size must be smaller than hidden_size") + std = 1.0 / math.sqrt(hidden_size) if weight_ih_attr is not False: self.weight_ih = self.create_parameter( @@ -985,13 +999,13 @@ def __init__( self.weight_ih.stop_gradient = True if weight_hh_attr is not False: self.weight_hh = self.create_parameter( - (4 * hidden_size, hidden_size), + (4 * hidden_size, proj_size or hidden_size), weight_hh_attr, default_initializer=I.Uniform(-std, std), ) else: self.weight_hh = self.create_parameter( - (4 * hidden_size, hidden_size), + (4 * hidden_size, proj_size or hidden_size), None, default_initializer=I.Constant(1.0), ) @@ -1027,6 +1041,14 @@ def __init__( ) self.bias_hh.stop_gradient = True + self.proj_size = proj_size + if proj_size: + self.weight_ho = self.create_parameter( + (proj_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std), + ) + self.hidden_size = hidden_size self.input_size = input_size self._gate_activation = F.sigmoid @@ -1050,6 +1072,8 @@ def forward(self, inputs, states=None): o = self._gate_activation(chunked_gates[3]) c = f * pre_cell + i * self._activation(chunked_gates[2]) h = o * self._activation(c) + if self.proj_size: + h = paddle.matmul(h, self.weight_ho, transpose_y=True) return h, (h, c) @@ -1061,7 +1085,7 @@ def state_shape(self): automatically inserted into shape). These two shapes correspond to :math:`h_{t-1}` and :math:`c_{t-1}` separately. """ - return ((self.hidden_size,), (self.hidden_size,)) + return ((self.hidden_size,), (self.proj_size or self.hidden_size,)) def extra_repr(self): return '{input_size}, {hidden_size}'.format(**self.__dict__) @@ -1428,6 +1452,7 @@ def __init__( mode, input_size, hidden_size, + proj_size=0, num_layers=1, direction="forward", time_major=False, @@ -1455,28 +1480,40 @@ def __init__( "bias_hh_attr": bias_hh_attr, } + self.proj_size = proj_size + if proj_size: + assert mode == 'LSTM' + if mode == "LSTM": rnn_cls = LSTMCell + kwargs["proj_size"] = proj_size elif mode == "GRU": rnn_cls = GRUCell + elif mode == "RNN_RELU": + rnn_cls = SimpleRNNCell + kwargs["activation"] = 'relu' + elif mode == "RNN_TANH": + rnn_cls = SimpleRNNCell + kwargs["activation"] = 'tanh' else: rnn_cls = SimpleRNNCell kwargs["activation"] = self.activation + in_size = proj_size or hidden_size if direction in ["forward"]: is_reverse = False cell = rnn_cls(input_size, hidden_size, **kwargs) self.append(RNN(cell, is_reverse, time_major)) - for i in range(1, num_layers): - cell = rnn_cls(hidden_size, hidden_size, **kwargs) + for _ in range(1, num_layers): + cell = rnn_cls(in_size, hidden_size, **kwargs) self.append(RNN(cell, is_reverse, time_major)) elif direction in bidirectional_list: cell_fw = rnn_cls(input_size, hidden_size, **kwargs) cell_bw = rnn_cls(input_size, hidden_size, **kwargs) self.append(BiRNN(cell_fw, cell_bw, time_major)) - for i in range(1, num_layers): - cell_fw = rnn_cls(2 * hidden_size, hidden_size, **kwargs) - cell_bw = rnn_cls(2 * hidden_size, hidden_size, **kwargs) + for _ in range(1, num_layers): + cell_fw = rnn_cls(2 * in_size, hidden_size, **kwargs) + cell_bw = rnn_cls(2 * in_size, hidden_size, **kwargs) self.append(BiRNN(cell_fw, cell_bw, time_major)) else: raise ValueError( @@ -1652,11 +1689,8 @@ def forward(self, inputs, initial_states=None, sequence_length=None): batch_index = 1 if self.time_major else 0 dtype = inputs.dtype if initial_states is None: - state_shape = ( - self.num_layers * self.num_directions, - -1, - self.hidden_size, - ) + state_shape = (self.num_layers * self.num_directions, -1) + dims = ([self.proj_size or self.hidden_size], [self.hidden_size]) fill_shape = list(state_shape) if inputs.shape[batch_index] > 0: @@ -1665,8 +1699,10 @@ def forward(self, inputs, initial_states=None, sequence_length=None): fill_shape[1] = paddle.shape(inputs)[batch_index].item() initial_states = tuple( [ - paddle.full(shape=fill_shape, fill_value=0, dtype=dtype) - for _ in range(self.state_components) + paddle.full( + shape=fill_shape + dims[i], fill_value=0, dtype=dtype + ) + for i in range(self.state_components) ] ) else: @@ -1826,6 +1862,7 @@ def __init__( mode, input_size, hidden_size, + 0, # proj_size num_layers, direction, time_major, @@ -1883,6 +1920,9 @@ class LSTM(RNNBase): dropout (float, optional): The dropout probability. Dropout is applied to the input of each layer except for the first layer. The range of dropout from 0 to 1. Defaults to 0. + proj_size (int, optional): If specified, the output hidden state of each layer + will be projected to `proj_size`. `proj_size` must be smaller than `hidden_size`. + Default: 0. weight_ih_attr (ParamAttr, optional): The parameter attribute for `weight_ih` of each cell. Default: None. weight_hh_attr (ParamAttr, optional): The parameter attribute for @@ -1903,7 +1943,8 @@ class LSTM(RNNBase): - **outputs** (Tensor): the output sequence. If `time_major` is True, the shape is `[time_steps, batch_size, num_directions * hidden_size]`, If `time_major` is False, the shape is `[batch_size, time_steps, num_directions * hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1. `time_steps` means the length of the output sequence. - - **final_states** (tuple): the final state, a tuple of two tensors, h and c. The shape of each is `[num_layers * num_directions, batch_size, hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" (the index of forward states are 0, 2, 4, 6... and the index of backward states are 1, 3, 5, 7...), else 1. + - **final_states** (tuple): the final state, a tuple of two tensors, h and c. The shape of each is `[num_layers * num_directions, batch_size, hidden_size]`. If `proj_size` is specified, the last dimension of h will be proj_size. + Note that `num_directions` is 2 if direction is "bidirectional" (the index of forward states are 0, 2, 4, 6... and the index of backward states are 1, 3, 5, 7...), else 1. Variables: - **weight_ih_l[k]**: the learnable input-hidden weights of the k-th layer. If `k = 0`, the shape is `[hidden_size, input_size]`. Otherwise, the shape is `[hidden_size, num_directions * hidden_size]`. @@ -1942,6 +1983,7 @@ def __init__( direction="forward", time_major=False, dropout=0.0, + proj_size=0, weight_ih_attr=None, weight_hh_attr=None, bias_ih_attr=None, @@ -1952,6 +1994,7 @@ def __init__( "LSTM", input_size, hidden_size, + proj_size, num_layers, direction, time_major, @@ -2071,6 +2114,7 @@ def __init__( "GRU", input_size, hidden_size, + 0, # proj_size num_layers, direction, time_major, diff --git a/test/dygraph_to_static/test_lstm.py b/test/dygraph_to_static/test_lstm.py index 990dab4d3f21f..df6439c936885 100644 --- a/test/dygraph_to_static/test_lstm.py +++ b/test/dygraph_to_static/test_lstm.py @@ -30,10 +30,14 @@ class LSTMLayer(nn.Layer): - def __init__(self, in_channels, hidden_size): + def __init__(self, in_channels, hidden_size, proj_size=None): super().__init__() self.cell = nn.LSTM( - in_channels, hidden_size, direction='bidirectional', num_layers=2 + in_channels, + hidden_size, + direction='bidirectional', + num_layers=2, + proj_size=proj_size, ) def forward(self, x): @@ -42,9 +46,9 @@ def forward(self, x): class Net(nn.Layer): - def __init__(self, in_channels, hidden_size): + def __init__(self, in_channels, hidden_size, proj_size=None): super().__init__() - self.lstm = LSTMLayer(in_channels, hidden_size) + self.lstm = LSTMLayer(in_channels, hidden_size, proj_size=proj_size) def forward(self, x): x = self.lstm(x) @@ -128,6 +132,13 @@ def test_save_with_training(self): self.save_in_eval(with_training=True) +class TestLstmWithProjsize(TestLstm): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.net = Net(12, 8, 4) + self.inputs = paddle.zeros((2, 10, 12)) + + class LinearNet(nn.Layer): def __init__(self): super().__init__() diff --git a/test/rnn/rnn_numpy.py b/test/rnn/rnn_numpy.py index ec110ed79b746..116ec504e068a 100644 --- a/test/rnn/rnn_numpy.py +++ b/test/rnn/rnn_numpy.py @@ -162,7 +162,13 @@ def forward(self, inputs, hx=None): class LSTMCell(LayerMixin): def __init__( - self, input_size, hidden_size, weight=True, bias=True, dtype="float64" + self, + input_size, + hidden_size, + weight=True, + bias=True, + dtype="float64", + proj_size=None, ): self.input_size = input_size self.hidden_size = hidden_size @@ -175,17 +181,26 @@ def __init__( -std, std, (4 * hidden_size, input_size) ).astype(dtype) self.weight_hh = np.random.uniform( - -std, std, (4 * hidden_size, hidden_size) + -std, std, (4 * hidden_size, proj_size or hidden_size) ).astype(dtype) else: self.weight_ih = np.ones((4 * hidden_size, input_size)).astype( dtype ) - self.weight_hh = np.ones((4 * hidden_size, hidden_size)).astype( - dtype - ) + self.weight_hh = np.ones( + (4 * hidden_size, proj_size or hidden_size) + ).astype(dtype) + self.parameters['weight_ih'] = self.weight_ih self.parameters['weight_hh'] = self.weight_hh + + self.proj_size = proj_size + if proj_size: + self.weight_ho = np.random.uniform( + -std, std, (proj_size, hidden_size) + ).astype(dtype) + self.parameters['weight_ho'] = self.weight_ho + if bias: self.bias_ih = np.random.uniform( -std, std, (4 * hidden_size) @@ -224,6 +239,9 @@ def forward(self, inputs, hx=None): c = f * pre_cell + i * np.tanh(chunked_gates[2]) h = o * np.tanh(c) + if self.proj_size: + h = np.matmul(h, self.weight_ho.T) + return h, (h, c) @@ -428,21 +446,18 @@ def forward(self, inputs, initial_states=None, sequence_length=None): batch_size = inputs.shape[batch_index] dtype = inputs.dtype if initial_states is None: - state_shape = ( - self.num_layers * self.num_directions, - batch_size, - self.hidden_size, - ) + state_shape = (self.num_layers * self.num_directions, batch_size) + proj_size = self.proj_size if hasattr(self, 'proj_size') else None + dims = ((proj_size or self.hidden_size,), (self.hidden_size,)) if self.state_components == 1: - initial_states = np.zeros(state_shape, dtype) + initial_states = np.zeros(state_shape + dims[0], dtype) else: initial_states = tuple( [ - np.zeros(state_shape, dtype) - for _ in range(self.state_components) + np.zeros(state_shape + dims[i], dtype) + for i in range(self.state_components) ] ) - states = split_states( initial_states, self.num_directions == 2, self.state_components ) @@ -532,24 +547,38 @@ def __init__( dropout=0.0, time_major=False, dtype="float64", + proj_size=None, ): super().__init__() bidirectional_list = ["bidirectional", "bidirect"] + in_size = proj_size or hidden_size if direction in ["forward"]: is_reverse = False - cell = LSTMCell(input_size, hidden_size, dtype=dtype) + cell = LSTMCell( + input_size, hidden_size, dtype=dtype, proj_size=proj_size + ) self.append(RNN(cell, is_reverse, time_major)) for i in range(1, num_layers): - cell = LSTMCell(hidden_size, hidden_size, dtype=dtype) + cell = LSTMCell( + in_size, hidden_size, dtype=dtype, proj_size=proj_size + ) self.append(RNN(cell, is_reverse, time_major)) elif direction in bidirectional_list: - cell_fw = LSTMCell(input_size, hidden_size, dtype=dtype) - cell_bw = LSTMCell(input_size, hidden_size, dtype=dtype) + cell_fw = LSTMCell( + input_size, hidden_size, dtype=dtype, proj_size=proj_size + ) + cell_bw = LSTMCell( + input_size, hidden_size, dtype=dtype, proj_size=proj_size + ) self.append(BiRNN(cell_fw, cell_bw, time_major)) for i in range(1, num_layers): - cell_fw = LSTMCell(2 * hidden_size, hidden_size, dtype=dtype) - cell_bw = LSTMCell(2 * hidden_size, hidden_size, dtype=dtype) + cell_fw = LSTMCell( + 2 * in_size, hidden_size, dtype=dtype, proj_size=proj_size + ) + cell_bw = LSTMCell( + 2 * in_size, hidden_size, dtype=dtype, proj_size=proj_size + ) self.append(BiRNN(cell_fw, cell_bw, time_major)) else: raise ValueError( @@ -564,6 +593,7 @@ def __init__( self.time_major = time_major self.num_layers = num_layers self.state_components = 2 + self.proj_size = proj_size class GRU(RNNMixin): diff --git a/test/rnn/test_rnn_nets.py b/test/rnn/test_rnn_nets.py index 734dcae0fde56..a245553697663 100644 --- a/test/rnn/test_rnn_nets.py +++ b/test/rnn/test_rnn_nets.py @@ -227,7 +227,9 @@ def test_with_initial_state(self): x = np.random.randn(12, 4, 16) if not self.time_major: x = np.transpose(x, [1, 0, 2]) - prev_h = np.random.randn(2 * self.num_directions, 4, 32) + prev_h = np.random.randn( + 2 * self.num_directions, 4, getattr(self, "proj_size", 32) + ) prev_c = np.random.randn(2 * self.num_directions, 4, 32) y1, (h1, c1) = rnn1(x, (prev_h, prev_c)) @@ -289,6 +291,35 @@ def runTest(self): self.test_predict() +class TestLSTMWithProjSize(TestLSTM): + def setUp(self): + # Since `set_device` is global, set `set_device` in `setUp` rather than + # `__init__` to avoid using an error device set by another test case. + place = paddle.set_device(self.place) + paddle.disable_static(place) + rnn1 = LSTM( + 16, + 32, + 2, + time_major=self.time_major, + direction=self.direction, + proj_size=8, + ) + rnn2 = paddle.nn.LSTM( + 16, + 32, + 2, + time_major=self.time_major, + direction=self.direction, + proj_size=8, + ) + convert_params_for_net(rnn1, rnn2) + + self.rnn1 = rnn1 + self.rnn2 = rnn2 + self.proj_size = 8 + + def predict_test_util(place, mode, stop_gradient=True): place = paddle.set_device(place) paddle.seed(123) diff --git a/test/rnn/test_rnn_nets_static.py b/test/rnn/test_rnn_nets_static.py index 20b8a7975e8c2..1471e2f9333f9 100644 --- a/test/rnn/test_rnn_nets_static.py +++ b/test/rnn/test_rnn_nets_static.py @@ -394,7 +394,9 @@ def test_with_initial_state(self): x = np.random.randn(12, 4, 16) if not self.time_major: x = np.transpose(x, [1, 0, 2]) - prev_h = np.random.randn(2 * self.num_directions, 4, 32) + prev_h = np.random.randn( + 2 * self.num_directions, 4, getattr(self, "proj_size", 32) + ) prev_c = np.random.randn(2 * self.num_directions, 4, 32) y1, (h1, c1) = rnn1(x, (prev_h, prev_c)) @@ -408,7 +410,11 @@ def test_with_initial_state(self): ) init_h = paddle.static.data( "init_h", - [2 * self.num_directions, -1, 32], + [ + 2 * self.num_directions, + -1, + getattr(self, "proj_size", 32), + ], dtype=paddle.framework.get_default_dtype(), ) init_c = paddle.static.data( @@ -505,13 +511,62 @@ def runTest(self): self.test_with_input_lengths() +class TestLSTMWithProjSize(TestLSTM): + def setUp(self): + # Since `set_device` is global, set `set_device` in `setUp` rather than + # `__init__` to avoid using an error device set by another test case. + place = paddle.set_device(self.place) + rnn1 = LSTM( + 16, + 32, + 2, + time_major=self.time_major, + direction=self.direction, + proj_size=8, + ) + + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.base.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + rnn2 = paddle.nn.LSTM( + 16, + 32, + 2, + time_major=self.time_major, + direction=self.direction, + proj_size=8, + ) + + exe = paddle.static.Executor(place) + scope = paddle.base.Scope() + with paddle.static.scope_guard(scope): + exe.run(sp) + convert_params_for_net_static(rnn1, rnn2, place) + + self.mp = mp + self.sp = sp + self.rnn1 = rnn1 + self.rnn2 = rnn2 + self.proj_size = 8 + + self.place = place + self.executor = exe + self.scope = scope + + def load_tests(loader, tests, pattern): suite = unittest.TestSuite() devices = ["cpu", "gpu"] if paddle.base.is_compiled_with_cuda() else ["cpu"] for direction in ["forward", "bidirectional", "bidirect"]: for time_major in [True, False]: for device in devices: - for test_class in [TestSimpleRNN, TestLSTM, TestGRU]: + for test_class in [ + TestSimpleRNN, + TestLSTM, + TestGRU, + TestLSTMWithProjSize, + ]: suite.addTest(test_class(time_major, direction, device)) return suite From fdee58b4a9e954104a1aae8738d4514649141ce7 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Mon, 8 Apr 2024 14:29:23 +0800 Subject: [PATCH 02/10] fix ci coverage --- test/dygraph_to_static/test_lstm.py | 9 ++++++++ test/rnn/rnn_numpy.py | 12 ++++++++--- test/rnn/test_rnn_nets.py | 32 +++++++++++++++++++++++++---- test/rnn/test_rnn_nets_static.py | 19 +++++++++++++++-- 4 files changed, 63 insertions(+), 9 deletions(-) diff --git a/test/dygraph_to_static/test_lstm.py b/test/dygraph_to_static/test_lstm.py index df6439c936885..1ffdfc1c7bd16 100644 --- a/test/dygraph_to_static/test_lstm.py +++ b/test/dygraph_to_static/test_lstm.py @@ -138,6 +138,15 @@ def setUp(self): self.net = Net(12, 8, 4) self.inputs = paddle.zeros((2, 10, 12)) + def test_error(self): + # proj_size < 0 + with self.assertRaises(ValueError): + nn.LSTM(4, 4, 4, proj_size=-1) + + # proj_size >= hidden_size + with self.assertRaises(ValueError): + nn.LSTM(4, 4, 4, proj_size=20) + class LinearNet(nn.Layer): def __init__(self): diff --git a/test/rnn/rnn_numpy.py b/test/rnn/rnn_numpy.py index 116ec504e068a..c564dd9b05760 100644 --- a/test/rnn/rnn_numpy.py +++ b/test/rnn/rnn_numpy.py @@ -49,7 +49,7 @@ def __init__( self.bias = bias if nonlinearity == 'RNN_TANH': self.nonlinearity = np.tanh - else: + elif nonlinearity == 'RNN_RELU': self.nonlinearity = lambda x: np.maximum(x, 0.0) self.parameters = {} @@ -516,10 +516,16 @@ def __init__( self.append(BiRNN(cell_fw, cell_bw, time_major)) for i in range(1, num_layers): cell_fw = SimpleRNNCell( - 2 * hidden_size, hidden_size, nonlinearity, dtype=dtype + 2 * hidden_size, + hidden_size, + nonlinearity=nonlinearity, + dtype=dtype, ) cell_bw = SimpleRNNCell( - 2 * hidden_size, hidden_size, nonlinearity, dtype=dtype + 2 * hidden_size, + hidden_size, + nonlinearity=nonlinearity, + dtype=dtype, ) self.append(BiRNN(cell_fw, cell_bw, time_major)) else: diff --git a/test/rnn/test_rnn_nets.py b/test/rnn/test_rnn_nets.py index a245553697663..f23509a4b387a 100644 --- a/test/rnn/test_rnn_nets.py +++ b/test/rnn/test_rnn_nets.py @@ -27,12 +27,15 @@ class TestSimpleRNN(unittest.TestCase): - def __init__(self, time_major=True, direction="forward", place="cpu"): + def __init__( + self, time_major=True, direction="forward", place="cpu", mode='RNN_TANH' + ): super().__init__("runTest") self.time_major = time_major self.direction = direction self.num_directions = 2 if direction in bidirectional_list else 1 self.place = place + self.mode = mode def setUp(self): # Since `set_device` is global, set `set_device` in `setUp` rather than @@ -40,10 +43,20 @@ def setUp(self): place = paddle.set_device(self.place) paddle.disable_static(place) rnn1 = SimpleRNN( - 16, 32, 2, time_major=self.time_major, direction=self.direction + 16, + 32, + 2, + time_major=self.time_major, + direction=self.direction, + nonlinearity=self.mode, ) rnn2 = paddle.nn.SimpleRNN( - 16, 32, 2, time_major=self.time_major, direction=self.direction + 16, + 32, + 2, + time_major=self.time_major, + direction=self.direction, + activation=self.mode[4:].lower(), ) convert_params_for_net(rnn1, rnn2) @@ -394,8 +407,19 @@ def load_tests(loader, tests, pattern): for direction in ["forward", "bidirectional", "bidirect"]: for time_major in [True, False]: for device in devices: - for test_class in [TestSimpleRNN, TestLSTM, TestGRU]: + for test_class in [ + TestSimpleRNN, + TestLSTM, + TestGRU, + TestLSTMWithProjSize, + ]: suite.addTest(test_class(time_major, direction, device)) + if test_class == TestSimpleRNN: + suite.addTest( + test_class( + time_major, direction, device, mode="RNN_RELU" + ) + ) return suite diff --git a/test/rnn/test_rnn_nets_static.py b/test/rnn/test_rnn_nets_static.py index 1471e2f9333f9..68bd978bc398b 100644 --- a/test/rnn/test_rnn_nets_static.py +++ b/test/rnn/test_rnn_nets_static.py @@ -29,19 +29,27 @@ class TestSimpleRNN(unittest.TestCase): - def __init__(self, time_major=True, direction="forward", place="cpu"): + def __init__( + self, time_major=True, direction="forward", place="cpu", mode="RNN_TANH" + ): super().__init__("runTest") self.time_major = time_major self.direction = direction self.num_directions = 2 if direction in bidirectional_list else 1 self.place = place + self.mode = mode def setUp(self): # Since `set_device` is global, set `set_device` in `setUp` rather than # `__init__` to avoid using an error device set by another test case. place = paddle.set_device(self.place) rnn1 = SimpleRNN( - 16, 32, 2, time_major=self.time_major, direction=self.direction + 16, + 32, + 2, + time_major=self.time_major, + direction=self.direction, + nonlinearity=self.mode, ) mp = paddle.static.Program() @@ -54,6 +62,7 @@ def setUp(self): 2, time_major=self.time_major, direction=self.direction, + activation=self.mode[4:].lower(), ) exe = paddle.static.Executor(place) @@ -568,6 +577,12 @@ def load_tests(loader, tests, pattern): TestLSTMWithProjSize, ]: suite.addTest(test_class(time_major, direction, device)) + if test_class == TestSimpleRNN: + suite.addTest( + test_class( + time_major, direction, device, mode="RNN_RELU" + ) + ) return suite From d35c28de8b17dc126bc778048d1f638d9cddd8d0 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Wed, 10 Apr 2024 23:40:03 +0800 Subject: [PATCH 03/10] adjust position of proj_size --- python/paddle/nn/layer/rnn.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index e9ce8d18b926c..b1f26896b381a 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -902,9 +902,6 @@ class LSTMCell(RNNCellBase): Parameters: input_size (int): The input size. hidden_size (int): The hidden size. - proj_size (int, optional): If specified, the output hidden state - will be projected to `proj_size`. `proj_size` must be smaller than - `hidden_size`. Default: None. weight_ih_attr(ParamAttr, optional): The parameter attribute for `weight_ih`. Default: None. weight_hh_attr(ParamAttr, optional): The parameter attribute for @@ -913,6 +910,9 @@ class LSTMCell(RNNCellBase): `bias_ih`. Default: None. bias_hh_attr (ParamAttr, optional): The parameter attribute for the `bias_hh`. Default: None. + proj_size (int, optional): If specified, the output hidden state + will be projected to `proj_size`. `proj_size` must be smaller than + `hidden_size`. Default: None. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -963,11 +963,11 @@ def __init__( self, input_size, hidden_size, - proj_size=0, weight_ih_attr=None, weight_hh_attr=None, bias_ih_attr=None, bias_hh_attr=None, + proj_size=0, name=None, ): super().__init__() @@ -975,12 +975,12 @@ def __init__( raise ValueError( f"hidden_size of {self.__class__.__name__} must be greater than 0, but now equals to {hidden_size}" ) - if proj_size and proj_size < 0: + if proj_size < 0: raise ValueError( f"proj_size of {self.__class__.__name__} must be greater than 0, but now equals to {hidden_size}" ) - if proj_size and proj_size >= hidden_size: + if proj_size >= hidden_size: raise ValueError("proj_size must be smaller than hidden_size") std = 1.0 / math.sqrt(hidden_size) @@ -1452,7 +1452,6 @@ def __init__( mode, input_size, hidden_size, - proj_size=0, num_layers=1, direction="forward", time_major=False, @@ -1461,6 +1460,7 @@ def __init__( weight_hh_attr=None, bias_ih_attr=None, bias_hh_attr=None, + proj_size=0, ): super().__init__() bidirectional_list = ["bidirectional", "bidirect"] @@ -1862,7 +1862,6 @@ def __init__( mode, input_size, hidden_size, - 0, # proj_size num_layers, direction, time_major, @@ -1871,6 +1870,7 @@ def __init__( weight_hh_attr, bias_ih_attr, bias_hh_attr, + 0, # proj_size ) @@ -1920,9 +1920,6 @@ class LSTM(RNNBase): dropout (float, optional): The dropout probability. Dropout is applied to the input of each layer except for the first layer. The range of dropout from 0 to 1. Defaults to 0. - proj_size (int, optional): If specified, the output hidden state of each layer - will be projected to `proj_size`. `proj_size` must be smaller than `hidden_size`. - Default: 0. weight_ih_attr (ParamAttr, optional): The parameter attribute for `weight_ih` of each cell. Default: None. weight_hh_attr (ParamAttr, optional): The parameter attribute for @@ -1931,6 +1928,9 @@ class LSTM(RNNBase): `bias_ih` of each cells. Default: None. bias_hh_attr (ParamAttr, optional): The parameter attribute for the `bias_hh` of each cells. Default: None. + proj_size (int, optional): If specified, the output hidden state of each layer + will be projected to `proj_size`. `proj_size` must be smaller than `hidden_size`. + Default: 0. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -1983,18 +1983,17 @@ def __init__( direction="forward", time_major=False, dropout=0.0, - proj_size=0, weight_ih_attr=None, weight_hh_attr=None, bias_ih_attr=None, bias_hh_attr=None, + proj_size=0, name=None, ): super().__init__( "LSTM", input_size, hidden_size, - proj_size, num_layers, direction, time_major, @@ -2003,6 +2002,7 @@ def __init__( weight_hh_attr, bias_ih_attr, bias_hh_attr, + proj_size, ) @@ -2114,7 +2114,6 @@ def __init__( "GRU", input_size, hidden_size, - 0, # proj_size num_layers, direction, time_major, @@ -2123,4 +2122,5 @@ def __init__( weight_hh_attr, bias_ih_attr, bias_hh_attr, + 0, # proj_size ) From 6c57349ef127814db32923783074c03cecfc2426 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Thu, 11 Apr 2024 11:09:48 +0800 Subject: [PATCH 04/10] fix proj_size to number --- test/dygraph_to_static/test_lstm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/dygraph_to_static/test_lstm.py b/test/dygraph_to_static/test_lstm.py index 1ffdfc1c7bd16..31c4b37aea712 100644 --- a/test/dygraph_to_static/test_lstm.py +++ b/test/dygraph_to_static/test_lstm.py @@ -30,7 +30,7 @@ class LSTMLayer(nn.Layer): - def __init__(self, in_channels, hidden_size, proj_size=None): + def __init__(self, in_channels, hidden_size, proj_size=0): super().__init__() self.cell = nn.LSTM( in_channels, @@ -46,7 +46,7 @@ def forward(self, x): class Net(nn.Layer): - def __init__(self, in_channels, hidden_size, proj_size=None): + def __init__(self, in_channels, hidden_size, proj_size=0): super().__init__() self.lstm = LSTMLayer(in_channels, hidden_size, proj_size=proj_size) From 814b2ac73e15268ee6b6c5c9717244089ba92d1e Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Thu, 11 Apr 2024 20:27:48 +0800 Subject: [PATCH 05/10] reshape weight_ho --- python/paddle/nn/layer/rnn.py | 8 ++++---- test/rnn/rnn_numpy.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index b1f26896b381a..f5dd120073cdc 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -1042,9 +1042,9 @@ def __init__( self.bias_hh.stop_gradient = True self.proj_size = proj_size - if proj_size: + if proj_size > 0: self.weight_ho = self.create_parameter( - (proj_size, hidden_size), + (hidden_size, proj_size), weight_hh_attr, default_initializer=I.Uniform(-std, std), ) @@ -1072,8 +1072,8 @@ def forward(self, inputs, states=None): o = self._gate_activation(chunked_gates[3]) c = f * pre_cell + i * self._activation(chunked_gates[2]) h = o * self._activation(c) - if self.proj_size: - h = paddle.matmul(h, self.weight_ho, transpose_y=True) + if self.proj_size > 0: + h = paddle.matmul(h, self.weight_ho) return h, (h, c) diff --git a/test/rnn/rnn_numpy.py b/test/rnn/rnn_numpy.py index c564dd9b05760..c5e651230a4b4 100644 --- a/test/rnn/rnn_numpy.py +++ b/test/rnn/rnn_numpy.py @@ -197,7 +197,7 @@ def __init__( self.proj_size = proj_size if proj_size: self.weight_ho = np.random.uniform( - -std, std, (proj_size, hidden_size) + -std, std, (hidden_size, proj_size) ).astype(dtype) self.parameters['weight_ho'] = self.weight_ho @@ -240,7 +240,7 @@ def forward(self, inputs, hx=None): h = o * np.tanh(c) if self.proj_size: - h = np.matmul(h, self.weight_ho.T) + h = np.matmul(h, self.weight_ho) return h, (h, c) From 691c16af0dc74b7cdf03bf9fa1350319b95dbfc7 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Thu, 11 Apr 2024 21:35:14 +0800 Subject: [PATCH 06/10] update --- python/paddle/nn/layer/rnn.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index f5dd120073cdc..c61531a51b3d4 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -1481,7 +1481,7 @@ def __init__( } self.proj_size = proj_size - if proj_size: + if proj_size > 0: assert mode == 'LSTM' if mode == "LSTM": @@ -1689,10 +1689,8 @@ def forward(self, inputs, initial_states=None, sequence_length=None): batch_index = 1 if self.time_major else 0 dtype = inputs.dtype if initial_states is None: - state_shape = (self.num_layers * self.num_directions, -1) dims = ([self.proj_size or self.hidden_size], [self.hidden_size]) - - fill_shape = list(state_shape) + fill_shape = [self.num_layers * self.num_directions, -1] if inputs.shape[batch_index] > 0: fill_shape[1] = inputs.shape[batch_index] else: From b683660e3e41b90b6a74bf1d26525effd870b5e4 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Fri, 12 Apr 2024 16:01:35 +0800 Subject: [PATCH 07/10] update docstring --- python/paddle/nn/layer/rnn.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index c61531a51b3d4..d25ac170bcb60 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -893,6 +893,11 @@ class LSTMCell(RNNCellBase): y_{t} & = h_{t} + If `proj_size` is specified, the dimension of hidden state :math:`h_{t}` will be projected to `proj_size`: + + .. math:: + h_{t} & = h_{t}W_{proj_size} + where :math:`\sigma` is the sigmoid function, and * is the elementwise multiplication operator. @@ -928,9 +933,9 @@ class LSTMCell(RNNCellBase): - **states** (list|tuple, optional): a list/tuple of two tensors, each of shape `[batch_size, hidden_size]`, the previous hidden state, corresponding to :math:`h_{t-1}, c_{t-1}` in the formula. When states is None, zero state is used. Defaults to None. Returns: - - **outputs** (Tensor): shape `[batch_size, hidden_size]`, the output, corresponding to :math:`h_{t}` in the formula. - - **states** (tuple): a tuple of two tensors, each of shape `[batch_size, hidden_size]`, if proj_size is specified, output shape of the first element will be `[batch_size, proj_size]` - the new hidden states, corresponding to :math:`h_{t}, c_{t}` in the formula. + - **outputs** (Tensor): shape `[batch_size, hidden_size]`, the output, corresponding to :math:`h_{t}` in the formula. If `proj_size` is specified, output shape will be `[batch_size, proj_size]`. + - **states** (tuple): a tuple of two tensors, each of shape `[batch_size, hidden_size]`, the new hidden states, corresponding to :math:`h_{t}, c_{t}` in the formula. + If `proj_size` is specified, shape of :math:`h_{t}` will be `[batch_size, proj_size]`. Notes: All the weights and bias are initialized with `Uniform(-std, std)` by @@ -1899,6 +1904,11 @@ class LSTM(RNNBase): y_{t} & = h_{t} + If `proj_size` is specified, the dimension of hidden state :math:`h_{t}` will be projected to `proj_size`: + + .. math:: + h_{t} & = h_{t}W_{proj_size} + where :math:`\sigma` is the sigmoid function, and * is the elementwise multiplication operator. @@ -1939,8 +1949,7 @@ class LSTM(RNNBase): Returns: - - **outputs** (Tensor): the output sequence. If `time_major` is True, the shape is `[time_steps, batch_size, num_directions * hidden_size]`, If `time_major` is False, the shape is `[batch_size, time_steps, num_directions * hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1. `time_steps` means the length of the output sequence. - + - **outputs** (Tensor): the output sequence. If `time_major` is True, the shape is `[time_steps, batch_size, num_directions * hidden_size]`. If `proj_size` is specified, shape will be `[time_major, batch_size, num_directions * proj_size]`. If `time_major` is False, the shape is `[batch_size, time_steps, num_directions * hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1. `time_steps` means the length of the output sequence. - **final_states** (tuple): the final state, a tuple of two tensors, h and c. The shape of each is `[num_layers * num_directions, batch_size, hidden_size]`. If `proj_size` is specified, the last dimension of h will be proj_size. Note that `num_directions` is 2 if direction is "bidirectional" (the index of forward states are 0, 2, 4, 6... and the index of backward states are 1, 3, 5, 7...), else 1. From b5c2a7d692ea26712c5c9add303cbec0f2d2673d Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Wed, 17 Apr 2024 11:26:16 +0800 Subject: [PATCH 08/10] update docstring --- python/paddle/nn/layer/rnn.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index d25ac170bcb60..270c2186be0aa 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -896,7 +896,7 @@ class LSTMCell(RNNCellBase): If `proj_size` is specified, the dimension of hidden state :math:`h_{t}` will be projected to `proj_size`: .. math:: - h_{t} & = h_{t}W_{proj_size} + h_{t} & = h_{t}W_{proj\_size} where :math:`\sigma` is the sigmoid function, and * is the elementwise multiplication operator. @@ -933,8 +933,8 @@ class LSTMCell(RNNCellBase): - **states** (list|tuple, optional): a list/tuple of two tensors, each of shape `[batch_size, hidden_size]`, the previous hidden state, corresponding to :math:`h_{t-1}, c_{t-1}` in the formula. When states is None, zero state is used. Defaults to None. Returns: - - **outputs** (Tensor): shape `[batch_size, hidden_size]`, the output, corresponding to :math:`h_{t}` in the formula. If `proj_size` is specified, output shape will be `[batch_size, proj_size]`. - - **states** (tuple): a tuple of two tensors, each of shape `[batch_size, hidden_size]`, the new hidden states, corresponding to :math:`h_{t}, c_{t}` in the formula. + - **outputs** (Tensor). Shape `[batch_size, hidden_size]`, the output, corresponding to :math:`h_{t}` in the formula. If `proj_size` is specified, output shape will be `[batch_size, proj_size]`. + - **states** (tuple). A tuple of two tensors, each of shape `[batch_size, hidden_size]`, the new hidden states, corresponding to :math:`h_{t}, c_{t}` in the formula. If `proj_size` is specified, shape of :math:`h_{t}` will be `[batch_size, proj_size]`. Notes: @@ -1907,7 +1907,7 @@ class LSTM(RNNBase): If `proj_size` is specified, the dimension of hidden state :math:`h_{t}` will be projected to `proj_size`: .. math:: - h_{t} & = h_{t}W_{proj_size} + h_{t} & = h_{t}W_{proj\_size} where :math:`\sigma` is the sigmoid function, and * is the elementwise multiplication operator. @@ -1949,8 +1949,8 @@ class LSTM(RNNBase): Returns: - - **outputs** (Tensor): the output sequence. If `time_major` is True, the shape is `[time_steps, batch_size, num_directions * hidden_size]`. If `proj_size` is specified, shape will be `[time_major, batch_size, num_directions * proj_size]`. If `time_major` is False, the shape is `[batch_size, time_steps, num_directions * hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1. `time_steps` means the length of the output sequence. - - **final_states** (tuple): the final state, a tuple of two tensors, h and c. The shape of each is `[num_layers * num_directions, batch_size, hidden_size]`. If `proj_size` is specified, the last dimension of h will be proj_size. + - **outputs** (Tensor). The output sequence. If `time_major` is True, the shape is `[time_steps, batch_size, num_directions * hidden_size]`. If `proj_size` is specified, shape will be `[time_major, batch_size, num_directions * proj_size]`. If `time_major` is False, the shape is `[batch_size, time_steps, num_directions * hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1. `time_steps` means the length of the output sequence. + - **final_states** (tuple). The final state, a tuple of two tensors, h and c. The shape of each is `[num_layers * num_directions, batch_size, hidden_size]`. If `proj_size` is specified, the last dimension of h will be proj_size. Note that `num_directions` is 2 if direction is "bidirectional" (the index of forward states are 0, 2, 4, 6... and the index of backward states are 1, 3, 5, 7...), else 1. Variables: From dbb1d63039206403f75c061330c27f17c880afa3 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Wed, 17 Apr 2024 16:05:21 +0800 Subject: [PATCH 09/10] try to fix docstring --- python/paddle/nn/layer/rnn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index 270c2186be0aa..e1390320f8ffd 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -896,6 +896,7 @@ class LSTMCell(RNNCellBase): If `proj_size` is specified, the dimension of hidden state :math:`h_{t}` will be projected to `proj_size`: .. math:: + h_{t} & = h_{t}W_{proj\_size} where :math:`\sigma` is the sigmoid function, and * is the elementwise @@ -1907,6 +1908,7 @@ class LSTM(RNNBase): If `proj_size` is specified, the dimension of hidden state :math:`h_{t}` will be projected to `proj_size`: .. math:: + h_{t} & = h_{t}W_{proj\_size} where :math:`\sigma` is the sigmoid function, and * is the elementwise From 7e8377f2468d76fc716f4efd204cef8bd6f961c2 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Wed, 17 Apr 2024 16:06:55 +0800 Subject: [PATCH 10/10] try to fix docstring --- python/paddle/nn/layer/rnn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index e1390320f8ffd..5888afc638eea 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -897,7 +897,7 @@ class LSTMCell(RNNCellBase): .. math:: - h_{t} & = h_{t}W_{proj\_size} + h_{t} = h_{t}W_{proj\_size} where :math:`\sigma` is the sigmoid function, and * is the elementwise multiplication operator. @@ -1909,7 +1909,7 @@ class LSTM(RNNBase): .. math:: - h_{t} & = h_{t}W_{proj\_size} + h_{t} = h_{t}W_{proj\_size} where :math:`\sigma` is the sigmoid function, and * is the elementwise multiplication operator.