Skip to content

Commit

Permalink
changed from deprecated numba @generated_jit to @overload
Browse files Browse the repository at this point in the history
  • Loading branch information
sibirrer committed Jul 11, 2023
1 parent 9724c6c commit 04eb495
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
16 changes: 11 additions & 5 deletions lenstronomy/Util/numba_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
if numba_enabled:
try:
import numba
from numba import extending
except ImportError:
numba_enabled = False
numba = None
extending = None

__all__ = ['jit', 'generated_jit', 'nan_to_num', 'nan_to_num_arr', 'nan_to_num_single']
__all__ = ['jit', 'overload', 'nan_to_num', 'nan_to_num_arr', 'nan_to_num_single']


def jit(nopython=nopython, cache=cache, parallel=parallel, fastmath=fastmath, error_model=error_model, inline='never'):
Expand All @@ -40,23 +42,27 @@ def wrapper(func):
return wrapper


def generated_jit(nopython=nopython, cache=cache, parallel=parallel, fastmath=fastmath, error_model=error_model):
def overload(nopython=nopython, cache=cache, parallel=parallel, fastmath=fastmath, error_model=error_model):
"""
Wrapper around numba.generated_jit. Allows you to redirect a function to another based on its type
- see the Numba docs for more info
"""
if numba_enabled:
def wrapper(func):
return numba.generated_jit(func, nopython=nopython, cache=cache, parallel=parallel, fastmath=fastmath,
error_model=error_model)
return func

@numba.extending.overload(wrapper, nopython=nopython, cache=cache, parallel=parallel, fastmath=fastmath,
error_model=error_model)
def ol_wrapper(func):
return func
else:
def wrapper(func):
return func

return wrapper


@generated_jit()
@overload()
def nan_to_num(x, posinf=1e10, neginf=-1e10, nan=0.):
"""
Implements a Numba equivalent to np.nan_to_num (with copy=False!) array or scalar in Numba.
Expand Down
2 changes: 1 addition & 1 deletion test/test_LensModel/test_Profiles/test_interpol.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,4 @@ def test_shift(self):


if __name__ == '__main__':
pytest.main("-k TestInterpol")
pytest.main()

0 comments on commit 04eb495

Please sign in to comment.