diff --git a/docs/source/contrib.examples.rst b/docs/source/contrib.examples.rst index 0d0f415ea5..65b2c0eb8f 100644 --- a/docs/source/contrib.examples.rst +++ b/docs/source/contrib.examples.rst @@ -20,6 +20,14 @@ BART Ridership :show-inheritance: :member-order: bysource +Nextstrain SARS-CoV-2 counts +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: pyro.contrib.examples.nextstrain + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + Utilities --------- .. automodule:: pyro.contrib.examples.util diff --git a/pyro/contrib/examples/nextstrain.py b/pyro/contrib/examples/nextstrain.py new file mode 100644 index 0000000000..df21c710de --- /dev/null +++ b/pyro/contrib/examples/nextstrain.py @@ -0,0 +1,44 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import subprocess +import urllib + +import torch + +from .util import _mkdir_p, get_data_directory + +DATA = get_data_directory(__file__) +URL = "https://github.com/pyro-ppl/datasets/raw/master/nextstrain.data.pt.gz" + + +def load_nextstrain_counts(map_location=None) -> dict: + """ + Loads a SARS-CoV-2 dataset. + + The original dataset is a preprocessed intermediate ``metadata.tsv.gz`` available via + `nextstrain `_. + The ``metadata.tsv.gz`` file was then aggregated to + (month,location,lineage) and (lineage,mutation) bins by the Broad Institute's + `preprocessing script `_. + """ + # Download the gzipped file. + _mkdir_p(DATA) + basename = URL.split("/")[-1] + gz_filename = os.path.join(DATA, basename) + if not os.path.exists(gz_filename): + logging.debug(f"downloading {URL}") + urllib.request.urlretrieve(URL, gz_filename) + + # Decompress the file. + filename = gz_filename.replace(".gz", "") + if not os.path.exists(filename): + logging.debug(f"unzipping {gz_filename}") + subprocess.check_call(["gunzip", "-k", gz_filename]) + + # Load tensors to the default location. + if map_location is None: + map_location = torch.tensor(0.0).device + return torch.load(filename, map_location=map_location)