Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unittest]add tinybert unittest #2992

Merged
merged 7 commits into from
Aug 11, 2022

Conversation

wj-Mcat
Copy link
Contributor

@wj-Mcat wj-Mcat commented Aug 8, 2022

PR types

New features

PR changes

Models

Description

add tinybert unittest

@wj-Mcat wj-Mcat changed the title [Unittest]update tinybert unittest [Unittest]add tinybert unittest Aug 8, 2022
@@ -330,7 +346,7 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None):
if attention_mask is None:
attention_mask = paddle.unsqueeze(
(input_ids == self.pad_token_id).astype(
self.pooler.dense.weight.dtype) * -1e4,
self.pooler.dense.weight.dtype) * 0e4,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个先不要改动,当前attention_mask is None的时候咱们确实会和HF行为不一致,而且目前还需要保留这个。attention_mask相关内容也会在 #2005 里处理

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不好意思,这个应该是键盘vim快捷键误触导致的,我已经回滚了。

Args:
embedding (nn.Embedding): the new embedding value
"""
self.tinybert.embeddings.word_embeddings = embedding
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

按照基类中的实现,这个是要改TinyBERTModel而不是这些下游任务模型的吧,参考HF中BERT

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

关于这点,我阅读了下代码,确实是在PreTrainedModel这个类中有实现 get_input_embeddings 这个方法,核心还是调用base_model的get_input_embeddings的方法。

我一会儿也把其他几个单测中的也修改一下。

import unittest
from typing import Optional, Tuple
from dataclasses import dataclass, fields, Field
from dataclasses_json import dataclass_json
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不要引入更多依赖了,看这个也是造成CI单测失败的原因

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯,我也看到了,其他的几个单测我也顺便都调整了。

class TinyBertModelIntegrationTest(unittest.TestCase):

# @slow
def test_inference_no_attention(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个和下面的slow为什么去掉了呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我后续加上,因为我在本地测试时,这个会直接skip掉,所以为了测试它的有效性,我都是注释掉的。

tests/transformers/tinybert/test_tokenizer.py Show resolved Hide resolved
self.assertListEqual(tokens_without_spe_char_p, expected_tokens)
'''

def test_pretrained_model_lists(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个单测的意图是要测什么呢,看HF没有

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是为了测试tokenizer当中pretrained_vocab_files_mapmax_model_input_sizes的配置是否正确。只不过HF的这个是在基类(test_tokenizer_common.py)里面,我是写在派生类里面来。

写在派生类的原因:tinybert没有max_model_input_sizes属性(其实是绝大部分的tokenizer都没有这个属性)导致单测failed,所以我要重写方法修改一下测试逻辑。

后想一下,这个逻辑其实是没有意义的,我觉得可以删掉,或者判断没有max_model_input_sizes属性的话,就直接skip掉单测也是可以的。 你觉得如何? @guoshengCS

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_model_input_sizes这个倒可以加上,HF大部分是加上了max_model_input_sizes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好,这个我马上加上

@wj-Mcat
Copy link
Contributor Author

wj-Mcat commented Aug 10, 2022

ping @guoshengCS

@guoshengCS guoshengCS merged commit 6952b91 into PaddlePaddle:develop Aug 11, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants