Skip to content

Commit

Permalink
Add solver.ProblemsGraph test
Browse files Browse the repository at this point in the history
  • Loading branch information
AntoinePrv committed Feb 9, 2024
1 parent 57ff9ad commit 244555e
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
8 changes: 7 additions & 1 deletion libmambapy/src/libmambapy/bindings/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mamba/solver/solution.hpp"

#include "bindings.hpp"
#include "flat_set_caster.hpp"
#include "utils.hpp"

namespace mamba::solver
Expand Down Expand Up @@ -356,7 +357,12 @@ namespace mambapy
.def("conflicts", &ProblemsGraph::conflicts_t::conflicts)
.def("in_conflict", &ProblemsGraph::conflicts_t::in_conflict)
.def("clear", [](ProblemsGraph::conflicts_t& self) { return self.clear(); })
.def("add", &ProblemsGraph::conflicts_t::add);
.def("add", &ProblemsGraph::conflicts_t::add)
.def(py::self == py::self)
.def(py::self != py::self)
.def("__copy__", &copy<ProblemsGraph::conflicts_t>)
.def("__deepcopy__", &deepcopy<ProblemsGraph::conflicts_t>, py::arg("memo"));


py_problems_graph.def("root_node", &ProblemsGraph::root_node)
.def("conflicts", &ProblemsGraph::conflicts)
Expand Down
48 changes: 47 additions & 1 deletion libmambapy/tests/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

import libmambapy.solver
import libmambapy


def test_import_submodule():
Expand Down Expand Up @@ -193,3 +193,49 @@ def test_Solution():
other = copy.deepcopy(sol)
assert other is not sol
assert len(other.actions) == len(sol.actions)


def test_ProblemsGraph():
# Create a ProblemsGraph
db = libmambapy.solver.libsolv.Database(libmambapy.specs.ChannelResolveParams())
db.add_repo_from_packages(
[
libmambapy.specs.PackageInfo(name="a", version="1.0", depends=["b>=2.0", "c>=2.1"]),
libmambapy.specs.PackageInfo(name="b", version="2.0", depends=["c<2.0"]),
libmambapy.specs.PackageInfo(name="c", version="1.0"),
libmambapy.specs.PackageInfo(name="c", version="3.0"),
],
)

request = libmambapy.solver.Request(
[libmambapy.solver.Request.Install(libmambapy.specs.MatchSpec.parse("a"))]
)

outcome = libmambapy.solver.libsolv.Solver().solve(db, request)

assert isinstance(outcome, libmambapy.solver.libsolv.UnSolvable)
pbg = outcome.problems_graph(db)

# ProblemsGraph conflicts
conflicts = pbg.conflicts()
assert len(conflicts) == 2
assert len(list(conflicts)) == 2
node, in_conflict = next(iter(conflicts))
assert conflicts.has_conflict(node)
for other in in_conflict:
assert conflicts.in_conflict(node, other)

other_conflicts = copy.deepcopy(conflicts)
assert other_conflicts is not conflicts
assert other_conflicts == conflicts

other_conflicts.clear()
assert len(other_conflicts) == 0

other_conflicts.add(7, 42)
assert other_conflicts.in_conflict(7, 42)

# ProblemsGraph graph
nodes, edges = pbg.graph()
assert len(nodes) > 0
assert len(edges) > 0

0 comments on commit 244555e

Please sign in to comment.