Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SpanningTree properties .mode, .edge_mean #2727

Merged
merged 4 commits into from
Jan 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions pyro/distributions/spanning_tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,52 @@ at::Tensor sample_tree_approx(at::Tensor edge_logits) {
return edges;
}

at::Tensor find_best_tree(at::Tensor edge_logits) {
torch::NoGradGuard no_grad;
const int K = edge_logits.size(0);
const int V = static_cast<int>(0.5 + std::sqrt(0.25 + 2 * K));
const int E = V - 1;
auto grid = make_complete_graph(V);

// Each of E edges in the tree is stored as an id k in [0, K) indexing into
// the complete graph. The id of an edge (v1,v2) is k = v1+v2*(v2-1)/2.
auto edge_ids = torch::empty({E}, at::kLong);
// This maps each vertex to whether it is a member of the cumulative tree.
auto components = torch::zeros({V}, at::kBool);

// Find the first edge.
auto probs = (edge_logits - edge_logits.max()).exp();
int k = probs.argmax(0).item().to<int>();
components[grid[0][k]] = 1;
components[grid[1][k]] = 1;
edge_ids[0] = k;

// Find edges connecting the cumulative tree to a new leaf.
for (int e = 1; e != E; ++e) {
auto c1 = components.index_select(0, grid[0]);
auto c2 = components.index_select(0, grid[1]);
auto mask = c1.__xor__(c2);
auto valid_logits = edge_logits.masked_select(mask);
int k = valid_logits.argmax(0).item().to<int>();
k = mask.nonzero().view(-1)[k].item().to<int>();
components[grid[0][k]] = 1;
components[grid[1][k]] = 1;
edge_ids[e] = k;
}

// Convert edge ids to a canonical list of pairs.
edge_ids = std::get<0>(edge_ids.sort());
auto edges = torch::empty({E, 2}, at::kLong);
for (int e = 0; e != E; ++e) {
edges[e][0] = grid[0][edge_ids[e]];
edges[e][1] = grid[1][edge_ids[e]];
}
return edges;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sample_tree_mcmc", &sample_tree_mcmc, "Sample a random spanning tree using MCMC");
m.def("sample_tree_approx", &sample_tree_approx, "Approximate sample a random spanning tree");
m.def("find_best_tree", &find_best_tree, "Finds a maximum weight spanning tree");
m.def("make_complete_graph", &make_complete_graph, "Constructs a complete graph");
}
89 changes: 89 additions & 0 deletions pyro/distributions/spanning_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,39 @@ def enumerate_support(self, expand=True):
trees = enumerate_spanning_trees(self.num_vertices)
return torch.tensor(trees, dtype=torch.long)

@property
def mode(self):
"""
:returns: The maximum weight spanning tree.
:rtype: Tensor
"""
backend = self.sampler_options.get("backend", "python")
return find_best_tree(self.edge_logits, backend=backend)

@property
def edge_mean(self):
"""
Computes marginal probabilities of each edge being active.

.. note:: This is similar to other distributions' ``.mean()``
method, but with a different shape because this distribution's
values are not encoded as binary matrices.

:returns: A symmetric square ``(V,V)``-shaped matrix with values
in ``[0,1]`` denoting the marginal probability of each edge
being in a sampled value.
:rtype: Tensor
"""
V = self.num_vertices
v1, v2 = make_complete_graph(V).unbind(0)
logits = self.edge_logits - self.edge_logits.max()
w = self.edge_logits.new_zeros(V, V)
w[v1, v2] = w[v2, v1] = logits.exp()
laplacian = w.sum(-1).diag_embed() - w
inv = (laplacian + 1 / V).pinverse()
resistance = inv.diag() + inv.diag()[..., None] - 2 * inv
return resistance * w


################################################################################
# Sampler implementation.
Expand Down Expand Up @@ -432,6 +465,62 @@ def sample_tree(edge_logits, init_edges=None, mcmc_steps=1, backend="python"):
return edges


@torch.no_grad()
def _find_best_tree(edge_logits):
K = len(edge_logits)
V = int(round(0.5 + (0.25 + 2 * K)**0.5))
assert K == V * (V - 1) // 2
E = V - 1
grid = make_complete_graph(V)

# Each of E edges in the tree is stored as an id k in [0, K) indexing into
# the complete graph. The id of an edge (v1,v2) is k = v1+v2*(v2-1)/2.
edge_ids = torch.empty((E,), dtype=torch.long)
# This maps each vertex to whether it is a member of the cumulative tree.
components = torch.zeros(V, dtype=torch.bool)

# Find the first edge.
k = edge_logits.argmax(0).item()
components[grid[:, k]] = 1
edge_ids[0] = k

