Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unittest] Add unittest for RoBERTa, ALBERT and ERNIE #2972

Merged
merged 20 commits into from
Aug 18, 2022

Conversation

yingyibiao
Copy link
Contributor

@yingyibiao yingyibiao commented Aug 5, 2022

PR types

unittest

PR changes

Models

Description

Add unittest for RoBERTa and ALBERT

Copy link
Contributor

@guoshengCS guoshengCS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外看有些新增的test文件是空的,是还没有完成是吗

@@ -406,6 +407,8 @@ def __init__(
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
self.add_prefix_space = add_prefix_space

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看着并不会被用到,看HF是否这里少了对prepare_for_tokenization方法的重写

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -363,6 +363,7 @@ def __init__(
merges_file,
errors='replace',
max_len=None,
add_prefix_space=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外这个参数也放到special token后面吧,一是和HF一致,二是更好的照顾兼容性

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
add_prefix_space=False,
max_len=None,
special_tokens=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_len和special_tokens看着都没有使用,而且HF也没有,没有用的话就去掉吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

paddlenlp/transformers/roberta/tokenizer.py Show resolved Hide resolved
AlbertModel,
)
from tests.transformers.test_modeling_common import ids_tensor, random_attention_mask, ModelTesterMixin
from tests.testing_utils import slow
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

按照上次统一的结论,对tests下的内容使用相对import

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@yingyibiao yingyibiao changed the title Add unittest for RoBERTa Add unittest for RoBERTa and ALBERT Aug 10, 2022
@yingyibiao yingyibiao changed the title Add unittest for RoBERTa and ALBERT [Unittest] Add unittest for RoBERTa, ALBERT and ERNIE Aug 11, 2022
for t in text)))
return self.convert_tokens_to_ids(tokens)
else:
return self.convert_tokens_to_ids(text)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里之前没改留了TODO好像就是因为兼容性问题,还得关注下这里是否有CI挂的,之前是skep序列标注任务训练和预测报错 #2063

**tokenizer.added_tokens_encoder)
vocab = tokenizer.get_vocab()
# vocab = dict(tokenizer.vocab._token_to_idx,
# **tokenizer.added_tokens_encoder)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删掉这个吧,不用了是吧,上面那样也是这样的就删掉吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@guoshengCS
Copy link
Contributor

看CI里这个制品详情里只有一个,这个是否是符合预期的呢 @zjjlivein
image

@yingyibiao yingyibiao merged commit 869dd96 into PaddlePaddle:develop Aug 18, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants