-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
docs: Added comprehensive performance comparison for NER models (#268)
* added performance metrics of fine-grained ner * fixing domain table and adding csv for it * new table within tags for fine-grained models * added info box on label transfer in the ner evaluation * added ner-fine-grained tag performance csv * added link to dansk * added figure with domain perf * added additional performance metrics on fine-grained ner, and sota domain eval * removed wasabi import * work in progress * Added outline for new NER docs * Updated ignore * Added performance tables * saving progress * saving progress * save progress * save progress * chore: clean up unused files * docs: Added Ner performance comparisons * style: linting * style: linting * chore: cleanup before squashing * style: Run ruff and added type hints * docs: fix formatting * ci: Update test to not use editable mode * debugging ci: Checking if it is windows specific * ci: ensure it is the correct python version running * ci experiment * style: linting * docs: formatting admonition * ci: re-enable windows * ci remove multiprocessing from ci * ci: debug seems like the python instance is wrong * ci: remove caching * docs: Updated notebook * docs: avoid having to install spacy model for running ci --------- Co-authored-by: Emil Trenckner Jessen <emil.tj@hotmail.com> Co-authored-by: github-actions <github-actions@github.com>
- Loading branch information
1 parent
6a5fe57
commit 3f5d3b0
Showing
21 changed files
with
3,620 additions
and
532 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -167,3 +167,6 @@ _build/* | |
# tutorials: | ||
tutorials/BenchmarkFairness.ipynb | ||
tutorials/*.py | ||
|
||
|
||
temp_files/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import augmenty | ||
import spacy | ||
from dacy.datasets import danish_names, female_names, male_names, muslim_names | ||
|
||
|
||
def get_gender_bias_augmenters() -> dict: | ||
# augmentation | ||
# define pattern of augmentation | ||
patterns = [ | ||
["first_name"], | ||
["first_name", "last_name"], | ||
["first_name", "last_name", "last_name"], | ||
] | ||
|
||
# define person tag for augmenters | ||
person_tag = "PER" | ||
|
||
# define all augmenters | ||
|
||
dk_aug = augmenty.load( | ||
"per_replace_v1", | ||
patterns=patterns, | ||
names=danish_names(), | ||
level=1, | ||
person_tag=person_tag, | ||
replace_consistency=True, | ||
) | ||
|
||
dk_aug = augmenty.load( | ||
"per_replace_v1", | ||
patterns=patterns, | ||
names=danish_names(), | ||
level=1, | ||
person_tag=person_tag, | ||
replace_consistency=True, | ||
) | ||
|
||
muslim_aug = augmenty.load( | ||
"per_replace_v1", | ||
patterns=patterns, | ||
names=muslim_names(), | ||
level=1, | ||
person_tag=person_tag, | ||
replace_consistency=True, | ||
) | ||
|
||
male_aug = augmenty.load( | ||
"per_replace_v1", | ||
patterns=patterns, | ||
names=male_names(), | ||
level=1, | ||
person_tag=person_tag, | ||
replace_consistency=True, | ||
) | ||
|
||
fem_aug = augmenty.load( | ||
"per_replace_v1", | ||
patterns=patterns, | ||
names=female_names(), | ||
level=1, | ||
person_tag=person_tag, | ||
replace_consistency=True, | ||
) | ||
|
||
bias_augmenters = { | ||
"Danish Names": dk_aug, | ||
"Muslim Names": muslim_aug, | ||
"Male Names": male_aug, | ||
"Female Names": fem_aug, | ||
} | ||
return bias_augmenters | ||
|
||
|
||
def get_robustness_augmenters(prob: float = 0.05) -> dict: | ||
# Spelling error augmentations | ||
char_swap_aug = augmenty.load("char_swap_v1", level=prob) | ||
tok_swap_aug = augmenty.load("token_swap_v1", level=prob) | ||
keystroke_aug = augmenty.load( | ||
"keystroke_error_v1", | ||
level=prob, | ||
keyboard="da_qwerty_v1", | ||
) | ||
start_casing_aug = augmenty.load("random_starting_case_v1", level=prob) | ||
|
||
sim_spelling_error_aug = augmenty.combine( | ||
[char_swap_aug, tok_swap_aug, keystroke_aug], | ||
) | ||
inconsistent_casing_aug = start_casing_aug | ||
|
||
# Synonym augmentations | ||
wordnet_aug = augmenty.load("wordnet_synonym_v1", level=prob, lang="da") | ||
nlp = spacy.load("da_core_news_lg") | ||
emb_aug = augmenty.load("word_embedding_v1", level=prob, nlp=nlp) | ||
|
||
synonym_aug = augmenty.combine([wordnet_aug, emb_aug]) | ||
|
||
# spacing augmentations | ||
remove_spacing_augmenter = augmenty.load("remove_spacing_v1", level=prob) | ||
spacing_insertion_augmenter = augmenty.load( | ||
"spacing_insertion_v1", | ||
level=prob, | ||
) | ||
|
||
spacing_aug = augmenty.combine( | ||
[remove_spacing_augmenter, spacing_insertion_augmenter], | ||
) | ||
|
||
# historical spelling augmentations | ||
upper_noun_aug = augmenty.load("da_historical_noun_casing_v1", level=1) | ||
æøå_aug = augmenty.load("da_æøå_replace_v1", level=1) | ||
|
||
hist_spelling_aug = augmenty.combine([upper_noun_aug, æøå_aug]) | ||
|
||
return { | ||
"Spelling Error": sim_spelling_error_aug, | ||
"Inconsistent Casing": inconsistent_casing_aug, | ||
"Synonym replacement": synonym_aug, | ||
"Inconsistent Spacing": spacing_aug, | ||
"Historical Spelling": hist_spelling_aug, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import random | ||
from typing import Any, Dict, List | ||
|
||
import augmenty | ||
import catalogue | ||
import numpy as np | ||
import spacy | ||
from datasets import load_dataset | ||
from spacy.tokens import Doc | ||
from spacy.training import Example | ||
|
||
from .augmentations import get_gender_bias_augmenters, get_robustness_augmenters | ||
|
||
datasets = catalogue.create("dacy", "datasets") | ||
|
||
|
||
@datasets.register("dane") | ||
def dane() -> Dict[str, List[Example]]: | ||
from dacy.datasets import dane as _dane | ||
|
||
train, dev, test = _dane(splits=["train", "dev", "test"]) # type: ignore | ||
nlp_da = spacy.blank("da") | ||
|
||
datasets = {} | ||
for nam, split in zip(["train", "dev", "test"], [train, dev, test]): # type: ignore | ||
examples = list(split(nlp_da)) | ||
datasets[nam] = examples | ||
|
||
return datasets | ||
|
||
|
||
def augment_dataset( | ||
dataset: str, | ||
augmenters: dict, | ||
n_rep: int = 20, | ||
split: str = "test", | ||
) -> List[Example]: | ||
# ensure seed | ||
random.seed(42) | ||
np.random.seed(42) | ||
|
||
nlp_da = spacy.blank("da") | ||
_dataset = datasets.get(dataset) | ||
ds_split = _dataset()[split] | ||
docs = [example.reference for example in ds_split] | ||
|
||
if not Doc.has_extension("meta"): | ||
Doc.set_extension("meta", default={}, force=True) | ||
|
||
# augment | ||
aug_docs = [] | ||
for aug_name, aug in augmenters.items(): | ||
for i in range(n_rep): | ||
_aug_docs = list(augmenty.docs(docs, augmenter=aug, nlp=nlp_da)) | ||
for doc in _aug_docs: | ||
doc._.meta["augmenter"] = aug_name | ||
doc._.meta["n_rep"] = i | ||
aug_docs.extend(_aug_docs) | ||
|
||
# convert to examples | ||
examples = [Example(doc, doc) for doc in aug_docs] | ||
return examples | ||
|
||
|
||
@datasets.register("gender_bias_dane") | ||
def dane_gender_bias() -> Dict[str, List[Example]]: | ||
return {"test": augment_dataset("dane", augmenters=get_gender_bias_augmenters())} | ||
|
||
|
||
@datasets.register("robustness_dane") | ||
def dane_robustness() -> Dict[str, List[Example]]: | ||
return {"test": augment_dataset("dane", augmenters=get_robustness_augmenters())} | ||
|
||
|
||
@datasets.register("dansk") | ||
def dansk(**kwargs: Any) -> Dict[str, List[Example]]: | ||
splits = ["train", "dev", "test"] | ||
|
||
if not Doc.has_extension("meta"): | ||
Doc.set_extension("meta", default={}, force=True) | ||
|
||
nlp = spacy.blank("da") | ||
|
||
def convert_to_doc(example: Dict) -> Doc: | ||
doc = Doc(nlp.vocab).from_json(example) | ||
# set metadata | ||
for k in ["dagw_source", "dagw_domain", "dagw_source_full"]: | ||
doc._.meta[k] = example[k] | ||
return doc | ||
|
||
dataset = {} | ||
for split in splits: | ||
ds = load_dataset("chcaa/DANSK", split=split, **kwargs) | ||
docs = [convert_to_doc(example) for example in ds] # type: ignore | ||
examples = [Example(doc, doc) for doc in docs] | ||
dataset[split] = examples | ||
|
||
return dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
""" | ||
List of models using for testing | ||
""" | ||
|
||
from functools import partial | ||
|
||
import dacy | ||
import spacy | ||
from spacy.language import Language | ||
|
||
|
||
def scandiner_loader() -> Language: | ||
scandiner = spacy.blank("da") | ||
scandiner.add_pipe("dacy/ner") | ||
return scandiner | ||
|
||
|
||
def spacy_wrap_loader(mdl: str) -> Language: | ||
daner_base = spacy.blank("da") | ||
config = {"model": {"name": mdl}, "predictions_to": ["ents"]} | ||
daner_base.add_pipe("token_classification_transformer", config=config) | ||
return daner_base | ||
|
||
|
||
MODELS = { | ||
"saattrupdan/nbailab-base-ner-scandi": scandiner_loader, | ||
"da_dacy_large_trf-0.2.0": partial(dacy.load, "da_dacy_large_trf-0.2.0"), | ||
"da_dacy_medium_trf-0.2.0": partial(dacy.load, "da_dacy_medium_trf-0.2.0"), | ||
"da_dacy_small_trf-0.2.0": partial(dacy.load, "da_dacy_small_trf-0.2.0"), | ||
"da_dacy_large_ner_fine_grained-0.1.0": partial( | ||
dacy.load, | ||
"da_dacy_large_ner_fine_grained-0.1.0", | ||
), | ||
"da_dacy_medium_ner_fine_grained-0.1.0": partial( | ||
dacy.load, | ||
"da_dacy_medium_ner_fine_grained-0.1.0", | ||
), | ||
"da_dacy_small_ner_fine_grained-0.1.0": partial( | ||
dacy.load, | ||
"da_dacy_small_ner_fine_grained-0.1.0", | ||
), | ||
"alexandrainst/da-ner-base": partial( | ||
spacy_wrap_loader, | ||
"alexandrainst/da-ner-base", | ||
), | ||
"da_core_news_trf-3.5.0": partial(spacy.load, "da_core_news_trf"), | ||
"da_core_news_lg-3.5.0": partial(spacy.load, "da_core_news_lg"), | ||
"da_core_news_md-3.5.0": partial(spacy.load, "da_core_news_md"), | ||
"da_core_news_sm-3.5.0": partial(spacy.load, "da_core_news_sm"), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# models required for performance benchmarks | ||
https://github.com/explosion/spacy-models/releases/download/da_core_news_sm-3.5.0/da_core_news_sm-3.5.0-py3-none-any.whl | ||
https://github.com/explosion/spacy-models/releases/download/da_core_news_md-3.5.0/da_core_news_md-3.5.0-py3-none-any.whl | ||
https://github.com/explosion/spacy-models/releases/download/da_core_news_lg-3.5.0/da_core_news_lg-3.5.0-py3-none-any.whl | ||
https://github.com/explosion/spacy-models/releases/download/da_core_news_trf-3.5.0/da_core_news_trf-3.5.0-py3-none-any.whl | ||
# https://huggingface.co/chcaa/da_dacy_medium_ner_fine_grained/resolve/main/da_dacy_medium_ner_fine_grained-any-py3-none-any.whl | ||
# https://huggingface.co/chcaa/da_dacy_large_ner_fine_grained/resolve/main/da_dacy_large_ner_fine_grained-any-py3-none-any.whl | ||
# https://huggingface.co/chcaa/da_dacy_small_ner_fine_grained/resolve/main/da_dacy_small_ner_fine_grained-any-py3-none-any.whl | ||
|
||
dacy>=2.6.0 | ||
altair>=4.1.0 | ||
datasets>=1.14.0 | ||
augmenty[all]>=1.3.7 |
Oops, something went wrong.