Skip to content

Commit

Permalink
[BUGFIX] BART and mBART support 2D attention mask from tokenizer (#1637)
Browse files Browse the repository at this point in the history
* fix: bart and mbart support 2d attention mask

* fix: bart and mbart support 2d attention mask
  • Loading branch information
gongenlei authored Jan 26, 2022
1 parent 82b1cc4 commit 32d01fa
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
7 changes: 6 additions & 1 deletion paddlenlp/transformers/bart/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,12 @@ def forward(self, input_ids=None, attention_mask=None, **kwargs):
attention_mask = paddle.cast(
input_ids == self.pad_token_id,
dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e4
attention_mask.stop_gradient = True
# For 2D attention_mask from tokenizer
elif attention_mask.ndim == 2:
attention_mask = paddle.unsqueeze(
attention_mask, axis=[1, 2]).astype(paddle.get_default_dtype())
attention_mask = (1.0 - attention_mask) * -1e4
attention_mask.stop_gradient = True

encoder_output = self.encoder(encoder_input, src_mask=attention_mask)
return encoder_output
Expand Down
7 changes: 6 additions & 1 deletion paddlenlp/transformers/mbart/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,12 @@ def forward(self, input_ids=None, attention_mask=None, **kwargs):
attention_mask = paddle.cast(
input_ids == self.pad_token_id,
dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e4
attention_mask.stop_gradient = True
# For 2D attention_mask from tokenizer
elif attention_mask.ndim == 2:
attention_mask = paddle.unsqueeze(
attention_mask, axis=[1, 2]).astype(paddle.get_default_dtype())
attention_mask = (1.0 - attention_mask) * -1e4
attention_mask.stop_gradient = True

encoder_output = self.encoder(encoder_input, src_mask=attention_mask)
return encoder_output
Expand Down

0 comments on commit 32d01fa

Please sign in to comment.