diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index a27b57ab9..fdb16eea2 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -278,11 +278,11 @@ def _reconstruct_word_ids_from_subtokens(embedding, tokens: list[str], subtokens special_tokens = [] # check if special tokens exist to circumvent error message - if embedding.tokenizer._bos_token: + if embedding.tokenizer.bos_token is not None: special_tokens.append(embedding.tokenizer.bos_token) - if embedding.tokenizer._cls_token: + if embedding.tokenizer.cls_token is not None: special_tokens.append(embedding.tokenizer.cls_token) - if embedding.tokenizer._sep_token: + if embedding.tokenizer.sep_token is not None: special_tokens.append(embedding.tokenizer.sep_token) # iterate over subtokens and reconstruct tokens @@ -1354,9 +1354,10 @@ def from_params(cls, params): def to_params(self): config_dict = self.model.config.to_dict() - # do not switch the attention implementation upon reload. - config_dict["attn_implementation"] = self.model.config._attn_implementation - config_dict.pop("_attn_implementation_autoset", None) + if hasattr(self.model.config, "_attn_implementation"): + # do not switch the attention implementation upon reload. + config_dict["attn_implementation"] = self.model.config._attn_implementation + config_dict.pop("_attn_implementation_autoset", None) super_params = super().to_params() diff --git a/flair/models/tars_model.py b/flair/models/tars_model.py index 4f5cb8573..f4171fdb2 100644 --- a/flair/models/tars_model.py +++ b/flair/models/tars_model.py @@ -383,7 +383,7 @@ def __init__( # transformer separator self.separator = str(self.tars_embeddings.tokenizer.sep_token) - if self.tars_embeddings.tokenizer._bos_token: + if self.tars_embeddings.tokenizer.bos_token is not None: self.separator += str(self.tars_embeddings.tokenizer.bos_token) self.prefix = prefix @@ -718,9 +718,11 @@ def __init__( ) # transformer separator - self.separator = str(self.tars_embeddings.tokenizer.sep_token) - if self.tars_embeddings.tokenizer._bos_token: - self.separator += str(self.tars_embeddings.tokenizer.bos_token) + self.separator = ( + self.tars_embeddings.tokenizer.sep_token if self.tars_embeddings.tokenizer.sep_token is not None else "" + ) + if self.tars_embeddings.tokenizer.bos_token is not None: + self.separator += self.tars_embeddings.tokenizer.bos_token self.prefix = prefix self.num_negative_labels_to_sample = num_negative_labels_to_sample diff --git a/requirements.txt b/requirements.txt index bb5ecafd4..2704114ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,6 @@ tabulate>=0.8.10 torch>=1.5.0,!=1.8 tqdm>=4.63.0 transformer-smaller-training-vocab>=0.2.3 -transformers[sentencepiece]>=4.18.0,<5.0.0 +transformers[sentencepiece]>=4.25.0,<5.0.0 wikipedia-api>=0.5.7 bioc<3.0.0,>=2.0.0