Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 6th No.24 】 为 paddle.nn.LSTM/RNNBase 功能增强 -part #63284

Merged
merged 14 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 62 additions & 18 deletions python/paddle/nn/layer/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,12 +910,16 @@ class LSTMCell(RNNCellBase):
`bias_ih`. Default: None.
bias_hh_attr (ParamAttr, optional): The parameter attribute for the
`bias_hh`. Default: None.
proj_size (int, optional): If specified, the output hidden state
will be projected to `proj_size`. `proj_size` must be smaller than
`hidden_size`. Default: None.
name (str, optional): Name for the operation (optional, default is
None). For more information, please refer to :ref:`api_guide_Name`.

Variables:
- **weight_ih** (Parameter): shape (4 * hidden_size, input_size), input to hidden weight, which corresponds to the concatenation of :math:`W_{ii}, W_{if}, W_{ig}, W_{io}` in the formula.
- **weight_hh** (Parameter): shape (4 * hidden_size, hidden_size), hidden to hidden weight, which corresponds to the concatenation of :math:`W_{hi}, W_{hf}, W_{hg}, W_{ho}` in the formula.
- **weight_hh** (Parameter): shape (4 * hidden_size, hidden_size), hidden to hidden weight, which corresponds to the concatenation of :math:`W_{hi}, W_{hf}, W_{hg}, W_{ho}` in the formula. If proj_size was specified, the shape will be (4 * hidden_size, proj_size).
- **weight_ho** (Parameter, optional): shape (hidden_size, proj_size), project the hidden state.
- **bias_ih** (Parameter): shape (4 * hidden_size, ), input to hidden bias, which corresponds to the concatenation of :math:`b_{ii}, b_{if}, b_{ig}, b_{io}` in the formula.
- **bias_hh** (Parameter): shape (4 * hidden_size, ), hidden to hidden bias, which corresponds to the concatenation of :math:`b_{hi}, b_{hf}, b_{hg}, b_{ho}` in the formula.

Expand All @@ -925,7 +929,8 @@ class LSTMCell(RNNCellBase):

Returns:
- **outputs** (Tensor): shape `[batch_size, hidden_size]`, the output, corresponding to :math:`h_{t}` in the formula.
- **states** (tuple): a tuple of two tensors, each of shape `[batch_size, hidden_size]`, the new hidden states, corresponding to :math:`h_{t}, c_{t}` in the formula.
- **states** (tuple): a tuple of two tensors, each of shape `[batch_size, hidden_size]`, if proj_size is specified, output shape of the first element will be `[batch_size, proj_size]`
the new hidden states, corresponding to :math:`h_{t}, c_{t}` in the formula.

Notes:
All the weights and bias are initialized with `Uniform(-std, std)` by
Expand Down Expand Up @@ -962,13 +967,22 @@ def __init__(
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
proj_size=0,
name=None,
):
super().__init__()
if hidden_size <= 0:
raise ValueError(
f"hidden_size of {self.__class__.__name__} must be greater than 0, but now equals to {hidden_size}"
)
if proj_size < 0:
Asthestarsfalll marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"proj_size of {self.__class__.__name__} must be greater than 0, but now equals to {hidden_size}"
)

if proj_size >= hidden_size:
raise ValueError("proj_size must be smaller than hidden_size")