# Find edges connecting the cumulative tree to a new leaf.
for e in range(1, E):
c1, c2 = components[grid]
mask = (c1 != c2)
valid_logits = edge_logits[mask]
k = valid_logits.argmax(0).item()
k = mask.nonzero(as_tuple=False)[k]
components[grid[:, k]] = 1
edge_ids[e] = k

# Convert edge ids to a canonical list of pairs.
edge_ids = edge_ids.sort()[0]
edges = torch.empty((E, 2), dtype=torch.long)
edges[:, 0] = grid[0, edge_ids]
edges[:, 1] = grid[1, edge_ids]
return edges


def find_best_tree(edge_logits, backend="python"):
"""
Find the maximum weight spanning tree of a dense weighted graph.

:param torch.Tensor edge_logits: A length-K array of nonnormalized log
probabilities.
:returns: An E x 2 tensor of edges in the form of (vertex,vertex) pairs.
Each edge should be sorted and the entire tensor should be
lexicographically sorted.
:rtype: torch.Tensor
"""
if backend == "python":
return _find_best_tree(edge_logits)
elif backend == "cpp":
return _get_cpp_module().find_best_tree(edge_logits)
else:
raise ValueError("unknown backend: {}".format(repr(backend)))


################################################################################
# Enumeration implementation.
################################################################################
Expand Down
61 changes: 58 additions & 3 deletions tests/distributions/test_spanning_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import os
from collections import Counter

import pyro
import pytest
import torch

import pyro
from pyro.distributions.spanning_tree import NUM_SPANNING_TREES, SpanningTree, make_complete_graph, sample_tree
from pyro.distributions.spanning_tree import (NUM_SPANNING_TREES, SpanningTree, find_best_tree, make_complete_graph,
sample_tree)
from tests.common import assert_equal, xfail_if_not_implemented

pytestmark = pytest.mark.skipif("CUDA_TEST" in os.environ, reason="spanning_tree unsupported on CUDA.")
Expand Down Expand Up @@ -58,6 +58,19 @@ def test_sample_tree_approx_smoke(num_edges, backend):
sample_tree(edge_logits, backend=backend)


@pytest.mark.filterwarnings("always")
@pytest.mark.parametrize('num_edges', [1, 3, 10, 30, 100])
@pytest.mark.parametrize('backend', ["python", "cpp"])
def test_find_best_tree_smoke(num_edges, backend):
pyro.set_rng_seed(num_edges)
E = num_edges
V = 1 + E
K = V * (V - 1) // 2
for _ in range(10 if backend == "cpp" or num_edges <= 30 else 1):
edge_logits = torch.rand(K)
find_best_tree(edge_logits, backend=backend)


@pytest.mark.parametrize('num_edges', [1, 2, 3, 4, 5, 6])
def test_enumerate_support(num_edges):
pyro.set_rng_seed(2 ** 32 - num_edges)
Expand Down Expand Up @@ -107,6 +120,48 @@ def test_log_prob(num_edges):
assert abs(log_total) < 1e-6, log_total


@pytest.mark.parametrize('num_edges', [1, 2, 3, 4, 5, 6])
def test_edge_mean_function(num_edges):
pyro.set_rng_seed(2 ** 32 - num_edges)
E = num_edges
V = 1 + E
K = V * (V - 1) // 2
edge_logits = torch.randn(K)
d = SpanningTree(edge_logits)

with xfail_if_not_implemented():
support = d.enumerate_support()
v1 = support[..., 0]
v2 = support[..., 1]
k = v1 + v2 * (v2 - 1) // 2
probs = d.log_prob(support).exp()[:, None].expand_as(k)
expected = torch.zeros(K).scatter_add_(0, k.reshape(-1), probs.reshape(-1))

actual = d.edge_mean
assert actual.shape == (V, V)
v1, v2 = make_complete_graph(V)
assert (actual[v1, v2] - expected).abs().max() < 1e-5, (actual, expected)


@pytest.mark.parametrize('num_edges', [1, 2, 3, 4, 5, 6])
@pytest.mark.parametrize('backend', ["python", "cpp"])
def test_mode(num_edges, backend):
pyro.set_rng_seed(2 ** 32 - num_edges)
E = num_edges
V = 1 + E
K = V * (V - 1) // 2
edge_logits = torch.randn(K)
d = SpanningTree(edge_logits, sampler_options={"backend": backend})
with xfail_if_not_implemented():
support = d.enumerate_support()
v1 = support[..., 0]
v2 = support[..., 1]
k = v1 + v2 * (v2 - 1) // 2
expected = support[edge_logits[k].sum(-1).argmax(0)]
actual = d.mode
assert (actual == expected).all()


@pytest.mark.filterwarnings("always")
@pytest.mark.parametrize('pattern', ["uniform", "random", "sparse"])
@pytest.mark.parametrize('num_edges', [1, 2, 3, 4, 5])
Expand Down