Skip to content

Commit

Permalink
Fix for issue #147 for transformers>4.17.0 (#148)
Browse files Browse the repository at this point in the history
Fix for issue #147 for transformers>4.17.0
  • Loading branch information
shirley-wu authored Jan 26, 2023
1 parent cb582ed commit a45819b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions bert_score/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from packaging import version
from torch.nn.utils.rnn import pad_sequence
from tqdm.auto import tqdm
from transformers import (AutoModel, AutoTokenizer, BertConfig, GPT2Tokenizer,
from transformers import (AutoModel, AutoTokenizer, BertConfig, GPT2Tokenizer, RobertaTokenizer,
RobertaConfig, XLMConfig, XLNetConfig)
from transformers import __version__ as trans_version

Expand Down Expand Up @@ -190,7 +190,7 @@ def sent_encode(tokenizer, sent):
sent = sent.strip()
if sent == "":
return tokenizer.build_inputs_with_special_tokens([])
elif isinstance(tokenizer, GPT2Tokenizer):
elif isinstance(tokenizer, GPT2Tokenizer) or isinstance(tokenizer, RobertaTokenizer):
# for RoBERTa and GPT-2
if version.parse(trans_version) >= version.parse("4.0.0"):
return tokenizer.encode(
Expand Down

0 comments on commit a45819b

Please sign in to comment.