diff --git a/comparisons/exact_match/exact_match.py b/comparisons/exact_match/exact_match.py index b43413315..08c300137 100644 --- a/comparisons/exact_match/exact_match.py +++ b/comparisons/exact_match/exact_match.py @@ -45,9 +45,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class ExactMatch(evaluate.EvaluationModule): +class ExactMatch(evaluate.Comparison): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.ComparisonInfo( module_type="comparison", description=_DESCRIPTION, citation=_CITATION, diff --git a/comparisons/mcnemar/mcnemar.py b/comparisons/mcnemar/mcnemar.py index 17a8bf469..86b85b5e3 100644 --- a/comparisons/mcnemar/mcnemar.py +++ b/comparisons/mcnemar/mcnemar.py @@ -61,9 +61,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class McNemar(evaluate.EvaluationModule): +class McNemar(evaluate.Comparison): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.ComparisonInfo( module_type="comparison", description=_DESCRIPTION, citation=_CITATION, diff --git a/docs/source/package_reference/main_classes.mdx b/docs/source/package_reference/main_classes.mdx index 06d1dc3ac..9bcba5b94 100644 --- a/docs/source/package_reference/main_classes.mdx +++ b/docs/source/package_reference/main_classes.mdx @@ -2,10 +2,24 @@ ## EvaluationModuleInfo +The base class `EvaluationModuleInfo` implements a the logic for the subclasses `MetricInfo`, `ComparisonInfo`, and `MeasurementInfo`. + [[autodoc]] evaluate.EvaluationModuleInfo +[[autodoc]] evaluate.MetricInfo + +[[autodoc]] evaluate.ComparisonInfo + +[[autodoc]] evaluate.MeasurementInfo + ## EvaluationModule -The base class `Metric` implements a Metric backed by one or several [`Dataset`]. +The base class `EvaluationModule` implements a the logic for the subclasses `Metric`, `Comparison`, and `Measurement`. + +[[autodoc]] evaluate.EvaluationModule + +[[autodoc]] evaluate.Metric + +[[autodoc]] evaluate.Comparison -[[autodoc]] evaluate.EvaluationModule \ No newline at end of file +[[autodoc]] evaluate.Measurement \ No newline at end of file diff --git a/measurements/perplexity/perplexity.py b/measurements/perplexity/perplexity.py index 1fab3342a..d10bbb41f 100644 --- a/measurements/perplexity/perplexity.py +++ b/measurements/perplexity/perplexity.py @@ -85,9 +85,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class Perplexity(evaluate.EvaluationModule): +class Perplexity(evaluate.Measurement): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MeasurementInfo( module_type="measurement", description=_DESCRIPTION, citation=_CITATION, diff --git a/measurements/text_duplicates/text_duplicates.py b/measurements/text_duplicates/text_duplicates.py index 372d33a0e..14a4dccfb 100644 --- a/measurements/text_duplicates/text_duplicates.py +++ b/measurements/text_duplicates/text_duplicates.py @@ -52,12 +52,12 @@ def get_hash(example): return hashlib.md5(example.strip().encode("utf-8")).hexdigest() @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class TextDuplicates(evaluate.EvaluationModule): +class TextDuplicates(evaluate.Measurement): """This measurement returns the duplicate strings contained in the input(s).""" def _info(self): - # TODO: Specifies the evaluate.EvaluationModuleInfo object - return evaluate.EvaluationModuleInfo( + # TODO: Specifies the evaluate.MeasurementInfo object + return evaluate.MeasurementInfo( # This is the description that will appear on the modules page. module_type="measurement", description=_DESCRIPTION, diff --git a/measurements/word_count/word_count.py b/measurements/word_count/word_count.py index 5393aa030..485a6dfd4 100644 --- a/measurements/word_count/word_count.py +++ b/measurements/word_count/word_count.py @@ -39,12 +39,12 @@ _CITATION = "" @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class WordCount(evaluate.EvaluationModule): +class WordCount(evaluate.Measurement): """This measurement returns the total number of words and the number of unique words in the input string(s).""" def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MeasurementInfo( # This is the description that will appear on the modules page. module_type="measurement", description=_DESCRIPTION, diff --git a/measurements/word_length/word_length.py b/measurements/word_length/word_length.py index f7565bef7..2a56b748f 100644 --- a/measurements/word_length/word_length.py +++ b/measurements/word_length/word_length.py @@ -50,12 +50,12 @@ """ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class WordLength(evaluate.EvaluationModule): +class WordLength(evaluate.Measurement): """This measurement returns the average number of words in the input string(s).""" def _info(self): - # TODO: Specifies the evaluate.EvaluationModuleInfo object - return evaluate.EvaluationModuleInfo( + # TODO: Specifies the evaluate.MeasurementInfo object + return evaluate.MeasurementInfo( # This is the description that will appear on the modules page. module_type="measurement", description=_DESCRIPTION, diff --git a/metrics/accuracy/accuracy.py b/metrics/accuracy/accuracy.py index e8f8c3e72..aa5a07328 100644 --- a/metrics/accuracy/accuracy.py +++ b/metrics/accuracy/accuracy.py @@ -78,9 +78,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class Accuracy(evaluate.EvaluationModule): +class Accuracy(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/bertscore/bertscore.py b/metrics/bertscore/bertscore.py index 5f8235152..ba985575c 100644 --- a/metrics/bertscore/bertscore.py +++ b/metrics/bertscore/bertscore.py @@ -98,9 +98,9 @@ def filter_log(record): @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class BERTScore(evaluate.EvaluationModule): +class BERTScore(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, homepage="https://github.com/Tiiiger/bert_score", diff --git a/metrics/bleu/bleu.py b/metrics/bleu/bleu.py index 5c8d1c400..38a10c3b3 100644 --- a/metrics/bleu/bleu.py +++ b/metrics/bleu/bleu.py @@ -85,9 +85,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class Bleu(evaluate.EvaluationModule): +class Bleu(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/bleurt/bleurt.py b/metrics/bleurt/bleurt.py index 715c0de5e..b47f8d284 100644 --- a/metrics/bleurt/bleurt.py +++ b/metrics/bleurt/bleurt.py @@ -77,10 +77,10 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class BLEURT(evaluate.EvaluationModule): +class BLEURT(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, homepage="https://github.com/google-research/bleurt", diff --git a/metrics/cer/cer.py b/metrics/cer/cer.py index fbc1a3329..c5f4a9072 100644 --- a/metrics/cer/cer.py +++ b/metrics/cer/cer.py @@ -116,9 +116,9 @@ def process_list(self, inp: List[str]): @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class CER(evaluate.EvaluationModule): +class CER(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/chrf/chrf.py b/metrics/chrf/chrf.py index 32a12155d..77da699e3 100644 --- a/metrics/chrf/chrf.py +++ b/metrics/chrf/chrf.py @@ -124,14 +124,14 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class ChrF(evaluate.EvaluationModule): +class ChrF(evaluate.Metric): def _info(self): if version.parse(scb.__version__) < version.parse("1.4.12"): raise ImportWarning( "To use `sacrebleu`, the module `sacrebleu>=1.4.12` is required, and the current version of `sacrebleu` doesn't match this condition.\n" 'You can install it with `pip install "sacrebleu>=1.4.12"`.' ) - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, homepage="https://github.com/mjpost/sacreBLEU#chrf--chrf", diff --git a/metrics/code_eval/code_eval.py b/metrics/code_eval/code_eval.py index ecd0f173f..0885712e6 100644 --- a/metrics/code_eval/code_eval.py +++ b/metrics/code_eval/code_eval.py @@ -132,9 +132,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class CodeEval(evaluate.EvaluationModule): +class CodeEval(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( # This is the description that will appear on the metrics page. description=_DESCRIPTION, citation=_CITATION, diff --git a/metrics/comet/comet.py b/metrics/comet/comet.py index 23dd74fb0..0465ec644 100644 --- a/metrics/comet/comet.py +++ b/metrics/comet/comet.py @@ -107,10 +107,10 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class COMET(evaluate.EvaluationModule): +class COMET(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, homepage="https://unbabel.github.io/COMET/html/index.html", diff --git a/metrics/competition_math/competition_math.py b/metrics/competition_math/competition_math.py index fc8e12c78..9a82eb40b 100644 --- a/metrics/competition_math/competition_math.py +++ b/metrics/competition_math/competition_math.py @@ -64,11 +64,11 @@ @datasets.utils.file_utils.add_end_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class CompetitionMathMetric(evaluate.EvaluationModule): +class CompetitionMathMetric(evaluate.Metric): """Accuracy metric for the MATH dataset.""" def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/coval/coval.py b/metrics/coval/coval.py index 07a9428f0..f1518b958 100644 --- a/metrics/coval/coval.py +++ b/metrics/coval/coval.py @@ -270,9 +270,9 @@ def check_gold_parse_annotation(key_lines): @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class Coval(evaluate.EvaluationModule): +class Coval(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/cuad/cuad.py b/metrics/cuad/cuad.py index ec1932d5c..21aec5af4 100644 --- a/metrics/cuad/cuad.py +++ b/metrics/cuad/cuad.py @@ -69,9 +69,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class CUAD(evaluate.EvaluationModule): +class CUAD(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/exact_match/exact_match.py b/metrics/exact_match/exact_match.py index 6abb20f40..d8c499b37 100644 --- a/metrics/exact_match/exact_match.py +++ b/metrics/exact_match/exact_match.py @@ -84,9 +84,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class ExactMatch(evaluate.EvaluationModule): +class ExactMatch(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/f1/f1.py b/metrics/f1/f1.py index f80dd1b5a..7a27ca984 100644 --- a/metrics/f1/f1.py +++ b/metrics/f1/f1.py @@ -97,9 +97,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class F1(evaluate.EvaluationModule): +class F1(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/frugalscore/frugalscore.py b/metrics/frugalscore/frugalscore.py index 65d7330c1..b6f7ee5f3 100644 --- a/metrics/frugalscore/frugalscore.py +++ b/metrics/frugalscore/frugalscore.py @@ -55,9 +55,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class FRUGALSCORE(evaluate.EvaluationModule): +class FRUGALSCORE(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/glue/glue.py b/metrics/glue/glue.py index d62248148..8c607bf1e 100644 --- a/metrics/glue/glue.py +++ b/metrics/glue/glue.py @@ -103,7 +103,7 @@ def pearson_and_spearman(preds, labels): @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class Glue(evaluate.EvaluationModule): +class Glue(evaluate.Metric): def _info(self): if self.config_name not in [ "sst2", @@ -124,7 +124,7 @@ def _info(self): '["sst2", "mnli", "mnli_mismatched", "mnli_matched", ' '"cola", "stsb", "mrpc", "qqp", "qnli", "rte", "wnli", "hans"]' ) - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/google_bleu/google_bleu.py b/metrics/google_bleu/google_bleu.py index d55e4e066..adcc0a31f 100644 --- a/metrics/google_bleu/google_bleu.py +++ b/metrics/google_bleu/google_bleu.py @@ -19,7 +19,7 @@ from nltk.translate import gleu_score import evaluate -from evaluate import EvaluationModuleInfo +from evaluate import MetricInfo from .tokenizer_13a import Tokenizer13a @@ -125,9 +125,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class GoogleBleu(evaluate.EvaluationModule): - def _info(self) -> EvaluationModuleInfo: - return evaluate.EvaluationModuleInfo( +class GoogleBleu(evaluate.Metric): + def _info(self) -> MetricInfo: + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/indic_glue/indic_glue.py b/metrics/indic_glue/indic_glue.py index a53fb0b31..03afd1700 100644 --- a/metrics/indic_glue/indic_glue.py +++ b/metrics/indic_glue/indic_glue.py @@ -103,7 +103,7 @@ def precision_at_10(en_sentvecs, in_sentvecs): @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class IndicGlue(evaluate.EvaluationModule): +class IndicGlue(evaluate.Metric): def _info(self): if self.config_name not in [ "wnli", @@ -126,7 +126,7 @@ def _info(self): '"cvit-mkb-clsr", "iitp-mr", "iitp-pr", "actsa-sc", "md", ' '"wiki-ner"]' ) - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/mae/mae.py b/metrics/mae/mae.py index eedd7d27d..e973bc5ff 100644 --- a/metrics/mae/mae.py +++ b/metrics/mae/mae.py @@ -82,9 +82,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class Mae(evaluate.EvaluationModule): +class Mae(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/mahalanobis/mahalanobis.py b/metrics/mahalanobis/mahalanobis.py index ff41576d0..a2cad4996 100644 --- a/metrics/mahalanobis/mahalanobis.py +++ b/metrics/mahalanobis/mahalanobis.py @@ -58,9 +58,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class Mahalanobis(evaluate.EvaluationModule): +class Mahalanobis(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/matthews_correlation/matthews_correlation.py b/metrics/matthews_correlation/matthews_correlation.py index 2d0ef98bc..295886be3 100644 --- a/metrics/matthews_correlation/matthews_correlation.py +++ b/metrics/matthews_correlation/matthews_correlation.py @@ -80,9 +80,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class MatthewsCorrelation(evaluate.EvaluationModule): +class MatthewsCorrelation(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/mauve/mauve.py b/metrics/mauve/mauve.py index c8c571985..9969135a7 100644 --- a/metrics/mauve/mauve.py +++ b/metrics/mauve/mauve.py @@ -86,9 +86,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class Mauve(evaluate.EvaluationModule): +class Mauve(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, homepage="https://github.com/krishnap25/mauve", diff --git a/metrics/mean_iou/mean_iou.py b/metrics/mean_iou/mean_iou.py index be5ef39b4..421a261f4 100644 --- a/metrics/mean_iou/mean_iou.py +++ b/metrics/mean_iou/mean_iou.py @@ -274,9 +274,9 @@ def mean_iou( @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class MeanIoU(evaluate.EvaluationModule): +class MeanIoU(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/meteor/meteor.py b/metrics/meteor/meteor.py index 46499084c..3dcffdde1 100644 --- a/metrics/meteor/meteor.py +++ b/metrics/meteor/meteor.py @@ -83,9 +83,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class Meteor(evaluate.EvaluationModule): +class Meteor(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/mse/mse.py b/metrics/mse/mse.py index bae059d4a..fb695bfde 100644 --- a/metrics/mse/mse.py +++ b/metrics/mse/mse.py @@ -86,9 +86,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class Mse(evaluate.EvaluationModule): +class Mse(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/pearsonr/pearsonr.py b/metrics/pearsonr/pearsonr.py index b30cde50f..5ed0e7620 100644 --- a/metrics/pearsonr/pearsonr.py +++ b/metrics/pearsonr/pearsonr.py @@ -84,9 +84,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class Pearsonr(evaluate.EvaluationModule): +class Pearsonr(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/perplexity/perplexity.py b/metrics/perplexity/perplexity.py index 5a8b4573e..b636e12ca 100644 --- a/metrics/perplexity/perplexity.py +++ b/metrics/perplexity/perplexity.py @@ -85,9 +85,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class Perplexity(evaluate.EvaluationModule): +class Perplexity(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( module_type="metric", description=_DESCRIPTION, citation=_CITATION, diff --git a/metrics/precision/precision.py b/metrics/precision/precision.py index e5db1c658..4b35aa7e4 100644 --- a/metrics/precision/precision.py +++ b/metrics/precision/precision.py @@ -103,9 +103,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class Precision(evaluate.EvaluationModule): +class Precision(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/recall/recall.py b/metrics/recall/recall.py index c7ebe97e2..8522cfcf6 100644 --- a/metrics/recall/recall.py +++ b/metrics/recall/recall.py @@ -93,9 +93,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class Recall(evaluate.EvaluationModule): +class Recall(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/rl_reliability/rl_reliability.py b/metrics/rl_reliability/rl_reliability.py index 08f566431..34a9c4570 100644 --- a/metrics/rl_reliability/rl_reliability.py +++ b/metrics/rl_reliability/rl_reliability.py @@ -82,14 +82,14 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class RLReliability(evaluate.EvaluationModule): +class RLReliability(evaluate.Metric): """Computes the RL Reliability Metrics.""" def _info(self): if self.config_name not in ["online", "offline"]: raise KeyError("""You should supply a configuration name selected in '["online", "offline"]'""") - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( module_type="metric", description=_DESCRIPTION, citation=_CITATION, diff --git a/metrics/roc_auc/roc_auc.py b/metrics/roc_auc/roc_auc.py index f90156019..604f93ab3 100644 --- a/metrics/roc_auc/roc_auc.py +++ b/metrics/roc_auc/roc_auc.py @@ -143,9 +143,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class ROCAUC(evaluate.EvaluationModule): +class ROCAUC(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/rouge/rouge.py b/metrics/rouge/rouge.py index 7a8de9880..9546ce01c 100644 --- a/metrics/rouge/rouge.py +++ b/metrics/rouge/rouge.py @@ -81,9 +81,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class Rouge(evaluate.EvaluationModule): +class Rouge(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/sacrebleu/sacrebleu.py b/metrics/sacrebleu/sacrebleu.py index 21290138b..6e756f4d4 100644 --- a/metrics/sacrebleu/sacrebleu.py +++ b/metrics/sacrebleu/sacrebleu.py @@ -103,14 +103,14 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class Sacrebleu(evaluate.EvaluationModule): +class Sacrebleu(evaluate.Metric): def _info(self): if version.parse(scb.__version__) < version.parse("1.4.12"): raise ImportWarning( "To use `sacrebleu`, the module `sacrebleu>=1.4.12` is required, and the current version of `sacrebleu` doesn't match this condition.\n" 'You can install it with `pip install "sacrebleu>=1.4.12"`.' ) - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, homepage="https://github.com/mjpost/sacreBLEU", diff --git a/metrics/sari/sari.py b/metrics/sari/sari.py index 11a01d2e9..7d021184d 100644 --- a/metrics/sari/sari.py +++ b/metrics/sari/sari.py @@ -258,9 +258,9 @@ def normalize(sentence, lowercase: bool = True, tokenizer: str = "13a", return_s @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class Sari(evaluate.EvaluationModule): +class Sari(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/seqeval/seqeval.py b/metrics/seqeval/seqeval.py index f9e55877f..252d16b6c 100644 --- a/metrics/seqeval/seqeval.py +++ b/metrics/seqeval/seqeval.py @@ -100,9 +100,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class Seqeval(evaluate.EvaluationModule): +class Seqeval(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, homepage="https://github.com/chakki-works/seqeval", diff --git a/metrics/spearmanr/spearmanr.py b/metrics/spearmanr/spearmanr.py index 6dfc08676..3be1743e7 100644 --- a/metrics/spearmanr/spearmanr.py +++ b/metrics/spearmanr/spearmanr.py @@ -97,9 +97,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class Spearmanr(evaluate.EvaluationModule): +class Spearmanr(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/squad/squad.py b/metrics/squad/squad.py index bbd339bcf..84658b125 100644 --- a/metrics/squad/squad.py +++ b/metrics/squad/squad.py @@ -66,9 +66,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class Squad(evaluate.EvaluationModule): +class Squad(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/squad_v2/squad_v2.py b/metrics/squad_v2/squad_v2.py index 2e3663aac..cb9ba1ae8 100644 --- a/metrics/squad_v2/squad_v2.py +++ b/metrics/squad_v2/squad_v2.py @@ -88,9 +88,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class SquadV2(evaluate.EvaluationModule): +class SquadV2(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/super_glue/super_glue.py b/metrics/super_glue/super_glue.py index bd8a329db..993a003b3 100644 --- a/metrics/super_glue/super_glue.py +++ b/metrics/super_glue/super_glue.py @@ -145,7 +145,7 @@ def evaluate_multirc(ids_preds, labels): @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class SuperGlue(evaluate.EvaluationModule): +class SuperGlue(evaluate.Metric): def _info(self): if self.config_name not in [ "boolq", @@ -164,7 +164,7 @@ def _info(self): "You should supply a configuration name selected in " '["boolq", "cb", "copa", "multirc", "record", "rte", "wic", "wsc", "wsc.fixed", "axb", "axg",]' ) - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/ter/ter.py b/metrics/ter/ter.py index 028dc727e..4adb9986c 100644 --- a/metrics/ter/ter.py +++ b/metrics/ter/ter.py @@ -151,14 +151,14 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class Ter(evaluate.EvaluationModule): +class Ter(evaluate.Metric): def _info(self): if version.parse(scb.__version__) < version.parse("1.4.12"): raise ImportWarning( "To use `sacrebleu`, the module `sacrebleu>=1.4.12` is required, and the current version of `sacrebleu` doesn't match this condition.\n" 'You can install it with `pip install "sacrebleu>=1.4.12"`.' ) - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, homepage="http://www.cs.umd.edu/~snover/tercom/", diff --git a/metrics/trec_eval/trec_eval.py b/metrics/trec_eval/trec_eval.py index ac2dc3183..462374e01 100644 --- a/metrics/trec_eval/trec_eval.py +++ b/metrics/trec_eval/trec_eval.py @@ -68,11 +68,11 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class TRECEval(evaluate.EvaluationModule): +class TRECEval(evaluate.Metric): """Compute TREC evaluation scores.""" def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( module_type="metric", description=_DESCRIPTION, citation=_CITATION, diff --git a/metrics/wer/wer.py b/metrics/wer/wer.py index 8a49c4f10..214d5b22e 100644 --- a/metrics/wer/wer.py +++ b/metrics/wer/wer.py @@ -75,9 +75,9 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class WER(evaluate.EvaluationModule): +class WER(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/wiki_split/wiki_split.py b/metrics/wiki_split/wiki_split.py index d52f76c4e..be83681b7 100644 --- a/metrics/wiki_split/wiki_split.py +++ b/metrics/wiki_split/wiki_split.py @@ -321,9 +321,9 @@ def compute_sacrebleu( @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class WikiSplit(evaluate.EvaluationModule): +class WikiSplit(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/xnli/xnli.py b/metrics/xnli/xnli.py index a81c389d3..cc631d55f 100644 --- a/metrics/xnli/xnli.py +++ b/metrics/xnli/xnli.py @@ -67,9 +67,9 @@ def simple_accuracy(preds, labels): @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class Xnli(evaluate.EvaluationModule): +class Xnli(evaluate.Metric): def _info(self): - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/metrics/xtreme_s/xtreme_s.py b/metrics/xtreme_s/xtreme_s.py index 19ba65e16..b4c052fc5 100644 --- a/metrics/xtreme_s/xtreme_s.py +++ b/metrics/xtreme_s/xtreme_s.py @@ -219,14 +219,14 @@ def compute_score(preds, labels, score_type="wer"): @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class XtremeS(evaluate.EvaluationModule): +class XtremeS(evaluate.Metric): def _info(self): if self.config_name not in _CONFIG_NAMES: raise KeyError(f"You should supply a configuration name selected in {_CONFIG_NAMES}") pred_type = "int64" if self.config_name in ["fleurs-lang_id", "minds14"] else "string" - return evaluate.EvaluationModuleInfo( + return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, diff --git a/src/evaluate/__init__.py b/src/evaluate/__init__.py index 8edcc85c2..74550677d 100644 --- a/src/evaluate/__init__.py +++ b/src/evaluate/__init__.py @@ -28,10 +28,10 @@ from .evaluator import Evaluator, TextClassificationEvaluator, evaluator from .hub import push_to_hub -from .info import EvaluationModuleInfo +from .info import ComparisonInfo, EvaluationModuleInfo, MeasurementInfo, MetricInfo from .inspect import inspect_evaluation_module, list_evaluation_modules from .loading import load -from .module import CombinedEvaluations, EvaluationModule, combine +from .module import CombinedEvaluations, Comparison, EvaluationModule, Measurement, Metric from .saving import save from .utils import * from .utils import gradio, logging diff --git a/src/evaluate/info.py b/src/evaluate/info.py index 250598b5c..14111deb1 100644 --- a/src/evaluate/info.py +++ b/src/evaluate/info.py @@ -33,9 +33,10 @@ @dataclass class EvaluationModuleInfo: - """Information about a metric. + """Base class to store fnformation about an evaluation used for `MetricInfo`, `ComparisonInfo`, + and `MeasurementInfo`. - `EvaluationModuleInfo` documents a metric, including its name, version, and features. + `EvaluationModuleInfo` documents an evaluation, including its name, version, and features. See the constructor arguments and properties for a full list. Note: Not all fields are known on construction and may be updated later. @@ -52,10 +53,10 @@ class EvaluationModuleInfo: reference_urls: List[str] = field(default_factory=list) streamable: bool = False format: Optional[str] = None - module_type: str = "metric" + module_type: str = "metric" # deprecate this in the future # Set later by the builder - metric_name: Optional[str] = None + module_name: Optional[str] = None config_name: Optional[str] = None experiment_id: Optional[str] = None @@ -98,3 +99,42 @@ def from_directory(cls, metric_info_dir) -> "EvaluationModuleInfo": def from_dict(cls, metric_info_dict: dict) -> "EvaluationModuleInfo": field_names = {f.name for f in dataclasses.fields(cls)} return cls(**{k: v for k, v in metric_info_dict.items() if k in field_names}) + + +@dataclass +class MetricInfo(EvaluationModuleInfo): + """Information about a metric. + + `EvaluationModuleInfo` documents a metric, including its name, version, and features. + See the constructor arguments and properties for a full list. + + Note: Not all fields are known on construction and may be updated later. + """ + + module_type: str = "metric" + + +@dataclass +class ComparisonInfo(EvaluationModuleInfo): + """Information about a comparison. + + `EvaluationModuleInfo` documents a comparison, including its name, version, and features. + See the constructor arguments and properties for a full list. + + Note: Not all fields are known on construction and may be updated later. + """ + + module_type: str = "comparison" + + +@dataclass +class MeasurementInfo(EvaluationModuleInfo): + """Information about a measurement. + + `EvaluationModuleInfo` documents a measurement, including its name, version, and features. + See the constructor arguments and properties for a full list. + + Note: Not all fields are known on construction and may be updated later. + """ + + module_type: str = "measurement" diff --git a/src/evaluate/module.py b/src/evaluate/module.py index b2b7cc9a0..1efafbb41 100644 --- a/src/evaluate/module.py +++ b/src/evaluate/module.py @@ -710,6 +710,69 @@ def _enforce_nested_string_type(self, schema, obj): raise TypeError(f"Expected type str but got {type(obj)}.") +class Metric(EvaluationModule): + """A Metric is the base class and common API for all metrics. + + Args: + config_name (``str``): This is used to define a hash specific to a metric computation script and prevents the metric's data + to be overridden when the metric loading script is modified. + keep_in_memory (:obj:`bool`): keep all predictions and references in memory. Not possible in distributed settings. + cache_dir (``str``): Path to a directory in which temporary prediction/references data will be stored. + The data directory should be located on a shared file-system in distributed setups. + num_process (``int``): specify the total number of nodes in a distributed settings. + This is useful to compute metrics in distributed setups (in particular non-additive metrics like F1). + process_id (``int``): specify the id of the current process in a distributed setup (between 0 and num_process-1) + This is useful to compute metrics in distributed setups (in particular non-additive metrics like F1). + seed (:obj:`int`, optional): If specified, this will temporarily set numpy's random seed when :func:`evaluate.Metric.compute` is run. + experiment_id (``str``): A specific experiment id. This is used if several distributed evaluations share the same file system. + This is useful to compute metrics in distributed setups (in particular non-additive metrics like F1). + max_concurrent_cache_files (``int``): Max number of concurrent metric cache files (default 10000). + timeout (``Union[int, float]``): Timeout in second for distributed setting synchronization. + """ + + +class Comparison(EvaluationModule): + """A Comparison is the base class and common API for all comparisons. + + Args: + config_name (``str``): This is used to define a hash specific to a comparison computation script and prevents the comparison's data + to be overridden when the comparison loading script is modified. + keep_in_memory (:obj:`bool`): keep all predictions and references in memory. Not possible in distributed settings. + cache_dir (``str``): Path to a directory in which temporary prediction/references data will be stored. + The data directory should be located on a shared file-system in distributed setups. + num_process (``int``): specify the total number of nodes in a distributed settings. + This is useful to compute comparisons in distributed setups (in particular non-additive comparisons). + process_id (``int``): specify the id of the current process in a distributed setup (between 0 and num_process-1) + This is useful to compute comparisons in distributed setups (in particular non-additive comparisons). + seed (:obj:`int`, optional): If specified, this will temporarily set numpy's random seed when :func:`evaluate.Comparison.compute` is run. + experiment_id (``str``): A specific experiment id. This is used if several distributed evaluations share the same file system. + This is useful to compute comparisons in distributed setups (in particular non-additive comparisons). + max_concurrent_cache_files (``int``): Max number of concurrent comparison cache files (default 10000). + timeout (``Union[int, float]``): Timeout in second for distributed setting synchronization. + """ + + +class Measurement(EvaluationModule): + """A Measurement is the base class and common API for all measurements. + + Args: + config_name (``str``): This is used to define a hash specific to a measurement computation script and prevents the measurement's data + to be overridden when the measurement loading script is modified. + keep_in_memory (:obj:`bool`): keep all predictions and references in memory. Not possible in distributed settings. + cache_dir (``str``): Path to a directory in which temporary prediction/references data will be stored. + The data directory should be located on a shared file-system in distributed setups. + num_process (``int``): specify the total number of nodes in a distributed settings. + This is useful to compute measurements in distributed setups (in particular non-additive measurements). + process_id (``int``): specify the id of the current process in a distributed setup (between 0 and num_process-1) + This is useful to compute measurements in distributed setups (in particular non-additive measurements). + seed (:obj:`int`, optional): If specified, this will temporarily set numpy's random seed when :func:`evaluate.Measurement.compute` is run. + experiment_id (``str``): A specific experiment id. This is used if several distributed evaluations share the same file system. + This is useful to compute measurements in distributed setups (in particular non-additive measurements). + max_concurrent_cache_files (``int``): Max number of concurrent measurement cache files (default 10000). + timeout (``Union[int, float]``): Timeout in second for distributed setting synchronization. + """ + + class CombinedEvaluations: def __init__(self, evaluation_modules, force_prefix=False): from .loading import load # avoid circular imports diff --git a/templates/{{ cookiecutter.module_slug }}/{{ cookiecutter.module_slug }}.py b/templates/{{ cookiecutter.module_slug }}/{{ cookiecutter.module_slug }}.py index 9973b3b17..578ebd2e8 100644 --- a/templates/{{ cookiecutter.module_slug }}/{{ cookiecutter.module_slug }}.py +++ b/templates/{{ cookiecutter.module_slug }}/{{ cookiecutter.module_slug }}.py @@ -58,14 +58,14 @@ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class {{ cookiecutter.module_class_name }}(evaluate.EvaluationModule): +class {{ cookiecutter.module_class_name }}(evaluate.{{ cookiecutter.module_type | capitalize}}): """TODO: Short description of my evaluation module.""" def _info(self): # TODO: Specifies the evaluate.EvaluationModuleInfo object - return evaluate.EvaluationModuleInfo( + return evaluate.{{ cookiecutter.module_type | capitalize}}Info( # This is the description that will appear on the modules page. - module_type="{{ cookiecutter.module_type }}", + module_type="{{ cookiecutter.module_type}}", description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION,