Skip to content

Commit

Permalink
Track assignment: replace munkres with lapjv
Browse files Browse the repository at this point in the history
See the following comparison between several implementations to solve
this problem: https://github.com/berhane/LAP-solvers
  • Loading branch information
snejus committed Dec 27, 2024
1 parent 2277e2a commit 420117b
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 38 deletions.
37 changes: 13 additions & 24 deletions beets/autotag/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
import re
from collections.abc import Iterable, Sequence
from enum import IntEnum
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar, Union, cast
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar, cast

from munkres import Munkres
import lap
import numpy as np

from beets import config, logging, plugins
from beets.autotag import (
Expand Down Expand Up @@ -126,21 +127,15 @@ def assign_items(
of objects of the two types.
"""
# Construct the cost matrix.
costs: list[list[Distance]] = []
for item in items:
row = []
for track in tracks:
row.append(track_distance(item, track))
costs.append(row)

costs = [[float(track_distance(i, t)) for t in tracks] for i in items]
# Find a minimum-cost bipartite matching.
log.debug("Computing track assignment...")
matching = Munkres().compute(costs)
cost, _, assigned_idxs = lap.lapjv(np.array(costs), extend_cost=True)
log.debug("...done.")

# Produce the output matching.
mapping = {items[i]: tracks[j] for (i, j) in matching}
extra_items = list(set(items) - set(mapping.keys()))
mapping = {items[i]: tracks[t] for (t, i) in enumerate(assigned_idxs)}
extra_items = list(set(items) - mapping.keys())
extra_items.sort(key=lambda i: (i.disc, i.track, i.title))
extra_tracks = list(set(tracks) - set(mapping.values()))
extra_tracks.sort(key=lambda t: (t.index, t.title))
Expand All @@ -154,6 +149,10 @@ def track_index_changed(item: Item, track_info: TrackInfo) -> bool:
return item.track not in (track_info.medium_index, track_info.index)


track_length_grace = config["match"]["track_length_grace"].as_number()
track_length_max = config["match"]["track_length_max"].as_number()


def track_distance(
item: Item,
track_info: TrackInfo,
Expand All @@ -166,18 +165,8 @@ def track_distance(
dist = hooks.Distance()

# Length.
if track_info.length:
item_length = cast(float, item.length)
track_length_grace = cast(
Union[float, int],
config["match"]["track_length_grace"].as_number(),
)
track_length_max = cast(
Union[float, int],
config["match"]["track_length_max"].as_number(),
)

diff = abs(item_length - track_info.length) - track_length_grace
if info_length := track_info.length:
diff = abs(item.length - info_length) - track_length_grace
dist.add_ratio("track_length", diff, track_length_max)

# Title.
Expand Down
3 changes: 3 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ Bug fixes:
:bug:`5265`
:bug:`5371`
:bug:`4715`
* :ref:`import-cmd`: Fix ``MemoryError`` and improve performance tagging large
albums by replacing ``munkres`` library with ``lap.lapjv``.
:bug:`5207`

For packagers:

Expand Down
81 changes: 68 additions & 13 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ python = ">=3.9,<4"
colorama = { version = "*", markers = "sys_platform == 'win32'" }
confuse = ">=1.5.0"
jellyfish = "*"
lap = ">=0.5.12"
mediafile = ">=0.12.0"
munkres = ">=1.0.0"
musicbrainzngs = ">=0.4"
numpy = ">=1.24.4"
platformdirs = ">=3.5.0"
pyyaml = "*"
typing_extensions = { version = "*", python = "<=3.10" }
Expand Down

0 comments on commit 420117b

Please sign in to comment.