-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unittest]add tinybert unittest #2992
Conversation
@@ -330,7 +346,7 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None): | |||
if attention_mask is None: | |||
attention_mask = paddle.unsqueeze( | |||
(input_ids == self.pad_token_id).astype( | |||
self.pooler.dense.weight.dtype) * -1e4, | |||
self.pooler.dense.weight.dtype) * 0e4, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个先不要改动,当前attention_mask is None
的时候咱们确实会和HF行为不一致,而且目前还需要保留这个。attention_mask相关内容也会在 #2005 里处理
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不好意思,这个应该是键盘vim快捷键误触导致的,我已经回滚了。
Args: | ||
embedding (nn.Embedding): the new embedding value | ||
""" | ||
self.tinybert.embeddings.word_embeddings = embedding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
按照基类中的实现,这个是要改TinyBERTModel而不是这些下游任务模型的吧,参考HF中BERT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
关于这点,我阅读了下代码,确实是在PreTrainedModel
这个类中有实现 get_input_embeddings
这个方法,核心还是调用base_model
的get_input_embeddings的方法。
我一会儿也把其他几个单测中的也修改一下。
import unittest | ||
from typing import Optional, Tuple | ||
from dataclasses import dataclass, fields, Field | ||
from dataclasses_json import dataclass_json |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不要引入更多依赖了,看这个也是造成CI单测失败的原因
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯嗯,我也看到了,其他的几个单测我也顺便都调整了。
class TinyBertModelIntegrationTest(unittest.TestCase): | ||
|
||
# @slow | ||
def test_inference_no_attention(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个和下面的slow为什么去掉了呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我后续加上,因为我在本地测试时,这个会直接skip掉,所以为了测试它的有效性,我都是注释掉的。
self.assertListEqual(tokens_without_spe_char_p, expected_tokens) | ||
''' | ||
|
||
def test_pretrained_model_lists(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个单测的意图是要测什么呢,看HF没有
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是为了测试tokenizer当中pretrained_vocab_files_map
和max_model_input_sizes
的配置是否正确。只不过HF的这个是在基类(test_tokenizer_common.py)里面,我是写在派生类里面来。
写在派生类的原因:tinybert没有max_model_input_sizes属性(其实是绝大部分的tokenizer都没有这个属性)导致单测failed,所以我要重写方法修改一下测试逻辑。
后想一下,这个逻辑其实是没有意义的,我觉得可以删掉,或者判断没有max_model_input_sizes属性的话,就直接skip掉单测也是可以的。 你觉得如何? @guoshengCS
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_model_input_sizes
这个倒可以加上,HF大部分是加上了max_model_input_sizes
的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好,这个我马上加上
ping @guoshengCS |
PR types
New features
PR changes
Models
Description
add tinybert unittest