-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 4th No.102】给AutoConverter增加新的模型组网的支持 AlbertModel #5626
Merged
Merged
Changes from 3 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
6f01e2e
[Add]Add AlbertModel to AutoConverter
megemini 13e935d
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
megemini e0e60f5
[Change]Add classifier.bias to suppress warning
megemini 6deaf28
[Fix]Fix AlbertForQuestionAnswering name mapping
megemini c0b2eb1
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
megemini c4bdcd1
[Fix]Fix conflict
megemini 5b3d809
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
megemini af80b45
[Fix]Fix lint style
megemini File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,11 +13,14 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import random | ||
import tempfile | ||
import unittest | ||
|
||
import numpy as np | ||
import paddle | ||
from paddle import Tensor | ||
from parameterized import parameterized_class | ||
from parameterized import parameterized, parameterized_class | ||
|
||
from paddlenlp.transformers import ( | ||
AlbertConfig, | ||
|
@@ -30,7 +33,7 @@ | |
AlbertPretrainedModel, | ||
) | ||
|
||
from ...testing_utils import slow | ||
from ...testing_utils import require_package, slow | ||
from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask | ||
|
||
|
||
|
@@ -337,6 +340,148 @@ def test_model_from_pretrained(self): | |
self.assertIsNotNone(model) | ||
|
||
|
||
class AlbertModelCompatibilityTest(unittest.TestCase): | ||
model_id = "hf-internal-testing/tiny-random-AlbertModel" | ||
|
||
@require_package("transformers", "torch") | ||
def test_albert_converter(self): | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
# 1. create input | ||
input_ids = np.random.randint(100, 200, [1, 20]) | ||
|
||
# 2. forward the paddle model | ||
from paddlenlp.transformers import AlbertModel | ||
|
||
paddle_model = AlbertModel.from_pretrained(self.model_id, from_hf_hub=True, cache_dir=tempdir) | ||
paddle_model.eval() | ||
paddle_logit = paddle_model(paddle.to_tensor(input_ids))[0] | ||
|
||
# 3. forward the torch model | ||
import torch | ||
from transformers import AlbertModel | ||
|
||
torch_model = AlbertModel.from_pretrained(self.model_id, cache_dir=tempdir) | ||
torch_model.eval() | ||
torch_logit = torch_model(torch.tensor(input_ids), return_dict=False)[0] | ||
|
||
# 4. compare results | ||
self.assertTrue( | ||
np.allclose( | ||
paddle_logit.detach().cpu().reshape([-1])[:9].numpy(), | ||
torch_logit.detach().cpu().reshape([-1])[:9].numpy(), | ||
rtol=1e-4, | ||
) | ||
) | ||
|
||
@require_package("transformers", "torch") | ||
def test_albert_converter_from_local_dir_with_enable_torch(self): | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
# 1. forward the torch model | ||
from transformers import AlbertModel | ||
|
||
torch_model = AlbertModel.from_pretrained(self.model_id) | ||
torch_model.save_pretrained(tempdir) | ||
|
||
# 2. forward the paddle model | ||
from paddlenlp.transformers import AlbertModel, model_utils | ||
|
||
model_utils.ENABLE_TORCH_CHECKPOINT = False | ||
|
||
with self.assertRaises(ValueError) as error: | ||
AlbertModel.from_pretrained(tempdir) | ||
self.assertIn("conversion is been disabled" in str(error.exception)) | ||
model_utils.ENABLE_TORCH_CHECKPOINT = True | ||
|
||
@require_package("transformers", "torch") | ||
def test_albert_converter_from_local_dir(self): | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
|
||
# 1. create commmon input | ||
input_ids = np.random.randint(100, 200, [1, 20]) | ||
|
||
# 2. forward the torch model | ||
import torch | ||
from transformers import AlbertModel | ||
|
||
torch_model = AlbertModel.from_pretrained(self.model_id) | ||
torch_model.eval() | ||
torch_model.save_pretrained(tempdir) | ||
torch_logit = torch_model(torch.tensor(input_ids), return_dict=False)[0] | ||
|
||
# 2. forward the paddle model | ||
from paddlenlp.transformers import AlbertModel | ||
|
||
paddle_model = AlbertModel.from_pretrained(tempdir) | ||
paddle_model.eval() | ||
paddle_logit = paddle_model(paddle.to_tensor(input_ids))[0] | ||
|
||
self.assertTrue( | ||
np.allclose( | ||
paddle_logit.detach().cpu().reshape([-1])[:9].numpy(), | ||
torch_logit.detach().cpu().reshape([-1])[:9].numpy(), | ||
rtol=1e-4, | ||
) | ||
) | ||
|
||
@parameterized.expand( | ||
[ | ||
("AlbertModel",), | ||
# ("AlbertForMaskedLM",), TODO: need to tie weights | ||
# ("AlbertForPretraining",), TODO: need to tie weights | ||
Comment on lines
+429
to
+430
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 等 #5623 合入之后就可以用 tie_weights 了。 |
||
("AlbertForMultipleChoice",), | ||
# ("AlbertForQuestionAnswering",), TODO: transformers NOT add the last pool layer before qa_outputs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可通过 architectures 来控制 pooler 的映射。 |
||
("AlbertForSequenceClassification",), | ||
("AlbertForTokenClassification",), | ||
] | ||
) | ||
@require_package("transformers", "torch") | ||
def test_albert_classes_from_local_dir(self, class_name, pytorch_class_name=None): | ||
pytorch_class_name = pytorch_class_name or class_name | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
|
||
# 1. create commmon input | ||
input_ids = np.random.randint(100, 200, [1, 20]) | ||
|
||
# 2. forward the torch model | ||
import torch | ||
import transformers | ||
|
||
torch_model_class = getattr(transformers, pytorch_class_name) | ||
torch_model = torch_model_class.from_pretrained(self.model_id) | ||
torch_model.eval() | ||
|
||
if "MultipleChoice" in class_name: | ||
# construct input for MultipleChoice Model | ||
torch_model.config.num_choices = random.randint(2, 10) | ||
input_ids = ( | ||
paddle.to_tensor(input_ids) | ||
.unsqueeze(1) | ||
.expand([-1, torch_model.config.num_choices, -1]) | ||
.cpu() | ||
.numpy() | ||
) | ||
|
||
torch_model.save_pretrained(tempdir) | ||
torch_logit = torch_model(torch.tensor(input_ids), return_dict=False)[0] | ||
|
||
# 3. forward the paddle model | ||
from paddlenlp import transformers | ||
|
||
paddle_model_class = getattr(transformers, class_name) | ||
paddle_model = paddle_model_class.from_pretrained(tempdir) | ||
paddle_model.eval() | ||
|
||
paddle_logit = paddle_model(paddle.to_tensor(input_ids), return_dict=False)[0] | ||
|
||
self.assertTrue( | ||
np.allclose( | ||
paddle_logit.detach().cpu().reshape([-1])[:9].numpy(), | ||
torch_logit.detach().cpu().reshape([-1])[:9].numpy(), | ||
atol=1e-3, | ||
) | ||
) | ||
|
||
|
||
class AlbertModelIntegrationTest(unittest.TestCase): | ||
@slow | ||
def test_inference_no_head_absolute_embedding(self): | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个变量定义在这里没有用到,故可以删除。