Skip to content

Commit

Permalink
Reintroduce cdef _union & _find
Browse files Browse the repository at this point in the history
This avoids potential SEGFAULTs from Python calls
  • Loading branch information
gmou3 committed Apr 22, 2024
1 parent be8144d commit 7c91033
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 10 deletions.
2 changes: 2 additions & 0 deletions src/sage/sets/disjoint_set.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ cdef class DisjointSet_class(SageObject):
cdef class DisjointSet_of_integers(DisjointSet_class):
cpdef int find(self, int i)
cpdef void union(self, int i, int j)
cdef inline int _find(self, int i)
cdef inline void _union(self, int i, int j)
cpdef root_to_elements_dict(self)
cpdef element_to_root_dict(self)
cpdef to_digraph(self)
Expand Down
77 changes: 67 additions & 10 deletions src/sage/sets/disjoint_set.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ AUTHORS:
- Sébastien Labbé (2008) - Initial version.
- Sébastien Labbé (2009-11-24) - Pickling support
- Sébastien Labbé (2010-01) - Inclusion into sage (:issue:`6775`).
- Giorgos Mousa (2024-04-22): Optimize
EXAMPLES:
Expand Down Expand Up @@ -447,7 +448,7 @@ cdef class DisjointSet_of_integers(DisjointSet_class):
INPUT:
- ``i`` -- element in ``self`` (no input checking)
- ``i`` -- element in ``self``
EXAMPLES::
Expand All @@ -468,8 +469,31 @@ cdef class DisjointSet_of_integers(DisjointSet_class):
{{0}, {1, 2, 3, 4}}
sage: [e.find(i) for i in range(5)]
[0, 1, 1, 1, 1]
sage: e.find(5) # no input checking
0
sage: e.find(2**10)
ValueError: i(=1024) must be between 0 and 4
...
"""
card = self.cardinality()
if i < 0 or i>= card:
raise ValueError('i(=%s) must be between 0 and %s' % (i, card - 1))
return OP_find(self._nodes, i)

cdef inline int _find(self, int i):
r"""
Return the representative of the set that ``i`` currently belongs to.
INPUT:
- ``i`` -- element in ``self``
EXAMPLES::
sage: e = DisjointSet(5)
sage: e._find(5) # only C-callable
Traceback (most recent call last):
...
AttributeError: 'sage.sets.disjoint_set.DisjointSet_of_integers'
object has no attribute '_find'. Did you mean: 'find'?
"""
return OP_find(self._nodes, i)

Expand All @@ -495,8 +519,38 @@ cdef class DisjointSet_of_integers(DisjointSet_class):
{{0, 1}, {2, 4}, {3}}
sage: d.union(1, 4); d
{{0, 1, 2, 4}, {3}}
sage: d.union(1, 5); d # no input checking
{{0, 1, 2, 4}, {3}}
sage: d.union(1, 5)
ValueError: j(=5) must be between 0 and 4
...
"""
cdef int card = self._nodes.degree
if i < 0 or i >= card:
raise ValueError('i(=%s) must be between 0 and %s' % (i, card - 1))
if j < 0 or j >= card:
raise ValueError('j(=%s) must be between 0 and %s' % (j, card - 1))
OP_join(self._nodes, i, j)

cdef inline void _union(self, int i, int j):
r"""
Combine the set of ``i`` and the set of ``j`` into one.
All elements in those two sets will share the same representative
that can be gotten using find.
INPUT:
- ``i`` -- element in ``self``
- ``j`` -- element in ``self``
EXAMPLES::
sage: d = DisjointSet(5); d
{{0}, {1}, {2}, {3}, {4}}
sage: d._union(0, 1) # only C-callable
Traceback (most recent call last):
...
AttributeError: 'sage.sets.disjoint_set.DisjointSet_of_integers'
object has no attribute '_union'. Did you mean: 'union'?
"""
OP_join(self._nodes, i, j)

Expand Down Expand Up @@ -727,7 +781,7 @@ cdef class DisjointSet_of_hashables(DisjointSet_class):
INPUT:
- ``e`` -- element in ``self`` (no input checking)
- ``e`` -- element in ``self``
EXAMPLES::
Expand All @@ -754,7 +808,7 @@ cdef class DisjointSet_of_hashables(DisjointSet_class):
KeyError: 5
"""
cdef int i = <int> self._el_to_int[e]
cdef int r = <int> self._d.find(i)
cdef int r = <int> self._d._find(i)
return self._int_to_el[r]

cpdef void union(self, e, f):
Expand All @@ -766,8 +820,8 @@ cdef class DisjointSet_of_hashables(DisjointSet_class):
INPUT:
- ``e`` -- element in ``self`` (no input checking)
- ``f`` -- element in ``self`` (no input checking)
- ``e`` -- element in ``self``
- ``f`` -- element in ``self``
EXAMPLES::
Expand All @@ -779,10 +833,13 @@ cdef class DisjointSet_of_hashables(DisjointSet_class):
{{'a', 'b'}, {'c', 'e'}, {'d'}}
sage: e.union('b', 'e'); e
{{'a', 'b', 'c', 'e'}, {'d'}}
sage: e.union('a', 2**10)
KeyError: 1024
...
"""
cdef int i = <int> self._el_to_int[e]
cdef int j = <int> self._el_to_int[f]
self._d.union(i, j)
self._d._union(i, j)

cpdef root_to_elements_dict(self):
r"""
Expand Down

0 comments on commit 7c91033

Please sign in to comment.