Skip to content

Commit

Permalink
Add SpanningTree properties .mode, .edge_mean (#2727)
Browse files Browse the repository at this point in the history
* Add a SpanningTree.edge_mean() method

* Implement a SpanningTree.mean() method

* Convert methods to properties

* Fix typo
  • Loading branch information
fritzo authored Jan 3, 2021
1 parent b5e4a7b commit a5f1ae6
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 3 deletions.
44 changes: 44 additions & 0 deletions pyro/distributions/spanning_tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,52 @@ at::Tensor sample_tree_approx(at::Tensor edge_logits) {
return edges;
}

at::Tensor find_best_tree(at::Tensor edge_logits) {
torch::NoGradGuard no_grad;
const int K = edge_logits.size(0);
const int V = static_cast<int>(0.5 + std::sqrt(0.25 + 2 * K));
const int E = V - 1;
auto grid = make_complete_graph(V);

// Each of E edges in the tree is stored as an id k in [0, K) indexing into
// the complete graph. The id of an edge (v1,v2) is k = v1+v2*(v2-1)/2.
auto edge_ids = torch::empty({E}, at::kLong);
// This maps each vertex to whether it is a member of the cumulative tree.
auto components = torch::zeros({V}, at::kBool);

// Find the first edge.
auto probs = (edge_logits - edge_logits.max()).exp();
int k = probs.argmax(0).item().to<int>();
components[grid[0][k]] = 1;
components[grid[1][k]] = 1;
edge_ids[0] = k;

// Find edges connecting the cumulative tree to a new leaf.
for (int e = 1; e != E; ++e) {
auto c1 = components.index_select(0, grid[0]);
auto c2 = components.index_select(0, grid[1]);
auto mask = c1.__xor__(c2);
auto valid_logits = edge_logits.masked_select(mask);
int k = valid_logits.argmax(0).item().to<int>();
k = mask.nonzero().view(-1)[k].item().to<int>();
components[grid[0][k]] = 1;
components[grid[1][k]] = 1;
edge_ids[e] = k;
}

// Convert edge ids to a canonical list of pairs.
edge_ids = std::get<0>(edge_ids.sort());
auto edges = torch::empty({E, 2}, at::kLong);
for (int e = 0; e != E; ++e) {
edges[e][0] = grid[0][edge_ids[e]];
edges[e][1] = grid[1][edge_ids[e]];
}
return edges;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sample_tree_mcmc", &sample_tree_mcmc, "Sample a random spanning tree using MCMC");
m.def("sample_tree_approx", &sample_tree_approx, "Approximate sample a random spanning tree");
m.def("find_best_tree", &find_best_tree, "Finds a maximum weight spanning tree");
m.def("make_complete_graph", &make_complete_graph, "Constructs a complete graph");
}
89 changes: 89 additions & 0 deletions pyro/distributions/spanning_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,39 @@ def enumerate_support(self, expand=True):
trees = enumerate_spanning_trees(self.num_vertices)
return torch.tensor(trees, dtype=torch.long)

@property
def mode(self):
"""
:returns: The maximum weight spanning tree.
:rtype: Tensor
"""
backend = self.sampler_options.get("backend", "python")
return find_best_tree(self.edge_logits, backend=backend)

@property
def edge_mean(self):
"""
Computes marginal probabilities of each edge being active.
.. note:: This is similar to other distributions' ``.mean()``
method, but with a different shape because this distribution's
values are not encoded as binary matrices.
:returns: A symmetric square ``(V,V)``-shaped matrix with values
in ``[0,1]`` denoting the marginal probability of each edge
being in a sampled value.
:rtype: Tensor
"""
V = self.num_vertices
v1, v2 = make_complete_graph(V).unbind(0)
logits = self.edge_logits - self.edge_logits.max()
w = self.edge_logits.new_zeros(V, V)
w[v1, v2] = w[v2, v1] = logits.exp()
laplacian = w.sum(-1).diag_embed() - w
inv = (laplacian + 1 / V).pinverse()
resistance = inv.diag() + inv.diag()[..., None] - 2 * inv
return resistance * w


################################################################################
# Sampler implementation.
Expand Down Expand Up @@ -432,6 +465,62 @@ def sample_tree(edge_logits, init_edges=None, mcmc_steps=1, backend="python"):
return edges


@torch.no_grad()
def _find_best_tree(edge_logits):
K = len(edge_logits)
V = int(round(0.5 + (0.25 + 2 * K)**0.5))
assert K == V * (V - 1) // 2
E = V - 1
grid = make_complete_graph(V)

# Each of E edges in the tree is stored as an id k in [0, K) indexing into
# the complete graph. The id of an edge (v1,v2) is k = v1+v2*(v2-1)/2.
edge_ids = torch.empty((E,), dtype=torch.long)
# This maps each vertex to whether it is a member of the cumulative tree.
components = torch.zeros(V, dtype=torch.bool)

# Find the first edge.
k = edge_logits.argmax(0).item()
components[grid[:, k]] = 1
edge_ids[0] = k

# Find edges connecting the cumulative tree to a new leaf.
for e in range(1, E):
c1, c2 = components[grid]
mask = (c1 != c2)
valid_logits = edge_logits[mask]
k = valid_logits.argmax(0).item()
k = mask.nonzero(as_tuple=False)[k]
components[grid[:, k]] = 1
edge_ids[e] = k

# Convert edge ids to a canonical list of pairs.
edge_ids = edge_ids.sort()[0]
edges = torch.empty((E, 2), dtype=torch.long)
edges[:, 0] = grid[0, edge_ids]
edges[:, 1] = grid[1, edge_ids]
return edges


def find_best_tree(edge_logits, backend="python"):
"""
Find the maximum weight spanning tree of a dense weighted graph.
:param torch.Tensor edge_logits: A length-K array of nonnormalized log
probabilities.
:returns: An E x 2 tensor of edges in the form of (vertex,vertex) pairs.
Each edge should be sorted and the entire tensor should be
lexicographically sorted.
:rtype: torch.Tensor
"""
if backend == "python":
return _find_best_tree(edge_logits)
elif backend == "cpp":
return _get_cpp_module().find_best_tree(edge_logits)
else:
raise ValueError("unknown backend: {}".format(repr(backend)))


################################################################################
# Enumeration implementation.
################################################################################
Expand Down
61 changes: 58 additions & 3 deletions tests/distributions/test_spanning_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import os
from collections import Counter

import pyro
import pytest
import torch

import pyro
from pyro.distributions.spanning_tree import NUM_SPANNING_TREES, SpanningTree, make_complete_graph, sample_tree
from pyro.distributions.spanning_tree import (NUM_SPANNING_TREES, SpanningTree, find_best_tree, make_complete_graph,
sample_tree)
from tests.common import assert_equal, xfail_if_not_implemented

pytestmark = pytest.mark.skipif("CUDA_TEST" in os.environ, reason="spanning_tree unsupported on CUDA.")
Expand Down Expand Up @@ -58,6 +58,19 @@ def test_sample_tree_approx_smoke(num_edges, backend):
sample_tree(edge_logits, backend=backend)


@pytest.mark.filterwarnings("always")
@pytest.mark.parametrize('num_edges', [1, 3, 10, 30, 100])
@pytest.mark.parametrize('backend', ["python", "cpp"])
def test_find_best_tree_smoke(num_edges, backend):
pyro.set_rng_seed(num_edges)
E = num_edges
V = 1 + E
K = V * (V - 1) // 2
for _ in range(10 if backend == "cpp" or num_edges <= 30 else 1):
edge_logits = torch.rand(K)
find_best_tree(edge_logits, backend=backend)


@pytest.mark.parametrize('num_edges', [1, 2, 3, 4, 5, 6])
def test_enumerate_support(num_edges):
pyro.set_rng_seed(2 ** 32 - num_edges)
Expand Down Expand Up @@ -107,6 +120,48 @@ def test_log_prob(num_edges):
assert abs(log_total) < 1e-6, log_total


@pytest.mark.parametrize('num_edges', [1, 2, 3, 4, 5, 6])
def test_edge_mean_function(num_edges):
pyro.set_rng_seed(2 ** 32 - num_edges)
E = num_edges
V = 1 + E
K = V * (V - 1) // 2
edge_logits = torch.randn(K)
d = SpanningTree(edge_logits)

with xfail_if_not_implemented():
support = d.enumerate_support()
v1 = support[..., 0]
v2 = support[..., 1]
k = v1 + v2 * (v2 - 1) // 2
probs = d.log_prob(support).exp()[:, None].expand_as(k)
expected = torch.zeros(K).scatter_add_(0, k.reshape(-1), probs.reshape(-1))

actual = d.edge_mean
assert actual.shape == (V, V)
v1, v2 = make_complete_graph(V)
assert (actual[v1, v2] - expected).abs().max() < 1e-5, (actual, expected)


@pytest.mark.parametrize('num_edges', [1, 2, 3, 4, 5, 6])
@pytest.mark.parametrize('backend', ["python", "cpp"])
def test_mode(num_edges, backend):
pyro.set_rng_seed(2 ** 32 - num_edges)
E = num_edges
V = 1 + E
K = V * (V - 1) // 2
edge_logits = torch.randn(K)
d = SpanningTree(edge_logits, sampler_options={"backend": backend})
with xfail_if_not_implemented():
support = d.enumerate_support()
v1 = support[..., 0]
v2 = support[..., 1]
k = v1 + v2 * (v2 - 1) // 2
expected = support[edge_logits[k].sum(-1).argmax(0)]
actual = d.mode
assert (actual == expected).all()


@pytest.mark.filterwarnings("always")
@pytest.mark.parametrize('pattern', ["uniform", "random", "sparse"])
@pytest.mark.parametrize('num_edges', [1, 2, 3, 4, 5])
Expand Down

0 comments on commit a5f1ae6

Please sign in to comment.