Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add enforcement for np.sort and np.argsort #918

Merged
merged 59 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
9a81558
set mergesort as default and disable unstable kinds
yuema137 Oct 23, 2024
339277f
add unittest
yuema137 Oct 23, 2024
b127ee9
formatting
yuema137 Oct 23, 2024
cda8c43
formatting
yuema137 Oct 23, 2024
fb1419d
change name to sort_enforcement
yuema137 Oct 23, 2024
c1e8246
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 23, 2024
05c594c
break long error messages
yuema137 Oct 23, 2024
cdeb036
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
yuema137 Oct 23, 2024
fda2638
keep the original sorting in numpy
yuema137 Oct 25, 2024
94f9159
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2024
1989b73
reemove unused import
yuema137 Oct 30, 2024
d0aa707
always use stablesort
yuema137 Oct 31, 2024
348f8e3
add numba-supported version of stableargsort
yuema137 Oct 31, 2024
c776020
use better naming for stablesort
yuema137 Oct 31, 2024
939777a
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
yuema137 Oct 31, 2024
b609e71
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2024
8f23ef5
use jitable to allow both regular function and numba-decorated functi…
yuema137 Oct 31, 2024
3347799
remove redundant numba_sort
yuema137 Oct 31, 2024
f8ac388
explicitly import stablesort from strax for numba decorated functions
yuema137 Oct 31, 2024
2cf9753
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
yuema137 Oct 31, 2024
bfc89e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2024
8a7c00d
consistent import style within one module
yuema137 Oct 31, 2024
2fae5ce
remove unused import
yuema137 Oct 31, 2024
4b933e0
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
yuema137 Oct 31, 2024
ab9ff71
add sorting error
yuema137 Oct 31, 2024
66f4d2b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2024
8f759c6
disable numba support for stable_sort
yuema137 Oct 31, 2024
478a61f
consistent import style for stable sort
yuema137 Oct 31, 2024
92a7807
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
yuema137 Oct 31, 2024
0be548f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2024
3b236d3
add kwargs
yuema137 Nov 1, 2024
074d597
merge master
yuema137 Nov 1, 2024
b9e203d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2024
024989a
modify docstring for stable_sort
yuema137 Nov 1, 2024
6e9491f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2024
66839ae
remove kwargs
yuema137 Nov 13, 2024
edfc4a6
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
yuema137 Nov 13, 2024
4c11074
update variable name
yuema137 Nov 13, 2024
0168bc8
update test_sort with hypothesis
yuema137 Nov 13, 2024
413c8c1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2024
898a8cf
rewrite hithest_density_region to decoupld stable_sort from numba part
yuema137 Nov 13, 2024
f942bc6
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
yuema137 Nov 13, 2024
a9006f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2024
507a6cd
remove unused import
yuema137 Nov 13, 2024
93e4b5b
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
yuema137 Nov 13, 2024
cf9d276
break long lines
yuema137 Nov 13, 2024
ddb242d
Merge branch 'master' into set_default_as_mergesort
yuema137 Nov 13, 2024
4ad97ae
remove numba decorator for the main function
yuema137 Nov 13, 2024
6d0b394
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
yuema137 Nov 13, 2024
c311544
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2024
586d7df
fix typo
yuema137 Nov 13, 2024
f4aefe8
rewrite hitlets to use non-numba HDR region
yuema137 Nov 13, 2024
e32d49d
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
yuema137 Nov 13, 2024
b0fb1ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2024
3d9cc73
format hitlets.py
yuema137 Nov 13, 2024
13cc0d7
unify growing_result import to fix mypy error
yuema137 Nov 13, 2024
5a3beab
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
yuema137 Nov 13, 2024
48d4e93
remove redundant space
yuema137 Nov 14, 2024
248a2a8
Remove unnecessary indent
dachengx Nov 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions strax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Glue the package together
# See https://www.youtube.com/watch?v=0oTh1CXRaQ0 if this confuses you
# The order of subpackes is not invariant, since we use strax.xxx inside strax
from .sort_enforcement import *
from .utils import *
from .chunk import *
from .dtypes import *
Expand Down
79 changes: 79 additions & 0 deletions strax/sort_enforcement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np
import functools
import warnings

# Store original functions
original_sort = np.sort
original_argsort = np.argsort


class SortingError(Exception):
"""Custom exception for sorting violations."""

pass


def enforce_mergesort(func_name):
"""Raises warning if not using mergesort."""
warnings.warn(
f"Direct use of {func_name} detected.\n"
"Please use the mergesort wrapper to ensure deterministic behavior.",
UserWarning,
stacklevel=3,
)


