diff --git a/paddlenlp/transformers/bart/modeling.py b/paddlenlp/transformers/bart/modeling.py index b288ebb502de..955995d1f9b1 100644 --- a/paddlenlp/transformers/bart/modeling.py +++ b/paddlenlp/transformers/bart/modeling.py @@ -214,7 +214,12 @@ def forward(self, input_ids=None, attention_mask=None, **kwargs): attention_mask = paddle.cast( input_ids == self.pad_token_id, dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e4 - attention_mask.stop_gradient = True + # For 2D attention_mask from tokenizer + elif attention_mask.ndim == 2: + attention_mask = paddle.unsqueeze( + attention_mask, axis=[1, 2]).astype(paddle.get_default_dtype()) + attention_mask = (1.0 - attention_mask) * -1e4 + attention_mask.stop_gradient = True encoder_output = self.encoder(encoder_input, src_mask=attention_mask) return encoder_output diff --git a/paddlenlp/transformers/mbart/modeling.py b/paddlenlp/transformers/mbart/modeling.py index 7cb6aa3e8fe9..245e8a43545b 100644 --- a/paddlenlp/transformers/mbart/modeling.py +++ b/paddlenlp/transformers/mbart/modeling.py @@ -286,7 +286,12 @@ def forward(self, input_ids=None, attention_mask=None, **kwargs): attention_mask = paddle.cast( input_ids == self.pad_token_id, dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e4 - attention_mask.stop_gradient = True + # For 2D attention_mask from tokenizer + elif attention_mask.ndim == 2: + attention_mask = paddle.unsqueeze( + attention_mask, axis=[1, 2]).astype(paddle.get_default_dtype()) + attention_mask = (1.0 - attention_mask) * -1e4 + attention_mask.stop_gradient = True encoder_output = self.encoder(encoder_input, src_mask=attention_mask) return encoder_output