-
Notifications
You must be signed in to change notification settings - Fork 28
/
kirchhoff.py
207 lines (180 loc) · 7.87 KB
/
kirchhoff.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import logging
import math
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
import torch
from pyro.distributions import constraints
from pyro.infer.reparam import GumbelSoftmaxReparam
from pyro.nn import PyroModule, PyroSample
from sklearn.cluster import AgglomerativeClustering
logger = logging.getLogger(__name__)
# TODO replace with CoalescentTimes(self.leaf_times, ordered=False)
class CoalescentTimes(dist.TransformedDistribution):
support = constraints.less_than(0.)
def __init__(self, leaf_times):
if not (leaf_times == 0).all():
raise NotImplementedError
L = len(leaf_times)
super().__init__(
dist.Exponential(torch.ones(L - 1)).to_event(1),
dist.transforms.AffineTransform(0., -1.))
class GTRSubstitutionModel(PyroModule):
"""
Generalized time-reversible substitution model among ``dim``-many states.
"""
def __init__(self, dim=4):
super().__init__()
self.dim = dim
self.stationary = PyroSample(dist.Dirichlet(torch.full((dim,), 2.)))
self.rates = PyroSample(
dist.Exponential(torch.ones(dim * (dim - 1) // 2)).to_event(1))
i = torch.arange(dim)
self._index = (i > i[:, None]).nonzero(as_tuple=False).T
@property
def transition(self):
p = self.stationary
i, j = self._index
m = torch.zeros(self.dim, self.dim)
m[i, j] = self.rates
m = m + m.T * (p / p[:, None])
m = m - m.sum(dim=-1).diag_embed()
return m
def forward(self, times, states):
m = self.transition
times = times.abs()
return states @ (m * times[..., None]).matrix_exp().transpose(-1, -2)
class KirchhoffModel(PyroModule):
"""
A phylogenetic tree model that marginalizes over tree structure and relaxes
over the states of internal nodes.
"""
def __init__(self, leaf_times, leaf_data, leaf_mask, *,
temperature=1.):
super().__init__()
assert leaf_times.dim() == 1
assert (leaf_times[:-1] <= leaf_times[1:]).all()
assert leaf_data.dim() == 2
assert leaf_mask.shape == leaf_data.shape
assert leaf_data.shape[:1] == leaf_times.shape
assert temperature > 0
L, C = leaf_data.shape
D = 1 + leaf_data.max().item() - leaf_data.min().item()
self.leaf_times = leaf_times
self.leaf_states = torch.zeros(L, C, D).scatter_(-1, leaf_data[..., None], 1)
self.subs_model = GTRSubstitutionModel(dim=D)
self.temperature = torch.tensor(float(temperature))
self.leaf_mask = leaf_mask
self.is_latent = torch.cat([~leaf_mask, leaf_mask.new_ones(L - 1, C)], dim=0)
self.num_nodes = 2 * L - 1
self._initialize()
def forward(self, sample_tree=False):
L, C, D = self.leaf_states.shape
N = 2 * L - 1
# Impute missing states.
with pyro.plate("nodes", N, dim=-2), \
pyro.plate("characters", C, dim=-1), \
poutine.mask(mask=self.is_latent), \
poutine.reparam(config={"states": GumbelSoftmaxReparam()}):
states = pyro.sample(
"states",
dist.RelaxedOneHotCategorical(self.temperature, torch.ones(D)))
# Interleave with observed states.
states = torch.cat([torch.where(self.leaf_mask[..., None],
self.leaf_states, states[:L]),
states[L:]], dim=0)
# Sample times of internal nodes.
internal_times = pyro.sample("internal_times", CoalescentTimes(self.leaf_times))
times = torch.cat([self.leaf_times, internal_times])
# Account for random tree structure.
edge_logits = self.kernel(states, times)
tree_dist = dist.SpanningTree(edge_logits, {"backend": "cpp"})
if not sample_tree:
# During training, analytically marginalize over trees.
pyro.factor("tree_likelihood", tree_dist.log_partition_function)
else:
# During prediction, simply sample a tree.
return pyro.sample("tree", tree_dist)
def kernel(self, states, times):
"""
Given states and times, compute pairwise transition log probability
between every undirected pair of states. This will be -inf for
infeasible pairs, namely leaf-leaf and internal-after-leaf.
"""
N, C, D = states.shape
assert times.shape == (N,)
L = (N + 1) // 2
# Select feasible time-ordered pairs.
with torch.no_grad():
feasible = times[:, None] < times
feasible[:L] = False # Leaves are terminal.
v0, v1 = feasible.nonzero(as_tuple=False).unbind(-1)
# Convert dense square float64 -> sparse float32.
dtype = states.dtype
states = states.float()
times = times.float()
x0 = states[v0]
x1 = states[v1]
dt = times[v1] - times[v0] + 1e-6
m = self.subs_model.transition.float()
# There are multiple ways to extend the mutation likelihood function to
# the interior of the relaxed space.
exp_mt = (dt[:, None, None] * m).matrix_exp()
kernel_version = 1
if kernel_version == 0:
sparse_logits = torch.einsum("fcd,fce,fde->fc", x0, x1, exp_mt).log().sum(-1)
elif kernel_version == 1:
# Accumulate sufficient statistics over characters.
stats = torch.einsum("fcd,fce->fde", x0, x1)
sparse_logits = torch.einsum("fde,fde->f", exp_mt.add(1e-6).log(), stats)
assert sparse_logits.isfinite().all()
# Convert sparse float32 -> dense triangular float64.
v0, v1 = torch.min(v0, v1), torch.max(v0, v1)
k = v0 + v1 * (v1 - 1) // 2 # SpanningTree canonical ordering.
K = N * (N - 1) // 2 # Number of edges in complete graph.
edge_logits = sparse_logits.new_full((K,), -math.inf)
edge_logits[k] = sparse_logits
return edge_logits.to(dtype)
def _initialize(self):
logger.info("Initializing via agglomerative clustering")
# Deterministically impute, only used by initialization.
missing = ~self.leaf_mask
self.leaf_states[missing] = 1. / self.subs_model.dim
# Heuristically initialize hierarchy.
L, C, D = self.leaf_states.shape
N = 2 * L - 1
data = self.leaf_states.reshape(L, C * D)
clustering = AgglomerativeClustering(
distance_threshold=0, n_clusters=None).fit(data)
children = clustering.children_
assert children.shape == (L - 1, 2)
# Heuristically initialize times and states.
times = torch.full((N,), math.nan)
states = torch.full((N, C, D), math.nan)
times[:L] = self.leaf_times
states[:L] = self.leaf_states * 0.99 + 0.01 / D
for p, (c1, c2) in enumerate(children):
times[L + p] = min(times[c1], times[c2]) - 1
states[L + p] = (states[c1] + states[c2]) / 2
assert times.isfinite().all()
assert states.isfinite().all()
self.init_internal_times = times[L:].clone()
self.init_states = states
def init_loc_fn(self, site):
"""
Heuristic initialization for guides.
"""
if site["name"] == "states":
return self.init_states
if site["name"] == "states_uniform":
# This is the GumbelSoftmaxReparam latent variable.
return 0.1 + 0.8 * self.init_states
if site["name"] == "internal_times":
return self.init_internal_times
if site["name"].endswith("subs_model.rates"):
# Initialize to low mutation rate.
return torch.full(site["fn"].shape(), 0.1)
if site["name"].endswith("subs_model.stationary"):
D, = site["fn"].event_shape
return torch.ones(D) / D
raise ValueError("unknown site {}".format(site["name"]))