Skip to content

Commit

Permalink
Check if all docs have domain attribute (#267)
Browse files Browse the repository at this point in the history
* `--list -t wmt23 -l cs-uk` should print just cs-uk, not other language pairs

When omitting `-l`, `--list` will still print all the language pairs for that test set.

Motivation:
Originally, `--list` showed just the list of language pairs, so there was no reason
to call it with `-l`, but now it lists all the **fields** for a given language pair
and it is relatively slow (it has to parse the XML files), so it makes sense
to restrict the listing to a single language pair only.

* allow `--detail` and `--subset` to be used also with the new XML test sets

* Fix lint errors

* Fix CI errors

* Fix domain field can be None

---------

Co-authored-by: Martin Popel <popel@ufal.mff.cuni.cz>
  • Loading branch information
jkawamoto and martinpopel committed Jul 23, 2024
1 parent 77ca43d commit dc301a5
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 56 deletions.
29 changes: 13 additions & 16 deletions sacrebleu/dataset/wmt_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ def _unwrap_wmt21_or_later(raw_file):
This script is adapted from https://github.com/wmt-conference/wmt-format-tools
:param raw_file: The raw xml file to unwrap.
:return: Dictionary which contains the following fields:
:return: Dictionary which contains the following fields
(each a list with values for each sentence):
- `src`: The source sentences.
- `docid`: ID indicating which document the sentences belong to.
- `origlang`: The original language of the document.
- `domain`: Domain of the document.
- `ref:{translator}`: The references produced by each translator.
- `ref`: An alias for the references from the first translator.
"""
Expand Down Expand Up @@ -60,13 +62,8 @@ def _unwrap_wmt21_or_later(raw_file):

systems = defaultdict(list)

src_sent_count, doc_count = 0, 0
src_sent_count, doc_count, seen_domain = 0, 0, False
for doc in tree.getroot().findall(".//doc"):
docid = doc.attrib["id"]
origlang = doc.attrib["origlang"]
# present wmt22++
domain = doc.attrib.get("domain", None)

# Skip the testsuite
if "testsuite" in doc.attrib:
continue
Expand Down Expand Up @@ -104,17 +101,17 @@ def get_sents(doc):
src.append(src_sents[seg_id])
for system_name in hyps.keys():
systems[system_name].append(hyps[system_name][seg_id])
docids.append(docid)
orig_langs.append(origlang)
if domain is not None:
domains.append(domain)
docids.append(doc.attrib["id"])
orig_langs.append(doc.attrib["origlang"])
# The "domain" attribute is missing in WMT21 and WMT22
domains.append(doc.get("domain"))
seen_domain = doc.get("domain") is not None
src_sent_count += 1

data = {"src": src, **refs, "docid": docids, "origlang": orig_langs, **systems}
if len(domains):
data["domain"] = domains

return data
fields = {"src": src, **refs, "docid": docids, "origlang": orig_langs, **systems}
if seen_domain:
fields["domain"] = domains
return fields

def _get_langpair_path(self, langpair):
"""
Expand Down
3 changes: 1 addition & 2 deletions sacrebleu/sacrebleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,7 @@ def main():

if args.list:
if args.test_set:
langpairs = get_langpairs_for_testset(args.test_set)
for pair in langpairs:
for pair in [args.langpair] if args.langpair else get_langpairs_for_testset(args.test_set):
fields = DATASETS[args.test_set].fieldnames(pair)
print(f'{pair}: {", ".join(fields)}')
else:
Expand Down
110 changes: 77 additions & 33 deletions sacrebleu/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import json
import os
import re
Expand Down Expand Up @@ -488,14 +489,18 @@ def get_available_testsets_for_langpair(langpair: str) -> List[str]:


def get_available_origlangs(test_sets, langpair) -> List[str]:
"""Return a list of origlang values in according to the raw SGM files."""
"""Return a list of origlang values according to the raw XML/SGM files."""
if test_sets is None:
return []

origlangs = set()
for test_set in test_sets.split(','):
dataset = DATASETS[test_set]
rawfile = os.path.join(SACREBLEU_DIR, test_set, 'raw', dataset.langpairs[langpair][0])
from .dataset.wmt_xml import WMTXMLDataset
if isinstance(dataset, WMTXMLDataset):
for origlang in dataset._unwrap_wmt21_or_later(rawfile)['origlang']:
origlangs.add(origlang)
if rawfile.endswith('.sgm'):
with smart_open(rawfile) as fin:
for line in fin:
Expand All @@ -505,48 +510,84 @@ def get_available_origlangs(test_sets, langpair) -> List[str]:
return sorted(list(origlangs))


def get_available_subsets(test_sets, langpair) -> List[str]:
"""Return a list of domain values according to the raw XML files and domain/country values from the SGM files."""
if test_sets is None:
return []

subsets = set()
for test_set in test_sets.split(','):
dataset = DATASETS[test_set]
from .dataset.wmt_xml import WMTXMLDataset
if isinstance(dataset, WMTXMLDataset):
rawfile = os.path.join(SACREBLEU_DIR, test_set, 'raw', dataset.langpairs[langpair][0])
fields = dataset._unwrap_wmt21_or_later(rawfile)
if 'domain' in fields:
subsets |= set(fields['domain'])
elif test_set in SUBSETS:
subsets |= set("country:" + v.split("-")[0] for v in SUBSETS[test_set].values())
subsets |= set(v.split("-")[1] for v in SUBSETS[test_set].values())
return sorted(list(subsets))

def filter_subset(systems, test_sets, langpair, origlang, subset=None):
"""Filter sentences with a given origlang (or subset) according to the raw SGM files."""
if origlang is None and subset is None:
return systems
if test_sets is None or langpair is None:
raise ValueError('Filtering for --origlang or --subset needs a test (-t) and a language pair (-l).')

if subset is not None and subset.startswith('country:'):
subset = subset[8:]

re_origlang = re.compile(r'.* origlang="([^"]+)".*\n')
re_id = re.compile(r'.* docid="([^"]+)".*\n')

indices_to_keep = []

for test_set in test_sets.split(','):
dataset = DATASETS[test_set]
rawfile = os.path.join(SACREBLEU_DIR, test_set, 'raw', dataset.langpairs[langpair][0])
if not rawfile.endswith('.sgm'):
raise Exception(f'--origlang and --subset supports only *.sgm files, not {rawfile!r}')
if subset is not None:
if test_set not in SUBSETS:
raise Exception('No subset annotation available for test set ' + test_set)
doc_to_tags = SUBSETS[test_set]
number_sentences_included = 0
with smart_open(rawfile) as fin:
include_doc = False
for line in fin:
if line.startswith('<doc '):
if origlang is None:
include_doc = True
from .dataset.wmt_xml import WMTXMLDataset
if isinstance(dataset, WMTXMLDataset):
fields = dataset._unwrap_wmt21_or_later(rawfile)
domains = fields['domain'] if 'domain' in fields else itertools.repeat(None)
for doc_origlang, doc_domain in zip(fields['origlang'], domains):
if origlang is None:
include_doc = True
else:
if origlang.startswith('non-'):
include_doc = doc_origlang != origlang[4:]
else:
doc_origlang = re_origlang.sub(r'\1', line)
if origlang.startswith('non-'):
include_doc = doc_origlang != origlang[4:]
include_doc = doc_origlang == origlang
if subset is not None and (doc_domain is None or not re.search(subset, doc_domain)):
include_doc = False
indices_to_keep.append(include_doc)
elif rawfile.endswith('.sgm'):
doc_to_tags = {}
if subset is not None:
if test_set not in SUBSETS:
raise Exception('No subset annotation available for test set ' + test_set)
doc_to_tags = SUBSETS[test_set]
with smart_open(rawfile) as fin:
include_doc = False
for line in fin:
if line.startswith('<doc '):
if origlang is None:
include_doc = True
else:
include_doc = doc_origlang == origlang

if subset is not None:
doc_id = re_id.sub(r'\1', line)
if not re.search(subset, doc_to_tags.get(doc_id, '')):
include_doc = False
if line.startswith('<seg '):
indices_to_keep.append(include_doc)
number_sentences_included += 1 if include_doc else 0
doc_origlang = re_origlang.sub(r'\1', line)
if origlang.startswith('non-'):
include_doc = doc_origlang != origlang[4:]
else:
include_doc = doc_origlang == origlang

if subset is not None:
doc_id = re_id.sub(r'\1', line)
if not re.search(subset, doc_to_tags.get(doc_id, '')):
include_doc = False
if line.startswith('<seg '):
indices_to_keep.append(include_doc)
else:
raise Exception(f'--origlang and --subset supports only WMT *.xml and *.sgm files, not {rawfile!r}')
return [[sentence for sentence, keep in zip(sys, indices_to_keep) if keep] for sys in systems]


Expand All @@ -565,8 +606,9 @@ def print_subset_results(metrics, full_system, full_refs, args):
subsets = [None]
if args.subset is not None:
subsets += [args.subset]
elif all(t in SUBSETS for t in args.test_set.split(',')):
subsets += COUNTRIES + DOMAINS
else:
subsets += get_available_subsets(args.test_set, args.langpair)

for subset in subsets:
system, *refs = filter_subset(
[full_system, *full_refs], args.test_set, args.langpair, origlang, subset)
Expand All @@ -575,9 +617,11 @@ def print_subset_results(metrics, full_system, full_refs, args):
continue

key = f'origlang={origlang}'
if subset in COUNTRIES:
key += f' country={subset}'
elif subset in DOMAINS:
if subset is None:
key += ' domain=ALL'
elif subset.startswith('country:'):
key += f' country={subset[8:]}'
else:
key += f' domain={subset}'

for metric in metrics.values():
Expand All @@ -592,4 +636,4 @@ def print_subset_results(metrics, full_system, full_refs, args):
print(f'{key}: sentences={n_system:<6} {score.name:<{max_metric_width}} = {score.score:.{w}f}')

# import at the end to avoid circular import
from .dataset import DATASETS, SUBSETS, DOMAINS, COUNTRIES # noqa: E402
from .dataset import DATASETS, SUBSETS # noqa: E402
10 changes: 5 additions & 5 deletions test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,16 @@ declare -A EXPECTED
EXPECTED["${CMD} -t wmt16,wmt17 -l en-fi --echo ref | ${CMD} -b -w 4 -t wmt16/B,wmt17/B -l en-fi"]=53.7432
EXPECTED["${CMD} -t wmt16,wmt17 -l en-fi --echo ref | ${CMD} -b -w 4 -t wmt16/B,wmt17/B -l en-fi --origlang=en"]=18.9054
EXPECTED["${CMD} -t wmt17 -l en-fi --echo ref | ${CMD} -b -t wmt17/B -l en-fi --detail"]="55.6
origlang=en : sentences=1502 BLEU = 21.4
origlang=fi : sentences=1500 BLEU = 100.0"
origlang=en domain=ALL : sentences=1502 BLEU = 21.4
origlang=fi domain=ALL : sentences=1500 BLEU = 100.0"
EXPECTED["${CMD} -t wmt18,wmt19 -l en-de --echo=src | ${CMD} -t wmt18,wmt19 -l en-de -b --detail"]="3.6
origlang=de : sentences=1498 BLEU = 3.6
origlang=en : sentences=3497 BLEU = 3.5
origlang=de domain=ALL : sentences=1498 BLEU = 3.6
origlang=en domain=ALL : sentences=3497 BLEU = 3.5
origlang=en domain=business : sentences=241 BLEU = 3.4
origlang=en country=EU : sentences=265 BLEU = 2.5
origlang=en country=GB : sentences=913 BLEU = 3.1
origlang=en country=OTHER : sentences=801 BLEU = 2.5
origlang=en country=US : sentences=1518 BLEU = 4.2
origlang=en domain=business : sentences=241 BLEU = 3.4
origlang=en domain=crime : sentences=570 BLEU = 3.6
origlang=en domain=entertainment : sentences=322 BLEU = 5.1
origlang=en domain=politics : sentences=959 BLEU = 3.0
Expand Down

0 comments on commit dc301a5

Please sign in to comment.