Skip to content

Commit

Permalink
Merge pull request #3569 from MattGPT-ai/mattgpt.perf.opt-check-dict-…
Browse files Browse the repository at this point in the history
…has-items

perf: optimize dictionary items check
  • Loading branch information
helpmefindaname authored Nov 28, 2024
2 parents 9a962cb + 8f934a4 commit e53db6d
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 20 deletions.
14 changes: 7 additions & 7 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,18 +122,18 @@ def get_idx_for_items(self, items: list[str]) -> list[int]:
return list(results)

def get_items(self) -> list[str]:
items = []
for item in self.idx2item:
items.append(item.decode("UTF-8"))
return items
return [item.decode("UTF-8") for item in self.idx2item]

def __len__(self) -> int:
return len(self.idx2item)

def get_item_for_index(self, idx):
def get_item_for_index(self, idx: int) -> str:
return self.idx2item[idx].decode("UTF-8")

def set_start_stop_tags(self):
def has_item(self, item: str) -> bool:
return item.encode("utf-8") in self.item2idx

def set_start_stop_tags(self) -> None:
self.add_item("<START>")
self.add_item("<STOP>")

Expand Down Expand Up @@ -1659,7 +1659,7 @@ def make_label_dictionary(
unked_count += count

if len(label_dictionary.idx2item) == 0 or (
len(label_dictionary.idx2item) == 1 and "<unk>" in label_dictionary.get_items()
len(label_dictionary.idx2item) == 1 and label_dictionary.has_item("<unk>")
):
log.error(f"ERROR: You specified label_type='{label_type}' which is not in this dataset!")
contained_labels = ", ".join(
Expand Down
2 changes: 1 addition & 1 deletion flair/models/lemmatizer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def predict(

for t_id, token in enumerate(tokens_in_batch):
predicted_lemma = "".join(
self.char_dictionary.get_item_for_index(idx) if idx != self.end_index else ""
self.char_dictionary.get_item_for_index(int(idx)) if idx != self.end_index else ""
for idx in predicted[t_id]
)
token.set_label(typename=label_name, value=predicted_lemma)
Expand Down
4 changes: 1 addition & 3 deletions flair/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,9 +877,7 @@ def predict(
filtered_indices = []
has_unknown_label = False
for idx, dp in enumerate(data_points):
if all(
label in self.label_dictionary.get_items() for label in self._get_label_of_datapoint(dp)
):
if all(self.label_dictionary.has_item(label) for label in self._get_label_of_datapoint(dp)):
filtered_indices.append(idx)
else:
has_unknown_label = True
Expand Down
18 changes: 9 additions & 9 deletions tests/test_corpus_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ def test_tagged_corpus_make_vocab_dictionary():
vocab = corpus.make_vocab_dictionary(max_tokens=2, min_freq=-1)

assert len(vocab) == 3
assert "<unk>" in vocab.get_items()
assert "training" in vocab.get_items()
assert "." in vocab.get_items()
assert vocab.has_item("<unk>")
assert vocab.has_item("training")
assert vocab.has_item(".")

vocab = corpus.make_vocab_dictionary(max_tokens=-1, min_freq=-1)

Expand All @@ -121,9 +121,9 @@ def test_tagged_corpus_make_vocab_dictionary():
vocab = corpus.make_vocab_dictionary(max_tokens=-1, min_freq=2)

assert len(vocab) == 3
assert "<unk>" in vocab.get_items()
assert "training" in vocab.get_items()
assert "." in vocab.get_items()
assert vocab.has_item("<unk>")
assert vocab.has_item("training")
assert vocab.has_item(".")


def test_label_set_confidence():
Expand Down Expand Up @@ -153,9 +153,9 @@ def test_tagged_corpus_make_label_dictionary():
label_dict = corpus.make_label_dictionary("label", add_unk=True)

assert len(label_dict) == 3
assert "<unk>" in label_dict.get_items()
assert "class_1" in label_dict.get_items()
assert "class_2" in label_dict.get_items()
assert label_dict.has_item("<unk>")
assert label_dict.has_item("class_1")
assert label_dict.has_item("class_2")

with pytest.warns(DeprecationWarning): # test to make sure the warning comes, but function works
corpus.make_tag_dictionary("label")
Expand Down

0 comments on commit e53db6d

Please sign in to comment.