diff --git a/src/sage/sets/disjoint_set.pxd b/src/sage/sets/disjoint_set.pxd index e8795776fc2..de7f9bf5890 100644 --- a/src/sage/sets/disjoint_set.pxd +++ b/src/sage/sets/disjoint_set.pxd @@ -22,6 +22,8 @@ cdef class DisjointSet_class(SageObject): cdef class DisjointSet_of_integers(DisjointSet_class): cpdef int find(self, int i) cpdef void union(self, int i, int j) + cdef inline int _find(self, int i) + cdef inline void _union(self, int i, int j) cpdef root_to_elements_dict(self) cpdef element_to_root_dict(self) cpdef to_digraph(self) diff --git a/src/sage/sets/disjoint_set.pyx b/src/sage/sets/disjoint_set.pyx index 9f95afd5eb2..fa1bc17af2c 100644 --- a/src/sage/sets/disjoint_set.pyx +++ b/src/sage/sets/disjoint_set.pyx @@ -14,6 +14,7 @@ AUTHORS: - Sébastien Labbé (2008) - Initial version. - Sébastien Labbé (2009-11-24) - Pickling support - Sébastien Labbé (2010-01) - Inclusion into sage (:issue:`6775`). +- Giorgos Mousa (2024-04-22): Optimize EXAMPLES: @@ -447,7 +448,7 @@ cdef class DisjointSet_of_integers(DisjointSet_class): INPUT: - - ``i`` -- element in ``self`` (no input checking) + - ``i`` -- element in ``self`` EXAMPLES:: @@ -468,8 +469,31 @@ cdef class DisjointSet_of_integers(DisjointSet_class): {{0}, {1, 2, 3, 4}} sage: [e.find(i) for i in range(5)] [0, 1, 1, 1, 1] - sage: e.find(5) # no input checking - 0 + sage: e.find(2**10) + ValueError: i(=1024) must be between 0 and 4 + ... + """ + card = self.cardinality() + if i < 0 or i>= card: + raise ValueError('i(=%s) must be between 0 and %s' % (i, card - 1)) + return OP_find(self._nodes, i) + + cdef inline int _find(self, int i): + r""" + Return the representative of the set that ``i`` currently belongs to. + + INPUT: + + - ``i`` -- element in ``self`` + + EXAMPLES:: + + sage: e = DisjointSet(5) + sage: e._find(5) # only C-callable + Traceback (most recent call last): + ... + AttributeError: 'sage.sets.disjoint_set.DisjointSet_of_integers' + object has no attribute '_find'. Did you mean: 'find'? """ return OP_find(self._nodes, i) @@ -495,8 +519,38 @@ cdef class DisjointSet_of_integers(DisjointSet_class): {{0, 1}, {2, 4}, {3}} sage: d.union(1, 4); d {{0, 1, 2, 4}, {3}} - sage: d.union(1, 5); d # no input checking - {{0, 1, 2, 4}, {3}} + sage: d.union(1, 5) + ValueError: j(=5) must be between 0 and 4 + ... + """ + cdef int card = self._nodes.degree + if i < 0 or i >= card: + raise ValueError('i(=%s) must be between 0 and %s' % (i, card - 1)) + if j < 0 or j >= card: + raise ValueError('j(=%s) must be between 0 and %s' % (j, card - 1)) + OP_join(self._nodes, i, j) + + cdef inline void _union(self, int i, int j): + r""" + Combine the set of ``i`` and the set of ``j`` into one. + + All elements in those two sets will share the same representative + that can be gotten using find. + + INPUT: + + - ``i`` -- element in ``self`` + - ``j`` -- element in ``self`` + + EXAMPLES:: + + sage: d = DisjointSet(5); d + {{0}, {1}, {2}, {3}, {4}} + sage: d._union(0, 1) # only C-callable + Traceback (most recent call last): + ... + AttributeError: 'sage.sets.disjoint_set.DisjointSet_of_integers' + object has no attribute '_union'. Did you mean: 'union'? """ OP_join(self._nodes, i, j) @@ -727,7 +781,7 @@ cdef class DisjointSet_of_hashables(DisjointSet_class): INPUT: - - ``e`` -- element in ``self`` (no input checking) + - ``e`` -- element in ``self`` EXAMPLES:: @@ -754,7 +808,7 @@ cdef class DisjointSet_of_hashables(DisjointSet_class): KeyError: 5 """ cdef int i = self._el_to_int[e] - cdef int r = self._d.find(i) + cdef int r = self._d._find(i) return self._int_to_el[r] cpdef void union(self, e, f): @@ -766,8 +820,8 @@ cdef class DisjointSet_of_hashables(DisjointSet_class): INPUT: - - ``e`` -- element in ``self`` (no input checking) - - ``f`` -- element in ``self`` (no input checking) + - ``e`` -- element in ``self`` + - ``f`` -- element in ``self`` EXAMPLES:: @@ -779,10 +833,13 @@ cdef class DisjointSet_of_hashables(DisjointSet_class): {{'a', 'b'}, {'c', 'e'}, {'d'}} sage: e.union('b', 'e'); e {{'a', 'b', 'c', 'e'}, {'d'}} + sage: e.union('a', 2**10) + KeyError: 1024 + ... """ cdef int i = self._el_to_int[e] cdef int j = self._el_to_int[f] - self._d.union(i, j) + self._d._union(i, j) cpdef root_to_elements_dict(self): r""" @@ -845,7 +902,7 @@ cdef class DisjointSet_of_hashables(DisjointSet_class): sage: g.edges(sort=True) # needs sage.graphs [(0, 0, None), (1, 2, None), (2, 2, None), (3, 2, None), (4, 2, None)] - The result depends on the ordering of the join:: + The result depends on the ordering of the union:: sage: d = DisjointSet(range(5)) sage: d.union(1, 2)