std = 1.0 / math.sqrt(hidden_size)
if weight_ih_attr is not False:
self.weight_ih = self.create_parameter(
Expand All @@ -985,13 +999,13 @@ def __init__(
self.weight_ih.stop_gradient = True
if weight_hh_attr is not False:
self.weight_hh = self.create_parameter(
(4 * hidden_size, hidden_size),
(4 * hidden_size, proj_size or hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std),
)
else:
self.weight_hh = self.create_parameter(
(4 * hidden_size, hidden_size),
(4 * hidden_size, proj_size or hidden_size),
None,
default_initializer=I.Constant(1.0),
)
Expand Down Expand Up @@ -1027,6 +1041,14 @@ def __init__(
)
self.bias_hh.stop_gradient = True

self.proj_size = proj_size
if proj_size:
Asthestarsfalll marked this conversation as resolved.
Show resolved Hide resolved
Asthestarsfalll marked this conversation as resolved.
Show resolved Hide resolved
self.weight_ho = self.create_parameter(
Asthestarsfalll marked this conversation as resolved.
Show resolved Hide resolved
(proj_size, hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std),
)

self.hidden_size = hidden_size
self.input_size = input_size
self._gate_activation = F.sigmoid
Expand All @@ -1050,6 +1072,8 @@ def forward(self, inputs, states=None):
o = self._gate_activation(chunked_gates[3])
c = f * pre_cell + i * self._activation(chunked_gates[2])
h = o * self._activation(c)
if self.proj_size:
Asthestarsfalll marked this conversation as resolved.
Show resolved Hide resolved
h = paddle.matmul(h, self.weight_ho, transpose_y=True)

return h, (h, c)

Expand All @@ -1061,7 +1085,7 @@ def state_shape(self):
automatically inserted into shape). These two shapes correspond
to :math:`h_{t-1}` and :math:`c_{t-1}` separately.
"""
return ((self.hidden_size,), (self.hidden_size,))
return ((self.hidden_size,), (self.proj_size or self.hidden_size,))

def extra_repr(self):
return '{input_size}, {hidden_size}'.format(**self.__dict__)
Expand Down Expand Up @@ -1436,6 +1460,7 @@ def __init__(
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
proj_size=0,
):
super().__init__()
bidirectional_list = ["bidirectional", "bidirect"]
Expand All @@ -1455,28 +1480,40 @@ def __init__(
"bias_hh_attr": bias_hh_attr,
}

self.proj_size = proj_size
if proj_size:
Asthestarsfalll marked this conversation as resolved.
Show resolved Hide resolved
Asthestarsfalll marked this conversation as resolved.
Show resolved Hide resolved
assert mode == 'LSTM'

if mode == "LSTM":
rnn_cls = LSTMCell
kwargs["proj_size"] = proj_size
elif mode == "GRU":
rnn_cls = GRUCell
elif mode == "RNN_RELU":
rnn_cls = SimpleRNNCell
kwargs["activation"] = 'relu'
elif mode == "RNN_TANH":
rnn_cls = SimpleRNNCell
kwargs["activation"] = 'tanh'
else:
rnn_cls = SimpleRNNCell
kwargs["activation"] = self.activation

in_size = proj_size or hidden_size
if direction in ["forward"]:
is_reverse = False
cell = rnn_cls(input_size, hidden_size, **kwargs)
self.append(RNN(cell, is_reverse, time_major))
for i in range(1, num_layers):
cell = rnn_cls(hidden_size, hidden_size, **kwargs)
for _ in range(1, num_layers):
cell = rnn_cls(in_size, hidden_size, **kwargs)
self.append(RNN(cell, is_reverse, time_major))
elif direction in bidirectional_list:
cell_fw = rnn_cls(input_size, hidden_size, **kwargs)
cell_bw = rnn_cls(input_size, hidden_size, **kwargs)
self.append(BiRNN(cell_fw, cell_bw, time_major))
for i in range(1, num_layers):
cell_fw = rnn_cls(2 * hidden_size, hidden_size, **kwargs)
cell_bw = rnn_cls(2 * hidden_size, hidden_size, **kwargs)
for _ in range(1, num_layers):
cell_fw = rnn_cls(2 * in_size, hidden_size, **kwargs)
cell_bw = rnn_cls(2 * in_size, hidden_size, **kwargs)
self.append(BiRNN(cell_fw, cell_bw, time_major))
else:
raise ValueError(
Expand Down Expand Up @@ -1652,11 +1689,8 @@ def forward(self, inputs, initial_states=None, sequence_length=None):
batch_index = 1 if self.time_major else 0
dtype = inputs.dtype
if initial_states is None:
state_shape = (
self.num_layers * self.num_directions,
-1,
self.hidden_size,
)
state_shape = (self.num_layers * self.num_directions, -1)
dims = ([self.proj_size or self.hidden_size], [self.hidden_size])

fill_shape = list(state_shape)
Asthestarsfalll marked this conversation as resolved.
Show resolved Hide resolved
if inputs.shape[batch_index] > 0:
Expand All @@ -1665,8 +1699,10 @@ def forward(self, inputs, initial_states=None, sequence_length=None):
fill_shape[1] = paddle.shape(inputs)[batch_index].item()
initial_states = tuple(
[
paddle.full(shape=fill_shape, fill_value=0, dtype=dtype)
for _ in range(self.state_components)
paddle.full(
shape=fill_shape + dims[i], fill_value=0, dtype=dtype
)
for i in range(self.state_components)
]
)
else:
Expand Down Expand Up @@ -1834,6 +1870,7 @@ def __init__(
weight_hh_attr,
bias_ih_attr,
bias_hh_attr,
0, # proj_size
)


Expand Down Expand Up @@ -1891,6 +1928,9 @@ class LSTM(RNNBase):
`bias_ih` of each cells. Default: None.
bias_hh_attr (ParamAttr, optional): The parameter attribute for the
`bias_hh` of each cells. Default: None.
proj_size (int, optional): If specified, the output hidden state of each layer
will be projected to `proj_size`. `proj_size` must be smaller than `hidden_size`.
Default: 0.
name (str, optional): Name for the operation (optional, default is
None). For more information, please refer to :ref:`api_guide_Name`.

Expand All @@ -1903,7 +1943,8 @@ class LSTM(RNNBase):

- **outputs** (Tensor): the output sequence. If `time_major` is True, the shape is `[time_steps, batch_size, num_directions * hidden_size]`, If `time_major` is False, the shape is `[batch_size, time_steps, num_directions * hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1. `time_steps` means the length of the output sequence.

- **final_states** (tuple): the final state, a tuple of two tensors, h and c. The shape of each is `[num_layers * num_directions, batch_size, hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" (the index of forward states are 0, 2, 4, 6... and the index of backward states are 1, 3, 5, 7...), else 1.
- **final_states** (tuple): the final state, a tuple of two tensors, h and c. The shape of each is `[num_layers * num_directions, batch_size, hidden_size]`. If `proj_size` is specified, the last dimension of h will be proj_size.
Note that `num_directions` is 2 if direction is "bidirectional" (the index of forward states are 0, 2, 4, 6... and the index of backward states are 1, 3, 5, 7...), else 1.

Variables:
- **weight_ih_l[k]**: the learnable input-hidden weights of the k-th layer. If `k = 0`, the shape is `[hidden_size, input_size]`. Otherwise, the shape is `[hidden_size, num_directions * hidden_size]`.
Expand Down Expand Up @@ -1946,6 +1987,7 @@ def __init__(
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
proj_size=0,
name=None,
):
super().__init__(
Expand All @@ -1960,6 +2002,7 @@ def __init__(
weight_hh_attr,
bias_ih_attr,
bias_hh_attr,
proj_size,
)


Expand Down Expand Up @@ -2079,4 +2122,5 @@ def __init__(
weight_hh_attr,
bias_ih_attr,
bias_hh_attr,
0, # proj_size
)
28 changes: 24 additions & 4 deletions test/dygraph_to_static/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@


class LSTMLayer(nn.Layer):
def __init__(self, in_channels, hidden_size):
def __init__(self, in_channels, hidden_size, proj_size=0):
super().__init__()
self.cell = nn.LSTM(
in_channels, hidden_size, direction='bidirectional', num_layers=2
in_channels,
hidden_size,
direction='bidirectional',
num_layers=2,
proj_size=proj_size,
)

def forward(self, x):
Expand All @@ -42,9 +46,9 @@ def forward(self, x):


class Net(nn.Layer):
def __init__(self, in_channels, hidden_size):
def __init__(self, in_channels, hidden_size, proj_size=0):
super().__init__()
self.lstm = LSTMLayer(in_channels, hidden_size)
self.lstm = LSTMLayer(in_channels, hidden_size, proj_size=proj_size)

def forward(self, x):
x = self.lstm(x)
Expand Down Expand Up @@ -128,6 +132,22 @@ def test_save_with_training(self):
self.save_in_eval(with_training=True)


class TestLstmWithProjsize(TestLstm):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.net = Net(12, 8, 4)
self.inputs = paddle.zeros((2, 10, 12))

def test_error(self):
# proj_size < 0
with self.assertRaises(ValueError):
nn.LSTM(4, 4, 4, proj_size=-1)

# proj_size >= hidden_size
with self.assertRaises(ValueError):
nn.LSTM(4, 4, 4, proj_size=20)


class LinearNet(nn.Layer):
def __init__(self):
super().__init__()
Expand Down
Loading