Skip to content

Commit

Permalink
Add helper to convert Bio.Phylo trees to PyTorch (#2557)
Browse files Browse the repository at this point in the history
* Add helper to convert Bio.Phylo trees to PyTorch

* Fix docs

* Fix parsing format

* Remove debugging statements

* Do not require biopython
  • Loading branch information
fritzo authored Jul 15, 2020
1 parent 776b6ef commit acde6b5
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@
'funsor': ('http://funsor.pyro.ai/en/stable/', None),
'opt_einsum': ('https://optimized-einsum.readthedocs.io/en/stable/', None),
'scipy': ('https://docs.scipy.org/doc/scipy/reference/', None),
'Bio': ('https://biopython.readthedocs.io/en/latest/', None),
}

# document class constructors (__init__ methods):
Expand Down
2 changes: 2 additions & 0 deletions docs/source/contrib.epidemiology.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,5 @@ Distributions
:show-inheritance:
:member-order: bysource
:special-members: __call__

.. autofunction:: pyro.distributions.coalescent.bio_phylo_to_times
3 changes: 3 additions & 0 deletions pyro/contrib/epidemiology/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from pyro.distributions.coalescent import bio_phylo_to_times

from .compartmental import CompartmentalModel
from .distributions import beta_binomial_dist, binomial_dist, infection_dist

__all__ = [
"CompartmentalModel",
"beta_binomial_dist",
"binomial_dist",
"bio_phylo_to_times",
"infection_dist",
]
47 changes: 47 additions & 0 deletions pyro/distributions/coalescent.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,53 @@ def __call__(self, rate_grid, t=slice(None)):
return const + linear + log


def bio_phylo_to_times(tree, *, get_time=None):
"""
Extracts coalescent summary statistics from a phylogeny, suitable for use
with :class:`~pyro.distributions.CoalescentRateLikelihood`.
:param Bio.Phylo.BaseTree.Clade tree: A phylogenetic tree.
:param callable get_time: Optional function to extract the time point of
each sub-:class:`~Bio.Phylo.BaseTree.Clade`. If absent, times will be
computed by cumulative `.branch_length`.
:returns: A pair of :class:`~torch.Tensor` s ``(leaf_times, coal_times)``
where ``leaf_times`` are times of sampling events (leaf nodes in the
phylogenetic tree) and ``coal_times`` are times of coalescences (leaf
nodes in the phylogenetic binary tree).
:rtype: tuple
"""
if get_time is None:
# Compute time as cumulative branch length.
def get_branch_length(clade):
branch_length = clade.branch_length
return 1.0 if branch_length is None else branch_length
times = {tree.root: get_branch_length(tree.root)}

leaf_times = []
coal_times = []
for clade in tree.find_clades():
if get_time is None:
time = times[clade]
for child in clade:
times[child] = time + get_branch_length(child)
else:
time = get_time(clade)

num_children = len(clade)
if num_children == 0:
leaf_times.append(time)
else:
# Pyro expects binary coalescent events, so we split n-ary events
# into n-1 separate binary events.
for _ in range(num_children - 1):
coal_times.append(time)
assert len(leaf_times) == 1 + len(coal_times)

leaf_times = torch.tensor(leaf_times)
coal_times = torch.tensor(coal_times)
return leaf_times, coal_times


