diff --git a/.gitignore b/.gitignore index f5790c7..0b53a2b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +*.egg-info +*.pytest_cache + # Byte-compiled / optimized / DLL files /tmp __pycache__/ diff --git a/CHANGELOG.md b/CHANGELOG.md index c0e8cd6..d047185 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,18 @@ # Change Log -## v0.3.0 - WIP +## v0.3.1 - 20211028 + +- Add support for recipes; list-recipe get-recipe subcommands added +- add support for viewing stats of dataset; words, chars, segs +- FIX url for UN dev and test sets (source was updated so we updated too) +- Multilingual experiment support; ISO 639-3 code `mul` implies multilingual; e.g. mul-eng or eng-mul +- `--dev` accepts multiple datasets, and merges it (useful for multilingual experiments) +- tar files are extracted before read (performance improvements) +- setup.py: version and descriptions accessed via regex + +--- + +## v0.3.0 - 20211021 > Big Changes: BCP-47, data compression diff --git a/README.md b/README.md index 37ee6ad..ac3eb91 100644 --- a/README.md +++ b/README.md @@ -240,6 +240,23 @@ $ tree data/deu-eng/ └── train.stats.json ``` +## Recipes + +> Since v0.3.1 + +Recipe is a set of datasets nominated for train, dev, and tests, and are meant to improve reproducibility of experiments. +Recipes are loaded from +1. Default: [`mtdata/recipe/recipes.yml`](mtdata/recipe/recipes.yml) from source code +2. Cache dir: `$MTDATA/mtdata.recipe.yml` where `$MTDATA` has default of `~/.mtdata` +3. Current dir: `$PWD/mtdata.recipe.yml` + +See [`mtdata/recipe/recipes.yml`](mtdata/recipe/recipes.yml) for format and examples. + +```bash +mtdata list-recipe # see all recipes +mtdata get-recipe -ri -o # get recipe, recreate dataset +``` + ## Language Name Standardization ### ISO 639 3 Internally, all language codes are mapped to ISO-639 3 codes. @@ -287,6 +304,9 @@ print(iso3_code('eNgLIsH', fail_error=True)) # case doesnt matter ``` ### BCP-47 + +> Since v0.3.0 + We used ISO 639-3 from the beginning, however, we soon faced the limitation that ISO 639-3 cannot distinguish script and region variants of language. So we have upgraded to BCP-47 like language tags in `v0.3.0`. * BCP47 uses two-letter codes to some and three-letter codes to the rest, we use three-letter codes to all languages. @@ -305,9 +325,9 @@ Our tags are of form `xxx_Yyyy_ZZ` where Notes: * Region is preserved when available and left blank when unavailable * Script `Yyyy` is forcibly suppressed in obvious cases. E.g. `eng` is written using `Latn` script, writing `eng-Latn` is just awkward to read as `Latn` is default we suppress `Latn` script for English. On the other hand a language like `Kannada` is written using `Knda` script (`kan-Knda` -> `kan`), but occasionally written using `Latn` script, so `kan-Latn` is not suppressed. -* The information about whats default script is obtained from IANA language code registry +* The information about what is default script is obtained from IANA language code registry +* Language code `mul` stands for _multiple languages, and is used as a placeholder for multilingual datasets (See `mul-eng` to represent many-to-English dataset recipes in [(mtdata/recipe/recipes.yml](mtdata/recipe/recipes.yml)) - #### Example: To inspect parsing/mapping, use `python -m mtdata.iso.bcp47 ` diff --git a/mtdata/__init__.py b/mtdata/__init__.py index 3948fc2..9053434 100644 --- a/mtdata/__init__.py +++ b/mtdata/__init__.py @@ -3,27 +3,24 @@ # Author: Thamme Gowda [tg (at) isi (dot) edu] # Created: 4/4/20 -__version__ = '0.3.0' +__version__ = '0.3.1' __description__ = 'mtdata is a tool to download datasets for machine translation' __author__ = 'Thamme Gowda' import logging as log from pathlib import Path import os +import enlighten +from ruamel.yaml import YAML - +yaml = YAML() debug_mode = False _log_format = '%(asctime)s %(module)s.%(funcName)s:%(lineno)s %(levelname)s:: %(message)s' log.basicConfig(level=log.INFO, datefmt='%Y-%m-%d %H:%M:%S', format=_log_format) cache_dir = Path(os.environ.get('MTDATA', '~/.mtdata')).expanduser() cached_index_file = cache_dir / f'mtdata.index.{__version__}.pkl' - -try: - import enlighten - pbar_man = enlighten.get_manager() -except ImportError as e: - log.warning("enlighten package maybe required. please run 'pip install englighten'") - log.warning(e) +FILE_LOCK_TIMEOUT = 2 * 60 * 60 # 2 hours +pbar_man = enlighten.get_manager() class MTDataException(Exception): diff --git a/mtdata/__main__.py b/mtdata/__main__.py index f13fa7f..1ba0236 100644 --- a/mtdata/__main__.py +++ b/mtdata/__main__.py @@ -2,12 +2,15 @@ # # Author: Thamme Gowda [tg (at) isi (dot) edu] # Created: 4/4/20 -import errno -from mtdata import main, log -if __name__ == '__main__': +def main(): + from mtdata import main try: main.main() except BrokenPipeError as e: # this happens when piped to '| head' which aborts pipe after limit. And that is okay. pass + + +if __name__ == '__main__': + main() diff --git a/mtdata/cache.py b/mtdata/cache.py index 129fb4c..b88fa20 100644 --- a/mtdata/cache.py +++ b/mtdata/cache.py @@ -8,8 +8,9 @@ from dataclasses import dataclass from pathlib import Path from mtdata.index import Entry -from mtdata import log, __version__, pbar_man, MTDataException +from mtdata import log, __version__, pbar_man, MTDataException, FILE_LOCK_TIMEOUT from mtdata.utils import ZipPath, TarPath +from mtdata.parser import Parser from typing import List, Union import portalocker @@ -19,7 +20,6 @@ import requests import math -MAX_TIMEOUT = 2 * 60 * 60 # 2 hours headers = {'User-Agent': f'mtdata downloader {__version__}; cURL and wget like.'} @@ -43,6 +43,46 @@ def get_entry(self, entry: Entry, fix_missing=True) -> Union[Path, List[Path]]: local = self.get_local_in_paths(path=local, entry=entry) return local + def get_stats(self, entry: Entry): + path = self.get_entry(entry) + parser = Parser(path, ext=entry.in_ext or None, ent=entry) + count, skips, noise = 0, 0, 0 + toks = [0, 0] + chars = [0, 0] + for rec in parser.read_segs(): + if len(rec) < 2 or not rec[0] or not rec[1]: + skips += 1 + continue + if entry.is_noisy(seg1=rec[0], seg2=rec[1]): + noise += 1 + skips += 1 + continue + count += 1 + s1, s2 = rec[:2] # get the first two recs + chars[0] += len(s1) + chars[1] += len(s2) + s1_tok, s2_tok = s1.split(), s2.split() + toks[0] += len(s1_tok) + toks[1] += len(s2_tok) + + l1, l2 = entry.did.langs + l1, l2 = l1.lang, l2.lang + assert count > 0, f'No valid records are found for {entry.did}' + if l2 < l1: + l1, l2 = l2, l1 + toks = toks[1], toks[0] + chars = chars[1], chars[0] + return { + 'id': str(entry.did), + 'segs': count, + 'segs_err': skips, + 'segs_noise': noise, + f'{l1}_toks': toks[0], + f'{l2}_toks': toks[1], + f'{l1}_chars': chars[0], + f'{l2}_chars': chars[0] + } + def get_flag_file(self, file: Path): return file.with_name(file.name + '._valid') @@ -74,20 +114,18 @@ def opus_xces_format(self, entry, fix_missing=True) -> List[Path]: l2_path = self.get_local_path(l2_url, fix_missing=fix_missing) return [align_file, l1_path, l2_path] - def get_local_in_paths(self, path:Path, entry: Entry,): + def get_local_in_paths(self, path: Path, entry: Entry,): in_paths = entry.in_paths if zipfile.is_zipfile(path): with zipfile.ZipFile(path) as root: in_paths = self.match_globs(names=root.namelist(), globs=in_paths) return [ZipPath(path, p) for p in in_paths] # stdlib is buggy, so I made a workaround elif tarfile.is_tarfile(path): - with tarfile.open(path, encoding='utf-8') as root: - in_paths = self.match_globs(names=root.getnames(), globs=in_paths) return [TarPath(path, p) for p in in_paths] else: raise Exception(f'Unable to read {entry.did}; the file is neither zip nor tar') - def download(self, url: str, save_at: Path): + def download(self, url: str, save_at: Path, timeout=(5, 10)): valid_flag = self.get_flag_file(save_at) lock_file = valid_flag.with_suffix("._lock") if valid_flag.exists() and save_at.exists(): @@ -95,18 +133,18 @@ def download(self, url: str, save_at: Path): save_at.parent.mkdir(parents=True, exist_ok=True) log.info(f"Acquiring lock on {lock_file}") - with portalocker.Lock(lock_file, 'w', timeout=MAX_TIMEOUT) as fh: + with portalocker.Lock(lock_file, 'w', timeout=FILE_LOCK_TIMEOUT) as fh: # check if downloaded by other parallel process if valid_flag.exists() and save_at.exists(): return save_at - log.info(f"Downloading {url} --> {save_at}") - resp = requests.get(url=url, allow_redirects=True, headers=headers, stream=True) + log.info(f"GET {url} → {save_at}") + resp = requests.get(url=url, allow_redirects=True, headers=headers, stream=True, timeout=timeout) assert resp.status_code == 200, resp.status_code buf_size = 2 ** 10 n_buffers = math.ceil(int(resp.headers.get('Content-Length', '0')) / buf_size) or None desc = url - if len(desc) > 40: - desc = desc[:30] + '...' + desc[-10:] + if len(desc) > 60: + desc = desc[:30] + '...' + desc[-28:] with pbar_man.counter(color='green', total=n_buffers, unit='KiB', leave=False, desc=f"{desc}") as pbar, open(save_at, 'wb', buffering=2**24) as out: for chunk in resp.iter_content(chunk_size=buf_size): diff --git a/mtdata/data.py b/mtdata/data.py index 3deb126..9b02a7e 100644 --- a/mtdata/data.py +++ b/mtdata/data.py @@ -6,7 +6,8 @@ from pathlib import Path from mtdata import log, pbar_man, cache_dir as CACHE_DIR, MTDataException from mtdata.cache import Cache -from mtdata.index import INDEX, Entry, DatasetId, LangPair +from mtdata.index import INDEX +from mtdata.entry import Entry, DatasetId, LangPair from mtdata.iso.bcp47 import bcp47, BCP47Tag from mtdata.parser import Parser from mtdata.utils import IO @@ -24,7 +25,7 @@ def __init__(self, dir: Path, langs: LangPair, cache_dir: Path, drop_train_noise drop_test_noise=False, drop_dupes=False, drop_tests=False, compress=False): self.dir = dir self.langs = langs - assert len(langs) == 2, 'Only parallel datasets are supported for now and expect two langs' + assert len(langs) == 2, f'Only parallel datasets are supported for now and expected two langs; {langs}' self.cache = Cache(cache_dir) self.train_parts_dir = dir / 'train-parts' # will be merged later @@ -50,7 +51,7 @@ def resolve_entries(cls, dids: List[DatasetId]): @classmethod def prepare(cls, langs, out_dir: Path, train_dids: Optional[List[DatasetId]] = None, - test_dids: Optional[List[DatasetId]] = None, dev_did: Optional[DatasetId] = None, + test_dids: Optional[List[DatasetId]] = None, dev_dids: Optional[List[DatasetId]] = None, cache_dir: Path = CACHE_DIR, merge_train=False, drop_noise: Tuple[bool, bool] = (True, False), compress=False, drop_dupes=False, drop_tests=False): drop_train_noise, drop_test_noise = drop_noise @@ -58,31 +59,40 @@ def prepare(cls, langs, out_dir: Path, train_dids: Optional[List[DatasetId]] = N assert train_dids or test_dids, 'Either train_names or test_names should be given' # First, resolve and check if they exist before going to process them. # Fail early for typos in name - train_entries, test_entries = None, None - if test_dids: - test_entries = cls.resolve_entries(test_dids) - if train_dids: - train_entries = cls.resolve_entries(train_dids) + all_dids = (train_dids or []) + (dev_dids or []) + (test_dids or []) + cls.resolve_entries(all_dids) dataset = cls(dir=out_dir, langs=langs, cache_dir=cache_dir, drop_train_noise=drop_train_noise, drop_test_noise=drop_test_noise, drop_dupes=drop_dupes, drop_tests=drop_tests) - if test_entries: # tests are smaller so quicker; no merging needed - dataset.add_test_entries(test_entries) - if dev_did: - dev_entry = cls.resolve_entries([dev_did])[0] - dataset.add_dev_entry(dev_entry) - if train_entries: # this might take some time - dataset.add_train_entries(train_entries, merge_train=merge_train, compress=compress) + dev_entries, test_entries = [], [] + if test_dids: # tests are smaller so quicker; no merging needed + test_entries = cls.resolve_entries(test_dids) + dataset.add_test_entries(test_entries) + if dev_dids: + dev_entries = cls.resolve_entries(dev_dids) + dataset.add_dev_entries(dev_entries) + if train_dids: # this might take some time + train_entries = cls.resolve_entries(train_dids) + drop_hashes = None + if drop_tests: + pair_files = [] + for ent in dev_entries + test_entries: + p1, p2 = dataset.get_paths(dataset.tests_dir, ent) + if BCP47Tag.check_compat_swap(langs, ent.did.langs, fail_on_incompat=True)[1]: + p1, p2 = p2, p1 # swap + pair_files.append((p1, p2)) + test_pair_hash, test_seg_hash = dataset.hash_all_bitexts(pair_files) + drop_hashes = test_pair_hash | test_seg_hash # set union + dataset.add_train_entries(train_entries, merge_train=merge_train, compress=compress, + drop_hashes=drop_hashes) return dataset - def hash_all_held_outs(self): - lang1, lang2 = self.langs - paired_files = self.find_bitext_pairs(self.tests_dir, lang1, lang2) + def hash_all_bitexts(self, paired_files): paired_hashes = set() seg_hashes = set() - for name, (if1, if2) in paired_files.items(): + for if1, if2 in paired_files: for seg1, seg2 in self.read_parallel(if1, if2): paired_hashes.add(hash((seg1, seg2))) paired_hashes.add(hash((seg2, seg1))) @@ -90,16 +100,22 @@ def hash_all_held_outs(self): seg_hashes.add(hash(seg2)) return paired_hashes, seg_hashes - def add_train_entries(self, entries, merge_train=False, compress=False): + def add_train_entries(self, entries, merge_train=False, compress=False, drop_hashes=None): self.add_parts(self.train_parts_dir, entries, drop_noise=self.drop_train_noise, compress=compress, desc='Training sets') if not merge_train: return lang1, lang2 = self.langs - paired_files = self.find_bitext_pairs(self.train_parts_dir, lang1, lang2) + # paired_files = self.find_bitext_pairs(self.train_parts_dir, lang1, lang2) + paired_files = {} + for ent in entries: + e1, e2 = self.get_paths(self.train_parts_dir, ent) + _, swapped = BCP47Tag.check_compat_swap(self.langs, ent.did.langs, fail_on_incompat=True) + if swapped: + e1, e2 = e2, e1 + paired_files[str(ent.did)] = e1, e2 log.info(f"Going to merge {len(paired_files)} files as one train file") - compress_ext = f'.{DEF_COMPRESS}' if compress else '' l1_ext = f'{lang1}{compress_ext}' l2_ext = f'{lang2}{compress_ext}' @@ -112,9 +128,6 @@ def add_train_entries(self, entries, merge_train=False, compress=False): test_overlap_skips=coll.defaultdict(int), selected=coll.defaultdict(int)) train_hashes = set() - tests_pair_hashes, tests_seg_hashes = set(), set() - if self.drop_tests: - tests_pair_hashes, tests_seg_hashes = self.hash_all_held_outs() with IO.writer(of1) as w1, IO.writer(of2) as w2, IO.writer(of3) as w3: with pbar_man.counter(color='green', total=len(paired_files), unit='it', desc="Merging", @@ -124,9 +137,8 @@ def add_train_entries(self, entries, merge_train=False, compress=False): counts['total'][name] += 1 if self.drop_dupes or self.drop_tests: hash_val = hash((seg1, seg2)) - if self.drop_tests and (hash_val in tests_pair_hashes - or hash(seg1) in tests_seg_hashes - or hash(seg2) in tests_seg_hashes): + if drop_hashes and (hash_val in drop_hashes or hash(seg1) in drop_hashes + or hash(seg2) in drop_hashes): counts['test_overlap_skips'][name] += 1 continue if self.drop_dupes: @@ -188,34 +200,63 @@ def read_parallel(cls, file1: Path, file2: Path): def add_test_entries(self, entries): self.add_parts(self.tests_dir, entries, drop_noise=self.drop_test_noise, desc='Held-out sets') - if len(entries) <= 4: + if len(entries) <= 20: for i, entry in enumerate(entries, start=1): self.link_to_part(entry, self.tests_dir, f"test{i}") def link_to_part(self, entry, data_dir, link_name): """Create link such as test, dev""" + l1, l2 = self.langs l1_path, l2_path = self.get_paths(data_dir, entry) l1_path, l2_path = l1_path.relative_to(self.dir), l2_path.relative_to(self.dir) - l1_link = self.dir / f'{link_name}.{self.langs[0]}' - l2_link = self.dir / f'{link_name}.{self.langs[1]}' + l1_link = self.dir / f'{link_name}.{l1.lang}' + l2_link = self.dir / f'{link_name}.{l2.lang}' for lnk in [l1_link, l2_link]: - lnk.unlink(missing_ok=True) - if BCP47Tag.are_compatible(self.langs[0], entry.did.langs[0]): - assert not BCP47Tag.are_compatible(self.langs[0], entry.did.langs[1]) - # cool! no swapping needed - elif BCP47Tag.are_compatible(self.langs[0], entry.did.langs[1]): - l1_path, l2_path = l2_path, l1_path # swapped - else: - raise Exception("This should not be happening! :(") - + if lnk.exists(): + lnk.unlink() + compat, swapped = BCP47Tag.check_compat_swap(self.langs, entry.did.langs) + if not compat: + raise Exception(f"Unable to unify language IDs: {self.langs} x {entry.did.langs}") + if swapped: + l1_path, l2_path = l2_path, l1_path l1_link.symlink_to(l1_path) l2_link.symlink_to(l2_path) - def add_dev_entry(self, entry): - n_good, n_bad = self.add_part(self.tests_dir, entry, drop_noise=self.drop_test_noise) - log.info(f"{entry.did} : found {n_good:} segments and {n_bad:} errors") - # create a link - self.link_to_part(entry, self.tests_dir, "dev") + def add_dev_entries(self, entries): + assert entries + for entry in entries: + n_good, n_bad = self.add_part(self.tests_dir, entry, drop_noise=self.drop_test_noise) + log.info(f"{entry.did} : found {n_good:} segments and {n_bad:} errors") + if len(entries) == 1: + # create a link to the only one + self.link_to_part(entries[0], self.tests_dir, "dev") + else: + l1, l2 = self.langs + out_paths = (self.dir / f'dev.{l1.lang}', self.dir / f'dev.{l2.lang}') + in_paths = [] + for ent in entries: + e1, e2 = self.get_paths(self.tests_dir, ent) + compat, swapped = BCP47Tag.check_compat_swap(self.langs, ent.did.langs) + if not compat: + raise Exception(f"Unable to unify language IDs {self.langs} x {ent.did.langs}") + if swapped: + e1, e2 = e2, e1 + in_paths.append((e1, e2)) + self.cat_bitexts(in_paths=in_paths, out_paths=out_paths) + + def cat_bitexts(self, in_paths:List[Tuple[Path, Path]], out_paths: Tuple[Path, Path]): + of1, of2 = out_paths + of1.parent.mkdir(exist_ok=True) + of2.parent.mkdir(exist_ok=True) + with pbar_man.counter(color='green', total=len(in_paths), unit='it', desc="Merging") as pbar,\ + IO.writer(of1) as w1, IO.writer(of2) as w2: + for if1, if2 in in_paths: + assert if1.exists() + assert if2.exists() + for seg1, seg2 in self.read_parallel(if1, if2): + w1.write(seg1 + '\n') + w2.write(seg2 + '\n') + pbar.update() def add_parts(self, dir_path, entries, drop_noise=False, compress=False, desc=None, fail_on_error=False): with pbar_man.counter(color='blue', leave=False, total=len(entries), unit='it', desc=desc, @@ -230,8 +271,7 @@ def add_parts(self, dir_path, entries, drop_noise=False, compress=False, desc=No log.error(f"Unable to add {ent.did}: {e}") if fail_on_error: raise e - else: - log.warning(e) + log.warning(e) @classmethod def get_paths(cls, dir_path: Path, entry: Entry, compress=False) -> Tuple[Path, Path]: @@ -245,6 +285,9 @@ def get_paths(cls, dir_path: Path, entry: Entry, compress=False) -> Tuple[Path, return l1, l2 def add_part(self, dir_path: Path, entry: Entry, drop_noise=False, compress=False): + flag_file = dir_path / f'.valid.{entry.did}' + if flag_file.exists(): + return -1, -1 path = self.cache.get_entry(entry) # swap = entry.is_swap(self.langs) parser = Parser(path, ext=entry.in_ext or None, ent=entry) @@ -267,8 +310,6 @@ def add_part(self, dir_path: Path, entry: Entry, drop_noise=False, compress=Fals if not sent1 or not sent2: skips += 1 continue - # if swap: - # sent2, sent1 = sent1, sent2 sent1 = sent1.replace('\n', ' ').replace('\t', ' ').replace('\r', ' ') sent2 = sent2.replace('\n', ' ').replace('\t', ' ').replace('\r', ' ') f1.write(f'{sent1}\n') @@ -281,6 +322,7 @@ def add_part(self, dir_path: Path, entry: Entry, drop_noise=False, compress=Fals if noise > 0: log.info(f"{entry}: Noise : {noise:,}/{count:,} => {100 * noise / count:.4f}%") log.info(f"wrote {count} lines to {l1} == {l2}") + flag_file.touch() return count, skips diff --git a/mtdata/entry.py b/mtdata/entry.py index 94d7aec..881ff76 100644 --- a/mtdata/entry.py +++ b/mtdata/entry.py @@ -5,6 +5,7 @@ from typing import Tuple, List, Optional, Set, Union from dataclasses import dataclass, field +from mtdata import log from mtdata.iso.bcp47 import BCP47Tag, bcp47 DID_DELIM = '-' # I wanted to use ":", but Windows, they dont like ":" in path! :( @@ -12,6 +13,19 @@ LangPair = Tuple[BCP47Tag, BCP47Tag] +def lang_pair(string) -> LangPair: + parts = string.strip().split('-') + if len(parts) != 2: + msg = f'expected value of form "xxx-yyz" eg "deu-eng"; given {string}' + raise Exception(msg) + std_codes = (bcp47(parts[0]), bcp47(parts[1])) + std_form = '-'.join(str(lang) for lang in std_codes) + if std_form != string: + log.info(f"Suggestion: Use codes {std_form} instead of {string}." + f" Let's make a little space for all languages of our planet 😢.") + return std_codes + + @dataclass(frozen=True) class DatasetId: group: str @@ -117,41 +131,3 @@ def is_noisy(self, seg1, seg2) -> bool: class JW300Entry(Entry): url: Tuple[str, str, str] # (align.xml, src.xml, tgt.xml) - -@dataclass -class Experiment: - langs: Tuple[BCP47Tag, BCP47Tag] # (lang1 , lang2) lang1 -> lang2 - train: List[Entry] # training should be merged from all these - tests: List[Entry] # multiple tests; one of them can be validation set - papers: Set['Paper'] = field(default_factory=set) - - def __post_init__(self): - if any(not isinstance(lang, BCP47Tag) for lang in self.langs): - self.langs = tuple(bcp47(l) for l in self.langs) - for t in self.tests: - assert t - for t in self.train: - assert t - - @classmethod - def make(cls, index, langs: Tuple[str, str], train: List[str], tests: List[str]): - train = [index.get_entry(name, langs) for name in train] - tests = [index.get_entry(name, langs) for name in tests] - return cls(langs, train=train, tests=tests) - - -@dataclass(eq=False) # see for hash related issues: https://stackoverflow.com/a/52390734/1506477 -class Paper: # or Article - - name: str # author1-etal-year - title: str # title - url: str # Paper url to be sure - cite: str # bibtex would be nice to display - experiments: List[Experiment] - - langs: Set[Tuple[str, str]] = None - - def __post_init__(self): - self.langs = self.langs or set(exp.langs for exp in self.experiments) - for exp in self.experiments: - exp.papers.add(self) diff --git a/mtdata/index/__init__.py b/mtdata/index/__init__.py index dc6620f..3b19a98 100644 --- a/mtdata/index/__init__.py +++ b/mtdata/index/__init__.py @@ -11,7 +11,7 @@ from pybtex.database import parse_file as parse_bib_file from mtdata import log, cached_index_file, __version__ -from mtdata.entry import Entry, Paper, DatasetId, LangPair +from mtdata.entry import Entry, DatasetId from mtdata.iso.bcp47 import bcp47, BCP47Tag REFS_FILE = Path(__file__).parent / 'refs.bib' @@ -96,11 +96,6 @@ def add_entry(self, entry: Entry): assert key not in self.entries, f'{key} is a duplicate' self.entries[key] = entry - def add_paper(self, paper: Paper): - assert isinstance(paper, Paper) - assert paper.name not in self.papers, f'{paper.name} is a duplicate' - self.papers[paper.name] = paper - def get_entries(self): return self.entries.values() diff --git a/mtdata/index/unitednations.py b/mtdata/index/unitednations.py index 010cba6..e166d2d 100644 --- a/mtdata/index/unitednations.py +++ b/mtdata/index/unitednations.py @@ -6,9 +6,11 @@ from mtdata.index import Index, Entry, DatasetId import itertools + def load_all(index: Index): cite = index.ref_db.get_bibtex('ziemski-etal-2016-united') url = "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.testsets.tar.gz" + url = "https://drive.google.com/uc?export=download&id=13GI1F1hvwpMUGBSa0QC6ov4eE57GC_Zx" # they changed it! langs = ['en', 'ar', 'fr', 'es', 'ru', 'zh'] for split in ['dev', 'test']: for l1, l2 in itertools.combinations(langs, 2): diff --git a/mtdata/iso/bcp47.py b/mtdata/iso/bcp47.py index 5991854..67fdc03 100644 --- a/mtdata/iso/bcp47.py +++ b/mtdata/iso/bcp47.py @@ -12,10 +12,12 @@ import json from collections import namedtuple from pathlib import Path -from typing import Optional, Union +from typing import Optional, Union, Tuple from functools import lru_cache from mtdata.iso import iso3_code +MULTI_LANG = 'mul' # multilang + def load_json(path: Path): assert path.exists() @@ -65,6 +67,40 @@ def are_compatible(cls, tag1, tag2): tag1 = tag1 if isinstance(tag1, cls) else bcp47(tag1) return tag1.is_compatible(tag2) + @classmethod + def check_compat_swap(cls, pair1: Tuple['BCP47Tag', 'BCP47Tag'], pair2: Tuple['BCP47Tag', 'BCP47Tag'], + fail_on_incompat=False) -> Tuple[bool, bool]: + a, b = pair1 + aa, bb = pair2 + # we cant support multiling on both sides + assert not (a.lang == MULTI_LANG and b.lang == MULTI_LANG), f'Multilingual on both side is not supported' + assert not (aa.lang == MULTI_LANG and bb.lang == MULTI_LANG), f'Multilingual on both side is not supported' + compat = False + swap = False + if a.is_compatible(aa): + if b.is_compatible(bb) or b.lang == MULTI_LANG: + assert not a.is_compatible(bb), f'{pair1} x {pair2} is ambiguous' + assert not b.is_compatible(aa), f'{pair1} x {pair2} is ambiguous' + compat = True + swap = False + elif a.is_compatible(bb): + if b.is_compatible(aa) or b.lang == MULTI_LANG: + assert not a.is_compatible(aa) # it wont be as already checked in prior case if + assert not b.is_compatible(bb), f'{pair1} x {pair2} is ambiguous' + compat = True + swap = True + elif a.lang == MULTI_LANG: + if b.is_compatible(bb): + compat = True + swap = False + elif b.is_compatible(aa): + compat = True, + swap = True + if not compat and fail_on_incompat: + raise Exception(f'Unable to match langs : {pair1} x {pair2}') + # else False, False + return compat, swap + class BCP47Parser: diff --git a/mtdata/main.py b/mtdata/main.py index 21012db..4f7cc59 100644 --- a/mtdata/main.py +++ b/mtdata/main.py @@ -4,13 +4,13 @@ # Created: 4/4/20 import argparse from pathlib import Path +from typing import List from collections import defaultdict import mtdata from mtdata import log, __version__, cache_dir as CACHE_DIR, cached_index_file -from mtdata.entry import DatasetId, LangPair +from mtdata.entry import DatasetId, lang_pair from mtdata.utils import IO -from mtdata.iso.bcp47 import bcp47 def list_data(langs, names, not_names=None, full=False): @@ -24,7 +24,7 @@ def list_data(langs, names, not_names=None, full=False): print(f"Total {len(entries)} entries") -def get_data(langs, out_dir, train_dids=None, test_dids=None, dev_did=None, merge_train=False, compress=False, +def get_data(langs, out_dir, train_dids=None, test_dids=None, dev_dids=None, merge_train=False, compress=False, drop_dupes=False, drop_tests=False, **kwargs): if kwargs: log.warning(f"Args are ignored: {kwargs}") @@ -32,15 +32,12 @@ def get_data(langs, out_dir, train_dids=None, test_dids=None, dev_did=None, merg assert train_dids or test_dids, 'Required --train or --test or both' dataset = Dataset.prepare( langs, train_dids=train_dids, test_dids=test_dids, out_dir=out_dir, - dev_did=dev_did, cache_dir=CACHE_DIR, merge_train=merge_train, compress=compress, + dev_dids=dev_dids, cache_dir=CACHE_DIR, merge_train=merge_train, compress=compress, drop_dupes=drop_dupes, drop_tests=drop_tests) cli_sig = f'-l {"-".join(str(l) for l in langs)}' - if train_dids: - cli_sig += f' -tr {" ".join(str(d) for d in train_dids)}' - if test_dids: - cli_sig += f' -ts {" ".join(str(d) for d in test_dids)}' - if dev_did: - cli_sig += f' -dv {dev_did}' + for flag, dids in [('-tr', train_dids), ('-ts', test_dids), ('-dv', dev_dids)]: + if dids: + cli_sig += f' {flag} {" ".join(map(str, dids))}' for flag, val in [('--merge', merge_train), ('--compress', compress), ('-dd', drop_dupes), ('-dt', drop_tests)]: if val: cli_sig += ' ' + flag @@ -75,8 +72,32 @@ def generate_report(langs, names, not_names=None, format='plain'): print(f'{key}\t{val:,}') -def list_experiments(args): - raise Exception("Not implemented yet") +def list_recipes(): + from mtdata.recipe import print_all, RECIPES + log.info(f"Found {len(RECIPES)} recipes") + print_all(RECIPES.values()) + + +def get_recipe(recipe_id, out_dir: Path, compress=False, drop_dupes=False, drop_tests=False, **kwargs): + if kwargs: + log.warning(f"Args are ignored: {kwargs}") + from mtdata.recipe import RECIPES + recipe = RECIPES.get(recipe_id) + if not recipe: + raise ValueError(f'recipe {recipe_id} not found. See "mtdata list-recipe"') + + get_data(langs=recipe.langs, train_dids=recipe.train, dev_dids=recipe.dev, test_dids=recipe.test, + merge_train=True, out_dir=out_dir, compress=compress, drop_dupes=drop_dupes, drop_tests=drop_tests) + + +def show_stats(*dids: DatasetId): + from mtdata.index import INDEX as index + from mtdata.cache import Cache + cache = Cache(CACHE_DIR) + for did in dids: + entry = index[did] + stats = cache.get_stats(entry) + print(stats) class MyFormatter(argparse.ArgumentDefaultsHelpFormatter): @@ -87,19 +108,6 @@ def _split_lines(self, text, width: int): return super()._split_lines(text, width) -def lang_pair(string) -> LangPair: - parts = string.split('-') - if len(parts) != 2: - msg = f'expected value of form "xxx-yyz" eg "deu-eng"; given {string}' - raise argparse.ArgumentTypeError(msg) - std_codes = (bcp47(parts[0]), bcp47(parts[1])) - std_form = '-'.join(str(lang) for lang in std_codes) - if std_form != string: - log.info(f"Suggestion: Use codes {std_form} instead of {string}." - f" Let's make a little space for all languages of our planet 😢.") - return std_codes - - def add_boolean_arg(parser: argparse.ArgumentParser, name, dest=None, default=False, help=''): group = parser.add_mutually_exclusive_group() dest = dest or name @@ -121,8 +129,9 @@ def parse_args(): help='''R| "list" - List the available entries "get" - Downloads the entry files and prepares them for experiment -"list-exp" - List the (well) known papers and datasets used in their experiments -"get-exp" - Get the datasets used in the specified experiment from "list-exp" +"list-recipe" - List the (well) known papers and dataset recipes used in their experiments +"get-recipe" - Get the datasets used in the specified experiment from "list-recipe" +"stats" - Get stats of dataset" ''') list_p = sub_ps.add_parser('list', formatter_class=MyFormatter) @@ -147,25 +156,35 @@ def parse_args(): e.g. "-ts Statmt-newstest_deen-2019-deu-eng Statmt-newstest_deen-2020-deu-eng ". You may also use shell expansion if your shell supports it. e.g. "-ts Statmt-newstest_deen-20{19,20}-deu-eng" ''') - get_p.add_argument('-dv', '--dev', metavar='ID', dest='dev_did', type=DatasetId.parse, required=False, - help='''R|Dataset to be used for development (aka validation). + get_p.add_argument('-dv', '--dev', metavar='ID', dest='dev_dids', type=DatasetId.parse, nargs='*', + help='''R|Dataset to be used for development (aka validation). e.g. "-dev Statmt-newstest_deen-2017-deu-eng"''') add_boolean_arg(get_p, 'merge', dest='merge_train', default=False, help='Merge train into a single file') - get_p.add_argument(f'--compress', action='store_true', default=False, help="Keep the files compressed") - get_p.add_argument('-dd', f'--dedupe', '--drop-dupes', dest='drop_dupes', action='store_true', default=False, - help="Remove duplicate (src, tgt) pairs in training (if any); valid when --merge. " - "Not recommended for large datasets. ") - get_p.add_argument('-dt', f'--drop-tests', dest='drop_tests', action='store_true', default=False, - help="Remove dev/test sentences from training sets (if any); valid when --merge") - get_p.add_argument('-o', '--out', dest='out_dir', type=Path, required=True, help='Output directory name') + + def add_getter_args(parser): + parser.add_argument(f'--compress', action='store_true', default=False, help="Keep the files compressed") + parser.add_argument('-dd', f'--dedupe', '--drop-dupes', dest='drop_dupes', action='store_true', default=False, + help="Remove duplicate (src, tgt) pairs in training (if any); valid when --merge. " + "Not recommended for large datasets. ") + parser.add_argument('-dt', f'--drop-tests', dest='drop_tests', action='store_true', default=False, + help="Remove dev/test sentences from training sets (if any); valid when --merge") + parser.add_argument('-o', '--out', dest='out_dir', type=Path, required=True, help='Output directory name') + + add_getter_args(get_p) report_p = sub_ps.add_parser('report', formatter_class=MyFormatter) - report_p.add_argument('-l', '--langs', metavar='L1-L2', type=lang_pair, - help='Language pairs; e.g.: deu-eng') - report_p.add_argument('-n', '--names', metavar='NAME', nargs='*', - help='Name of dataset set; eg europarl_v9.') + report_p.add_argument('-l', '--langs', metavar='L1-L2', type=lang_pair, help='Language pairs; e.g.: deu-eng') + report_p.add_argument('-n', '--names', metavar='NAME', nargs='*', help='Name of dataset set; eg europarl_v9.') report_p.add_argument('-nn', '--not-names', metavar='NAME', nargs='*', help='Exclude these names') + listr_p = sub_ps.add_parser('list-recipe', formatter_class=MyFormatter) + getr_p = sub_ps.add_parser('get-recipe', formatter_class=MyFormatter) + getr_p.add_argument('-ri', '--recipe-id', type=str, help='Recipe ID', required=True) + add_getter_args(getr_p) + + stats_p = sub_ps.add_parser('stats', formatter_class=MyFormatter) + stats_p.add_argument('did', nargs='+', type=DatasetId.parse, help="Show stats of dataset IDs") + args = p.parse_args() if args.verbose: log.getLogger().setLevel(level=log.DEBUG) @@ -179,13 +198,16 @@ def main(): bak_file = cached_index_file.with_suffix(".bak") log.info(f"Invalidate index: {cached_index_file} -> {bak_file}") cached_index_file.rename(bak_file) - if args.task == 'list': list_data(args.langs, args.names, not_names=args.not_names, full=args.full) elif args.task == 'get': get_data(**vars(args)) - elif args.task == 'list_exp': - list_experiments(args) + elif args.task == 'list-recipe': + list_recipes() + elif args.task == 'get-recipe': + get_recipe(**vars(args)) + elif args.task == 'stats': + show_stats(*args.did) elif args.task == 'report': generate_report(args.langs, names=args.names, not_names=args.not_names) else: diff --git a/mtdata/parser.py b/mtdata/parser.py index d362e2a..edc9ada 100644 --- a/mtdata/parser.py +++ b/mtdata/parser.py @@ -6,7 +6,7 @@ from typing import Optional, Union, Tuple, List from dataclasses import dataclass from pathlib import Path -from mtdata import log +from mtdata import log, pbar_man from mtdata.entry import Entry from itertools import zip_longest @@ -85,17 +85,22 @@ def read_segs(self): raise Exception(f'Not supported {self.ext} : {p}') if len(readers) == 1: - yield from readers[0] + data = readers[0] elif self.ext == 'tmx' or self.ext == 'tsv': - for reader in readers: - yield from reader + data = (rec for reader in readers for rec in reader) # flatten all readers elif len(readers) == 2: - for seg1, seg2 in zip_longest(*readers): - if seg1 is None or seg2 is None: - raise Exception(f'{self.paths} have unequal number of segments') - yield seg1, seg2 + def _zip_n_check(): + for seg1, seg2 in zip_longest(*readers): + if seg1 is None or seg2 is None: + raise Exception(f'{self.paths} have unequal number of segments') + yield seg1, seg2 + data = _zip_n_check() else: raise Exception("This is an error") + with pbar_man.counter(color='green', unit='seg', leave=False, desc=f"Reading {self.ent.did}") as pbar: + for rec in data: + yield rec + pbar.update() def read_plain(self, path): with IO.reader(path) as stream: diff --git a/mtdata/recipe/__init__.py b/mtdata/recipe/__init__.py new file mode 100644 index 0000000..9126818 --- /dev/null +++ b/mtdata/recipe/__init__.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python +# +# +# Author: Thamme Gowda +# Created: 10/27/21 +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import List, Union, Dict, Optional + + +from mtdata import yaml, cache_dir, log +from mtdata.entry import lang_pair, LangPair, DatasetId, BCP47Tag, bcp47 + + +_def_recipes: Path = Path(__file__).parent / 'recipes.yml' +_cwd_recipes: Path = Path('.').expanduser() / 'mtdata.recipes.yml' +_home_recipes: Path = cache_dir / 'mtdata.recipes.yml' + + +@dataclass +class Recipe: + + id: str + langs: LangPair + train: List[DatasetId] + test: Optional[List[DatasetId]] = None + dev: Optional[List[DatasetId]] = None + desc: Optional[str] = '' + url: str = '' + + @classmethod + def parse(cls, langs, train, test, dev, **kwargs): + train, dev, test = [None if not x else + isinstance(x, list) and x or x.split(',') for x in (train, dev, test)] + langs = lang_pair(langs) + train = train and [DatasetId.parse(i) for i in train] + test = test and [DatasetId.parse(i) for i in test] + dev = dev and [DatasetId.parse(i) for i in dev] + return cls(langs=langs, train=train, test=test, dev=dev, **kwargs) + + def format(self): + rec = vars(self) + rec['langs'] = '-'.join(map(str, self.langs)) + rec['train'] = self.train and ','.join(str(did) for did in self.train) + rec['test'] = self.test and ','.join(str(did) for did in self.test) + rec['dev'] = self.dev and ','.join(str(did) for did in self.dev) + return rec + + @classmethod + def load(cls, *paths) -> Dict[str, 'Recipe']: + assert len(paths) > 0 + recipes = {} + for path in paths: + log.info(f"Loading recipes from {path}") + with open(path) as inp: + recipes_raw = yaml.load(inp) + for r in recipes_raw: + assert isinstance(r, dict), f'{r} expected to be a dict' + r = cls.parse(**r) + assert r.id not in recipes, f'{r} is a duplicate' + recipes[r.id] = r + return recipes + + @classmethod + def load_all(cls): + assert _def_recipes.exists(), f'{_def_recipes} file expected but not found' + paths = [_def_recipes] + if _home_recipes.exists(): + paths.append(_home_recipes) + if _cwd_recipes.exists(): + paths.append(_cwd_recipes) + return cls.load(*paths) + + +def print_all(recipes: List[Recipe], delim='\t', out=sys.stdout): + for i, val in enumerate(recipes): + kvs = val.format().items() + if i == 0: + out.write(delim.join([kv[0] or '' for kv in kvs]) + '\n') + out.write(delim.join([kv[1] or '' for kv in kvs]) + '\n') + + +RECIPES = Recipe.load_all() + + +if __name__ == '__main__': + print_all(list(RECIPES.values())) diff --git a/mtdata/recipe/recipes.yml b/mtdata/recipe/recipes.yml new file mode 100644 index 0000000..d7b49ed --- /dev/null +++ b/mtdata/recipe/recipes.yml @@ -0,0 +1,156 @@ +#- id: template1 +# langs: xxx-yyy +# desc: desc1 +# url: https://example.com +# train: +# - did1 +# - did2 +# dev: did3 +# test: +# - did4 +# - did5 + +- id: vaswani_etal_2017_ende + langs: eng-deu + desc: Transformer - Attention is all you need + url: https://arxiv.org/abs/1706.03762 + train: + - Statmt-commoncrawl_wmt13-1-deu-eng + - Statmt-europarl_wmt13-7-deu-eng + - Statmt-news_commentary_wmt18-13-deu-eng + dev: Statmt-newstest-2013-eng-deu + test: Statmt-newstest_deen-2014-deu-eng + +- id: tg01_deen_10M + langs: deu-eng + desc: 10M sentence corpus for Deu-Eng + url: + train: + - Statmt-commoncrawl_wmt13-1-deu-eng + - Statmt-europarl_wmt13-7-deu-eng + - Statmt-news_commentary_wmt18-13-deu-eng + - EU-dcep-1-deu-eng + - Tilde-eesc-2017-deu-eng + - Tilde-ema-2016-deu-eng + dev: Statmt-newstest_deen-2019-deu-eng + test: Statmt-newstest_deen-2020-deu-eng + +- id: tg01_aren_16M + langs: ara-eng + desc: 16M sentence corpus for ara-eng + train: OPUS-unpc-1.0-ara-eng + dev: UN-un_dev-1-eng-ara + test: UN-un_test-1-eng-ara + +- id: tg01_hien_1M + langs: hin-eng + desc: 1M sentence corpus for hin-eng # its 1.3M lines, but okay to me on log scale + train: IITB-hien_train-1.5-hin-eng + dev: IITB-hien_dev-1.5-hin-eng + test: IITB-hien_test-1.5-hin-eng + +- id: tg01_ruen_1M + langs: rus-eng + desc: 1M sentence corpus for hin-eng + train: Statmt-commoncrawl_wmt13-1-rus-eng,Statmt-wiki_titles-2-rus-eng + dev: Statmt-newstest_ruen-2019-rus-eng + test: Statmt-newstest_ruen-2020-rus-eng + +- id: tg01_knen_100K + langs: kan-eng + desc: 160K sentence Kan-Eng corpus + train: Statmt-ccaligned-1-eng-kan_IN + dev: AI4Bharath-wat_dev-2021-eng-kan + test: AI4Bharath-wat_test-2021-eng-kan + +- id: tg01_teen_100k + langs: tel-eng + train: OPUS-ted2020-1-eng-tel,Statmt-pmindia-1-eng-tel,JoshuaDec-indian_training-1-tel-eng + dev: AI4Bharath-wat_dev-2021-eng-tel + test: AI4Bharath-wat_test-2021-eng-tel + +- id: tg01_smeeng_40K + langs: sme-eng + train: OPUS-kde4-2-eng-sme + dev: OPUS-opus100_dev-1-eng-sme + test: OPUS-opus100_test-1-eng-sme + +- id: tg01_myaeng_18k + langs: mya-eng + desc: "18K mya-eng" + train: WAT-alt_train-2020-mya-eng + dev: WAT-alt_dev-2020-mya-eng + test: WAT-alt_test-2020-mya-eng + +- id: tg01_boseng_1k + langs: bos-eng + desc: "1K bos-eng" + train: OPUS-tatoeba-20210310-bos-eng,Neulab-tedtalks_test-1-eng-bos + dev: OPUS-opus100_dev-1-bos-eng + test: OPUS-opus100_test-1-bos-eng + +- id: tg01_uzbeng_5k + langs: uzb-eng + desc: "5K uzb-eng" + train: OPUS-wikimedia-20210402-eng-uzb + dev: OPUS-opus100_dev-1-eng-uzb + test: OPUS-opus100_test-1-eng-uzb + +- id: tg01_10toeng + langs: mul-eng # mul is multiple languages + desc: 10 languages to english, with various scales + train: + - Statmt-commoncrawl_wmt13-1-deu-eng # deu-eng : 10M + - Statmt-europarl_wmt13-7-deu-eng + - Statmt-news_commentary_wmt18-13-deu-eng + - EU-dcep-1-deu-eng + - Tilde-eesc-2017-deu-eng + - Tilde-ema-2016-deu-eng + - OPUS-unpc-1.0-ara-eng #16M + - IITB-hien_train-1.5-hin-eng # 1.3M + - Statmt-commoncrawl_wmt13-1-rus-eng # rus-eng 1M + - Statmt-wiki_titles-2-rus-eng + - Statmt-ccaligned-1-eng-kan_IN # kan-eng 400K + - OPUS-ted2020-1-eng-tel # tel-eng 100K + - Statmt-pmindia-1-eng-tel + - JoshuaDec-indian_training-1-tel-eng + - OPUS-kde4-2-eng-sme # 40K + - WAT-alt_train-2020-mya-eng # 18k + - OPUS-tatoeba-20210310-bos-eng # 500 + - Neulab-tedtalks_test-1-eng-bos # +500 = 1K + - OPUS-wikimedia-20210402-eng-uzb # 5K + dev: + - Statmt-newstest_deen-2019-deu-eng + - UN-un_dev-1-eng-ara + - IITB-hien_dev-1.5-hin-eng + - Statmt-newstest_ruen-2019-rus-eng + - AI4Bharath-wat_dev-2021-eng-kan + - AI4Bharath-wat_dev-2021-eng-tel + - OPUS-opus100_dev-1-eng-sme + - WAT-alt_dev-2020-mya-eng + - OPUS-opus100_dev-1-bos-eng + - OPUS-opus100_dev-1-eng-uzb + test: + - Statmt-newstest_deen-2020-deu-eng + - UN-un_test-1-eng-ara + - IITB-hien_test-1.5-hin-eng + - Statmt-newstest_ruen-2020-rus-eng + - AI4Bharath-wat_test-2021-eng-kan + - AI4Bharath-wat_test-2021-eng-tel + - OPUS-opus100_test-1-eng-sme + - WAT-alt_test-2020-mya-eng + - OPUS-opus100_test-1-eng-uzb + - OPUS-opus100_test-1-bos-eng + +- id: tg01_2to1_test + desc: testing multilingual + langs: mul-eng # mul is multiple languages + train: + - WAT-alt_train-2020-mya-eng + - Statmt-pmindia-1-eng-tel + dev: + - AI4Bharath-wat_dev-2021-eng-tel + - WAT-alt_dev-2020-mya-eng + test: + - AI4Bharath-wat_test-2021-eng-tel + - WAT-alt_test-2020-mya-eng \ No newline at end of file diff --git a/mtdata/utils.py b/mtdata/utils.py index d471ad2..8311092 100644 --- a/mtdata/utils.py +++ b/mtdata/utils.py @@ -7,8 +7,10 @@ import tarfile import zipfile from dataclasses import dataclass +import portalocker -from mtdata import log + +from mtdata import log, FILE_LOCK_TIMEOUT import shutil from datetime import datetime from pathlib import Path @@ -142,7 +144,7 @@ def exists(self): def open(self, mode='r', **kwargs): assert mode in ('r', 'rt'), f'only "r" is supported, given: {mode}' - log.debug(f"Reading from zip file : {self.root} // {self.name}") + log.debug(f"Reading zip: {self.root}?{self.name}") container = zipfile.ZipFile(self.root, mode='r') stream = container.open(self.name, 'r') reader = io.TextIOWrapper(stream, **kwargs) @@ -159,13 +161,19 @@ def close(*args, **kwargs): @dataclass class TarPath(ArchivedPath): + def __post_init__(self): + self.ext_dir = self.extract() + matches = list(self.ext_dir.glob(self.name)) + assert len(matches) == 1 + self.child = matches[0] + self.open = self.child.open + def exists(self): - with tarfile.open(self.root, encoding='utf-8') as root: - return self.name in root.getnames() + return self.child.exists() - def open(self, mode='r', **kwargs): + def open_old(self, mode='r', **kwargs): assert mode in ('r', 'rt'), f'only "r" is supported, given: {mode}' - log.info(f"Reading from tar file : {self.root} // {self.name}") + log.info(f"Reading tar: {self.root}?{self.name}") container = tarfile.open(self.root, mode='r', encoding='utf-8') stream = container.extractfile(self.name) reader = io.TextIOWrapper(stream, **kwargs) @@ -174,6 +182,31 @@ def open(self, mode='r', **kwargs): def close(*args, **kwargs): reader_close() stream.close() - container.close() + container and container.close() reader.close = close # hijack return reader + + def extract(self): + dir_name = self.extracted_name() + out_path = self.root.parent / dir_name + valid_path = self.root.parent / (dir_name + '.valid') + lock_path = self.root.parent / (dir_name + '.lock') + if not valid_path.exists(): + with portalocker.Lock(lock_path, 'w', timeout=FILE_LOCK_TIMEOUT) as _: + if valid_path.exists(): + return # extracted by parallel process + log.info(f"extracting {self.root}") + with tarfile.open(self.root) as tar: + tar.extractall(out_path) + valid_path.touch() + return out_path + + def extracted_name(self): + exts = ['.tar', '.tar.gz', '.tar.bz2', '.tar.xz'] + name = self.root.name + dir_name = name + '-extracted' + for ext in exts: + if self.root.name.endswith(ext): + dir_name = name[:-len(ext)] + break + return dir_name \ No newline at end of file diff --git a/setup.py b/setup.py index c6abf54..82276b6 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ # Author: Thamme Gowda [tg (at) isi (dot) edu] # Created: 4/6/20 -import mtdata +import re from pathlib import Path from setuptools import setup, find_namespace_packages @@ -22,10 +22,20 @@ 'Programming Language :: Python :: 3 :: Only', ] +init_file = Path(__file__).parent / 'mtdata' / '__init__.py' +init_txt = init_file.read_text() +version_re = re.compile(r'''__version__ = ['"]([0-9.]+(-dev)?)['"]''') +__version__ = version_re.search(init_txt).group(1) +desc_re = re.compile(r'''__description__ = ['"](.*)['"]''') +__description__ = desc_re.search(init_txt).group(1) +assert __version__ +assert __description__ + + setup( name='mtdata', - version=mtdata.__version__, - description=mtdata.__description__, + version=__version__, + description=__description__, long_description=long_description, long_description_content_type='text/markdown', classifiers=classifiers, @@ -40,7 +50,7 @@ 'computational linguistics'], entry_points={ 'console_scripts': [ - 'mtdata=mtdata.main:main', + 'mtdata=mtdata.__main__:main', 'mtdata-iso=mtdata.iso.__main__:main', ], }, @@ -49,6 +59,7 @@ 'enlighten==1.10.1', 'portalocker==2.3.0', 'pybtex==0.24.0', + 'ruamel.yaml >= 0.17.10', ], include_package_data=True, zip_safe=False, diff --git a/tests/test_recipe.py b/tests/test_recipe.py new file mode 100644 index 0000000..2a8d290 --- /dev/null +++ b/tests/test_recipe.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +# +# +# Author: Thamme Gowda +# Created: 10/27/21 + +from pathlib import Path +from tempfile import TemporaryDirectory +from mtdata.main import get_recipe + + +def test_recipe_multilingual(): + recipe_id = 'tg01_2to1_test' + with TemporaryDirectory() as out_dir: + out_dir = Path(out_dir) + get_recipe(recipe_id=recipe_id, out_dir=out_dir, drop_dupes=True, drop_tests=True) + assert (out_dir / 'mtdata.signature.txt').stat().st_size > 0 + assert (out_dir / 'train.eng').stat().st_size > 0 + assert (out_dir / 'train.mul').stat().st_size > 0 + assert (out_dir / 'dev.eng').stat().st_size > 0 + assert (out_dir / 'dev.mul').stat().st_size > 0 +