def create_sort_wrapper(original_func, func_name):
"""Creates a wrapper that enforces mergesort usage."""

@functools.wraps(original_func)
def wrapper(arr, *args, **kwargs):
# Check if explicit quicksort is requested
if kwargs.get("kind") == "quicksort" or kwargs.get("kind") == "heapsort":
raise SortingError(
"quicksort and heapsort are not allowed due to non-deterministic behavior.\n"
"Please use mergesort explicitely, "
"or remove the 'kind' parameter to use mergesort by default."
)

# Always use mergesort
kwargs["kind"] = "mergesort"

# Log usage of direct sorting functions
if original_func in {np.sort, np.argsort}:
enforce_mergesort(func_name)

return original_func(arr, *args, **kwargs)

return wrapper


# Create wrappers for all sorting functions
sort_wrapper = create_sort_wrapper(original_sort, "np.sort")
argsort_wrapper = create_sort_wrapper(original_argsort, "np.argsort")

# Export wrapped versions for explicit usage
mergesort = sort_wrapper
mergesort_argsort = argsort_wrapper


def enable_safe_sorting():
"""Patches all NumPy sorting methods to enforce mergesort.

Returns a function to restore original behavior if needed.

"""
# Replace all sorting methods
np.sort = sort_wrapper
dachengx marked this conversation as resolved.
Show resolved Hide resolved
np.argsort = argsort_wrapper

def restore_original_sorts():
"""Restores original NumPy sorting behavior."""
np.sort = original_sort
np.argsort = original_argsort

return restore_original_sorts


# Enable safe sorting immediately
restore_sorts = enable_safe_sorting()
73 changes: 73 additions & 0 deletions tests/test_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import unittest
import numpy as np
import warnings
from strax.sort_enforcement import SortingError, mergesort, mergesort_argsort, restore_sorts


class TestSortEnforcement(unittest.TestCase):
def setUp(self):
"""Set up test fixtures before each test method."""
self.arr = np.array([3, 1, 4, 1, 5, 9, 2, 6])
# Store expected sorted array and indices for comparison
self.expected_sorted = np.array([1, 1, 2, 3, 4, 5, 6, 9])
yuema137 marked this conversation as resolved.
Show resolved Hide resolved
self.expected_argsort = np.array([1, 3, 6, 0, 2, 4, 7, 5])

def test_explicit_mergesort(self):
"""Test explicit mergesort function (should not warn)"""
with warnings.catch_warnings():
warnings.simplefilter("error") # Turn warnings into errors
sorted_arr = mergesort(self.arr)
np.testing.assert_array_equal(sorted_arr, self.expected_sorted)

def test_explicit_mergesort_argsort(self):
"""Test explicit mergesort_argsort function (should not warn)"""
with warnings.catch_warnings():
warnings.simplefilter("error") # Turn warnings into errors
sorted_indices = mergesort_argsort(self.arr)
np.testing.assert_array_equal(sorted_indices, self.expected_argsort)

def test_quicksort_rejection(self):
"""Test that quicksort and heapsort raise errors for both sort and argsort."""
# Test np.sort
with self.assertRaises(SortingError):
np.sort(self.arr, kind="quicksort")
with self.assertRaises(SortingError):
np.sort(self.arr, kind="heapsort")

# Test np.argsort
with self.assertRaises(SortingError):
np.argsort(self.arr, kind="quicksort")
with self.assertRaises(SortingError):
np.argsort(self.arr, kind="heapsort")

def test_sort_stability(self):
"""Test that sorting is stable (mergesort property)"""
# Create array with duplicate values
arr = np.array(
[(1, "a"), (2, "b"), (1, "c"), (2, "d")], dtype=[("num", int), ("letter", "U1")]
)
sorted_arr = np.sort(arr, order="num")
# Check that relative order of equal elements is preserved
self.assertEqual(sorted_arr[0]["letter"], "a")
yuema137 marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(sorted_arr[1]["letter"], "c")
self.assertEqual(sorted_arr[2]["letter"], "b")
self.assertEqual(sorted_arr[3]["letter"], "d")

def test_restore_functionality(self):
"""Test that restore_sorts function works correctly."""
# First verify mergesort is enforced
with self.assertRaises(SortingError):
np.sort(self.arr, kind="quicksort")

# Restore original behavior
restore_sorts()

# Now quicksort should work without error
try:
np.sort(self.arr, kind="quicksort")
except SortingError:
self.fail("quicksort raised SortingError after restore")


if __name__ == "__main__":
unittest.main(verbosity=2)
Loading