def _gather(tensor, dim, index):
"""
Like :func:`torch.gather` but broadcasts.
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
'matplotlib>=1.3',
'torchvision>=0.6.0',
'visdom>=0.1.4',
# 'biopython>=1.54', # requires Python 3.6
'pandas',
'seaborn',
'wget',
Expand Down
87 changes: 86 additions & 1 deletion tests/distributions/test_coalescent.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import io
import re

import pytest
import torch

import pyro
from pyro.distributions import CoalescentTimes, CoalescentTimesWithRate
from pyro.distributions.coalescent import CoalescentRateLikelihood, CoalescentTimesConstraint, _sample_coalescent_times
from pyro.distributions.coalescent import (CoalescentRateLikelihood, CoalescentTimesConstraint,
_sample_coalescent_times, bio_phylo_to_times)
from pyro.distributions.util import broadcast_shape
from tests.common import assert_close

Expand Down Expand Up @@ -144,3 +148,84 @@ def test_likelihood_sequential(num_leaves, num_steps, batch_shape, clamped):
for t in range(num_steps))

assert_close(actual, expected)


TREE_NEXUS = """
#NEXUS
Begin Trees;
Tree tree1=((EPI_ISL_408009:0.00000[&date=2020.08],
EPI_ISL_408008:0.00000[&date=2020.08]) NODE_0000004:0.17430[&date=2020.08],
(EPI_ISL_417931:0.28554[&date=2020.21],
(EPI_ISL_417332:0.11102[&date=2020.20], EPI_ISL_413931:0.08643[&date=2020.18])
NODE_0000005:0.16360[&date=2020.09], ((EPI_ISL_413558:0.11909[&date=2020.16],
(EPI_ISL_413559:0.07179[&date=2020.16],
(EPI_ISL_412862:0.00000[&date=2020.15],
EPI_ISL_413561:0.01093[&date=2020.16])
NODE_0000011:0.06086[&date=2020.15])
NODE_0000012:0.04730[&date=2020.09]) NODE_0000007:0.06603[&date=2020.04],
(EPI_ISL_411955:0.09393[&date=2020.11],
(EPI_ISL_417325:0.08372[&date=2020.17],
(EPI_ISL_417318:0.02411[&date=2020.16],
EPI_ISL_417320:0.03504[&date=2020.17])
NODE_0000009:0.05141[&date=2020.14])
NODE_0000006:0.07032[&date=2020.09])
NODE_0000014:0.04474[&date=2020.02]) NODE_0000010:0.04578[&date=2019.97],
(EPI_ISL_417933:0.15496[&date=2020.21], EPI_ISL_414648:0.13583[&date=2020.19],
(EPI_ISL_417932:0.09490[&date=2020.21],
(EPI_ISL_417937:0.05785[&date=2020.21],
EPI_ISL_417331:0.04419[&date=2020.20])
NODE_0000001:0.03705[&date=2020.15],
(EPI_ISL_417938:0.06860[&date=2020.21],
(EPI_ISL_417939:0.04394[&date=2020.21],
(EPI_ISL_417330:0.00314[&date=2020.20],
(EPI_ISL_416457:0.00000[&date=2020.21],
EPI_ISL_417935:0.00000[&date=2020.21])
NODE_0000018:0.01680[&date=2020.21])
NODE_0000017:0.02714[&date=2020.19])
NODE_0000016:0.02466[&date=2020.17])
NODE_0000015:0.02630[&date=2020.14])
NODE_0000000:0.06006[&date=2020.12])
NODE_0000002:0.13059[&date=2020.06])
NODE_0000003:0.02264[&date=2019.93]) NODE_0000008:0.10000[&date=2019.90];
End;
"""


@pytest.fixture
def tree():
Phylo = pytest.importorskip("Bio.Phylo")
tree_file = io.StringIO(TREE_NEXUS)
trees = list(Phylo.parse(tree_file, "nexus"))
assert len(trees) == 1
return trees[0]


def test_bio_phylo_to_times(tree):
leaf_times, coal_times = bio_phylo_to_times(tree)
assert len(coal_times) + 1 == len(leaf_times)

# Check positivity.
times = torch.cat([coal_times, leaf_times])
signs = torch.cat([-torch.ones_like(coal_times), torch.ones_like(leaf_times)])
times, index = times.sort(0)
signs = signs[index]
lineages = signs.flip([0]).cumsum(0).flip([0])
assert (lineages >= 0).all()


def test_bio_phylo_to_times_custom(tree):
# Test a custom time parser.
def get_time(clade):
date_string = re.search(r"date=(\d\d\d\d\.\d\d)", clade.comment).group(1)
return (float(date_string) - 2020) * 365.25

leaf_times, coal_times = bio_phylo_to_times(tree, get_time=get_time)
assert len(coal_times) + 1 == len(leaf_times)

# Check positivity.
times = torch.cat([coal_times, leaf_times])
signs = torch.cat([-torch.ones_like(coal_times), torch.ones_like(leaf_times)])
times, index = times.sort(0)
signs = signs[index]
lineages = signs.flip([0]).cumsum(0).flip([0])
assert (lineages >= 0).all()

0 comments on commit acde6b5

Please sign in to comment.