Skip to content

Commit

Permalink
fix bugs.
Browse files Browse the repository at this point in the history
  • Loading branch information
xiemoyuan committed Feb 18, 2021
1 parent 222c07a commit 0724c6b
Showing 1 changed file with 72 additions and 23 deletions.
95 changes: 72 additions & 23 deletions python/paddle/nn/layer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,35 @@ def _convert_param_attr_to_list(param_attr, n):
return param_attrs


def _convert_attention_mask(attn_mask, dtype):
"""
Convert the attention mask to the target dtype we expect.
Parameters:
attn_mask (Tensor, optional): A tensor used in multi-head attention
to prevents attention to some unwanted positions, usually the
paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
When the data type is bool, the unwanted positions have `False`
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
dtype (VarType): The target type of `attn_mask` we expect.
Returns:
Tensor: A Tensor with shape same as input `attn_mask`, with data type `dtype`.
"""
if attn_mask is not None and attn_mask.dtype != dtype:
attn_mask_dtype = convert_dtype(attn_mask.dtype)
if attn_mask_dtype == 'bool' or 'int' in attn_mask_dtype:
attn_mask = (paddle.cast(attn_mask, dtype) - 1.0) * 1e9
else:
attn_mask = paddle.cast(attn_mask, dtype)
return attn_mask


class MultiHeadAttention(Layer):
"""
Attention mapps queries and a set of key-value pairs to outputs, and
Expand Down Expand Up @@ -334,8 +363,8 @@ def forward(self, query, key=None, value=None, attn_mask=None, cache=None):
paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
When the data type is bool, the unwanted positions have `False`
values and the others have 'True' values. When the data type is
int, the unwanted positions have 0 values and the others have 0
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
Expand Down Expand Up @@ -378,11 +407,7 @@ def forward(self, query, key=None, value=None, attn_mask=None, cache=None):
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)
if attn_mask is not None:
# Support bool or int mask
attn_mask_dtype = convert_dtype(attn_mask.dtype)
if attn_mask_dtype == 'bool' or 'int' in attn_mask_dtype:
attn_mask = (paddle.cast(attn_mask, product.dtype) - 1.0) * 1e9
else:
attn_mask = paddle.cast(attn_mask, product.dtype)
attn_mask = _convert_attention_mask(attn_mask, product.dtype)
product = product + attn_mask
weights = F.softmax(product)
if self.dropout:
Expand Down Expand Up @@ -519,8 +544,8 @@ def forward(self, src, src_mask=None, cache=None):
paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
When the data type is bool, the unwanted positions have `False`
values and the others have 'True' values. When the data type is
int, the unwanted positions have 0 values and the others have 0
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
Expand All @@ -538,6 +563,8 @@ def forward(self, src, src_mask=None, cache=None):
incremental length. See `MultiHeadAttention.gen_cache` and \
`MultiHeadAttention.forward` for more details.
"""
src_mask = _convert_attention_mask(src_mask, src.dtype)

residual = src
if self.normalize_before:
src = self.norm1(src)
Expand Down Expand Up @@ -634,8 +661,8 @@ def forward(self, src, src_mask=None, cache=None):
paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
When the data type is bool, the unwanted positions have `False`
values and the others have 'True' values. When the data type is
int, the unwanted positions have 0 values and the others have 0
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
Expand All @@ -653,6 +680,8 @@ def forward(self, src, src_mask=None, cache=None):
See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \
for more details.
"""
src_mask = _convert_attention_mask(src_mask, src.dtype)

output = src
new_caches = []
for i, mod in enumerate(self.layers):
Expand Down Expand Up @@ -822,8 +851,8 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None):
the subsequent positions. It is a tensor with shape broadcasted
to `[batch_size, n_head, target_length, target_length]`.
When the data type is bool, the unwanted positions have `False`
values and the others have 'True' values. When the data type is
int, the unwanted positions have 0 values and the others have 0
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
Expand All @@ -832,8 +861,8 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None):
usually the paddings. It is a tensor with shape broadcasted to
`[batch_size, n_head, target_length, source_length]`. When the
data type is bool, the unwanted positions have `False` values
and the others have 'True' values. When the data type is int,
the unwanted positions have 0 values and the others have 0
and the others have `True` values. When the data type is int,
the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
Expand All @@ -853,6 +882,9 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None):
See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \
for more details.
"""
tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)
memory_mask = _convert_attention_mask(memory_mask, memory.dtype)

residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
Expand Down Expand Up @@ -977,8 +1009,8 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None):
the subsequent positions. It is a tensor with shape broadcasted
to `[batch_size, n_head, target_length, target_length]`. When
the data type is bool, the unwanted positions have `False`
values and the others have 'True' values. When the data type is
int, the unwanted positions have 0 values and the others have 0
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
Expand All @@ -987,8 +1019,8 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None):
usually the paddings. It is a tensor with shape broadcasted to
`[batch_size, n_head, target_length, source_length]`. When the
data type is bool, the unwanted positions have `False` values
and the others have 'True' values. When the data type is int,
the unwanted positions have 0 values and the others have 0
and the others have `True` values. When the data type is int,
the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
Expand All @@ -1006,6 +1038,9 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None):
See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \
for more details.
"""
tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)
memory_mask = _convert_attention_mask(memory_mask, memory.dtype)

output = tgt
new_caches = []
for i, mod in enumerate(self.layers):
Expand Down Expand Up @@ -1244,13 +1279,23 @@ def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
memory (Tensor): The output of Transformer encoder. It is a tensor
with shape `[batch_size, source_length, d_model]`. The data type
should be float32 or float64.
src_mask (Tensor, optional): A tensor used in multi-head attention
to prevents attention to some unwanted positions, usually the
paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
When the data type is bool, the unwanted positions have `False`
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
tgt_mask (Tensor, optional): A tensor used in self attention
to prevents attention to some unwanted positions, usually the
the subsequent positions. It is a tensor with shape broadcasted
to `[batch_size, n_head, target_length, target_length]`. When
the data type is bool, the unwanted positions have `False`
values and the others have 'True' values. When the data type is
int, the unwanted positions have 0 values and the others have 0
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
Expand All @@ -1259,8 +1304,8 @@ def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
usually the paddings. It is a tensor with shape broadcasted to
`[batch_size, n_head, target_length, source_length]`. When the
data type is bool, the unwanted positions have `False` values
and the others have 'True' values. When the data type is int,
the unwanted positions have 0 values and the others have 0
and the others have `True` values. When the data type is int,
the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
Expand All @@ -1269,7 +1314,11 @@ def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
Tensor: It is a tensor that has the same shape and data type \
as `tgt`, representing the output of Transformer decoder.
"""
src_mask = _convert_attention_mask(src_mask, src.dtype)
memory = self.encoder(src, src_mask=src_mask)

tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)
memory_mask = _convert_attention_mask(memory_mask, memory.dtype)
output = self.decoder(
tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
return output
Expand Down

1 comment on commit 0724c6b

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.