Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace lap.lapjv() with scipy.optimize.linear_sum_assignment() #3267

Merged
merged 2 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ opt_einsum>=2.3.2
pyro-api>=0.1.1
tqdm>=4.36
funsor[torch]
setuptools<60
setuptools
10 changes: 3 additions & 7 deletions pyro/distributions/one_one_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,19 +160,15 @@ def sample(self, sample_shape=torch.Size()):
def mode(self):
"""
Computes a maximum probability matching.

.. note:: This requires the `lap <https://pypi.org/project/lap/>`_
package and runs on CPU.
"""
return maximum_weight_matching(self.logits)


@torch.no_grad()
def maximum_weight_matching(logits):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=ImportWarning)
import lap
from scipy.optimize import linear_sum_assignment

cost = -logits.cpu()
value = lap.lapjv(cost.numpy())[1]
value = linear_sum_assignment(cost.numpy())[1]
value = torch.tensor(value, dtype=torch.long, device=logits.device)
return value
10 changes: 3 additions & 7 deletions pyro/distributions/one_two_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,6 @@ def sample(self, sample_shape=torch.Size()):
def mode(self):
"""
Computes a maximum probability matching.

.. note:: This requires the `lap <https://pypi.org/project/lap/>`_
package and runs on CPU.
"""
return maximum_weight_matching(self.logits)

Expand Down Expand Up @@ -204,12 +201,11 @@ def enumerate_one_two_matchings(num_destins):

@torch.no_grad()
def maximum_weight_matching(logits):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=ImportWarning)
import lap
from scipy.optimize import linear_sum_assignment

cost = -logits.cpu()
cost = torch.cat([cost, cost], dim=-1) # Duplicate destinations.
value = lap.lapjv(cost.numpy())[1]
value = linear_sum_assignment(cost.numpy())[1]
value = torch.tensor(value, dtype=torch.long, device=logits.device)
value %= logits.size(1)
return value
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
"scikit-learn",
"seaborn>=0.11.0",
"wget",
"lap", # Requires setuptools<60
"scipy>=1.1",
# 'biopython>=1.54',
# 'scanpy>=1.4', # Requires HDF5
# 'scvi>=0.6', # Requires loopy and other fragile packages
Expand Down Expand Up @@ -115,7 +115,6 @@
"pytest-xdist",
"pytest>=5.0",
"ruff",
"scipy>=1.1",
],
"profile": ["prettytable", "pytest-benchmark", "snakeviz"],
"dev": EXTRAS_REQUIRE
Expand All @@ -131,7 +130,6 @@
"pytest-xdist",
"pytest>=5.0",
"ruff",
"scipy>=1.1",
"sphinx",
"sphinx_rtd_theme",
"yapf",
Expand Down
4 changes: 3 additions & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,14 @@ def assert_tensors_equal(a, b, prec=0.0, msg=""):
assert a.size() == b.size(), msg
if isinstance(prec, numbers.Number) and prec == 0:
assert (a == b).all(), msg
return
if a.numel() == 0 and b.numel() == 0:
return
b = b.type_as(a)
b = b.cuda(device=a.get_device()) if a.is_cuda else b.cpu()
if not a.dtype.is_floating_point:
return (a == b).all()
assert (a == b).all(), msg
return
# check that NaNs are in the same locations
nan_mask = a != a
assert torch.equal(nan_mask, b != b), msg
Expand Down
4 changes: 1 addition & 3 deletions tests/distributions/test_one_one_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,20 +112,19 @@ def test_grad_hard(num_nodes):
@pytest.mark.parametrize("dtype", [torch.float, torch.double], ids=str)
@pytest.mark.parametrize("num_nodes", [1, 2, 3, 4, 5, 6, 7, 8])
def test_mode(num_nodes, dtype):
pytest.importorskip("lap")
logits = torch.randn(num_nodes, num_nodes, dtype=dtype) * 10
d = dist.OneOneMatching(logits)
values = d.enumerate_support()
i = d.log_prob(values).max(0).indices.item()
expected = values[i]
actual = d.mode()
assert_equal(actual, expected)
assert (actual == expected).all()


