Skip to content

Commit

Permalink
Add helper to download preprocessed aggregated nextstrain dataset (#2978
Browse files Browse the repository at this point in the history
)

* Add code to download nextstrain dataset

* Add docs

* Address review comment
  • Loading branch information
fritzo authored Dec 7, 2021
1 parent f9b7d3d commit 02bffda
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/source/contrib.examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ BART Ridership
:show-inheritance:
:member-order: bysource

Nextstrain SARS-CoV-2 counts
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: pyro.contrib.examples.nextstrain
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

Utilities
---------
.. automodule:: pyro.contrib.examples.util
Expand Down
44 changes: 44 additions & 0 deletions pyro/contrib/examples/nextstrain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import logging
import os
import subprocess
import urllib

import torch

from .util import _mkdir_p, get_data_directory

DATA = get_data_directory(__file__)
URL = "https://github.com/pyro-ppl/datasets/raw/master/nextstrain.data.pt.gz"


def load_nextstrain_counts(map_location=None) -> dict:
"""
Loads a SARS-CoV-2 dataset.
The original dataset is a preprocessed intermediate ``metadata.tsv.gz`` available via
`nextstrain <https://docs.nextstrain.org/projects/ncov/en/latest/reference/remote_inputs.html>`_.
The ``metadata.tsv.gz`` file was then aggregated to
(month,location,lineage) and (lineage,mutation) bins by the Broad Institute's
`preprocessing script <https://github.com/broadinstitute/pyro-cov/blob/master/scripts/preprocess_nextstrain.py>`_.
"""
# Download the gzipped file.
_mkdir_p(DATA)
basename = URL.split("/")[-1]
gz_filename = os.path.join(DATA, basename)
if not os.path.exists(gz_filename):
logging.debug(f"downloading {URL}")
urllib.request.urlretrieve(URL, gz_filename)

# Decompress the file.
filename = gz_filename.replace(".gz", "")
if not os.path.exists(filename):
logging.debug(f"unzipping {gz_filename}")
subprocess.check_call(["gunzip", "-k", gz_filename])

# Load tensors to the default location.
if map_location is None:
map_location = torch.tensor(0.0).device
return torch.load(filename, map_location=map_location)

0 comments on commit 02bffda

Please sign in to comment.