@pytest.mark.parametrize("dtype", [torch.float, torch.double], ids=str)
@pytest.mark.parametrize("num_nodes", [3, 5, 8, 13, 100, 1000])
def test_mode_smoke(num_nodes, dtype):
pytest.importorskip("lap")
logits = torch.randn(num_nodes, num_nodes, dtype=dtype) * 10
d = dist.OneOneMatching(logits)
value = d.mode()
Expand All @@ -136,7 +135,6 @@ def test_mode_smoke(num_nodes, dtype):
@pytest.mark.parametrize("num_nodes", [2, 3, 4, 5, 6])
@pytest.mark.parametrize("bp_iters", [None, BP_ITERS], ids=["exact", "bp"])
def test_sample(num_nodes, dtype, bp_iters):
pytest.importorskip("lap")
logits = torch.randn(num_nodes, num_nodes, dtype=dtype) * 10
d = dist.OneOneMatching(logits, bp_iters=bp_iters)

Expand Down
6 changes: 0 additions & 6 deletions tests/distributions/test_one_two_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ def test_grad_phylo(num_leaves):
@pytest.mark.parametrize("dtype", [torch.float, torch.double], ids=str)
@pytest.mark.parametrize("num_destins", [1, 2, 3, 4, 5])
def test_mode_full(num_destins, dtype):
pytest.importorskip("lap")
num_sources = 2 * num_destins
logits = torch.randn(num_sources, num_destins, dtype=dtype) * 10
d = dist.OneTwoMatching(logits)
Expand All @@ -189,7 +188,6 @@ def test_mode_full(num_destins, dtype):
@pytest.mark.parametrize("dtype", [torch.float, torch.double], ids=str)
@pytest.mark.parametrize("num_leaves", [2, 3, 4, 5, 6])
def test_mode_phylo(num_leaves, dtype):
pytest.importorskip("lap")
logits, times = random_phylo_logits(num_leaves, dtype)
d = dist.OneTwoMatching(logits)
values = d.enumerate_support()
Expand All @@ -202,7 +200,6 @@ def test_mode_phylo(num_leaves, dtype):
@pytest.mark.parametrize("dtype", [torch.float, torch.double], ids=str)
@pytest.mark.parametrize("num_destins", [3, 5, 8, 13, 100, 1000])
def test_mode_full_smoke(num_destins, dtype):
pytest.importorskip("lap")
num_sources = 2 * num_destins
logits = torch.randn(num_sources, num_destins, dtype=dtype) * 10
d = dist.OneTwoMatching(logits)
Expand All @@ -213,7 +210,6 @@ def test_mode_full_smoke(num_destins, dtype):
@pytest.mark.parametrize("dtype", [torch.float, torch.double], ids=str)
@pytest.mark.parametrize("num_leaves", [3, 5, 8, 13, 100, 1000])
def test_mode_phylo_smoke(num_leaves, dtype):
pytest.importorskip("lap")
logits, times = random_phylo_logits(num_leaves, dtype)
d = dist.OneTwoMatching(logits, bp_iters=10)
value = d.mode()
Expand All @@ -224,7 +220,6 @@ def test_mode_phylo_smoke(num_leaves, dtype):
@pytest.mark.parametrize("num_destins", [2, 3, 4])
@pytest.mark.parametrize("bp_iters", [None, BP_ITERS], ids=["exact", "bp"])
def test_sample_full(num_destins, dtype, bp_iters):
pytest.importorskip("lap")
num_sources = 2 * num_destins
logits = torch.randn(num_sources, num_destins, dtype=dtype) * 10
d = dist.OneTwoMatching(logits, bp_iters=bp_iters)
Expand All @@ -251,7 +246,6 @@ def test_sample_full(num_destins, dtype, bp_iters):
@pytest.mark.parametrize("num_leaves", [3, 4, 5])
@pytest.mark.parametrize("bp_iters", [None, BP_ITERS], ids=["exact", "bp"])
def test_sample_phylo(num_leaves, dtype, bp_iters):
pytest.importorskip("lap")
logits, times = random_phylo_logits(num_leaves, dtype)
num_sources, num_destins = logits.shape
d = dist.OneTwoMatching(logits, bp_iters=bp_iters)
Expand Down
Loading