From ab2545b1e9ba99a7a57d52ec8c32fa44e6bba69b Mon Sep 17 00:00:00 2001 From: Kyle Oman Date: Mon, 16 Dec 2024 12:25:10 +0000 Subject: [PATCH] Account for differing histogram* signatures across numpy versions. --- unyt/_array_functions.py | 150 +++++++++++++++++++++++++++++++-------- 1 file changed, 121 insertions(+), 29 deletions(-) diff --git a/unyt/_array_functions.py b/unyt/_array_functions.py index 8ea16c2b..ac31e363 100644 --- a/unyt/_array_functions.py +++ b/unyt/_array_functions.py @@ -158,20 +158,25 @@ def _sanitize_range(_range, units): return new_range.squeeze() -@implements(np.histogram) -def histogram( - a, - bins=10, - range=None, - density=None, - weights=None, - *args, - **kwargs, -): +def _histogram(a, bins=10, range=None, density=None, weights=None, normed=None): range = _sanitize_range(range, units=[a.units]) - counts, bins = np.histogram._implementation( - np.asarray(a), bins, range, *args, **kwargs - ) + if NUMPY_VERSION > Version("1.23"): + counts, bins = np.histogram._implementation( + np.asarray(a), + bins=bins, + range=range, + density=density, + weights=np.asarray(weights) if weights is not None else None, + ) + else: + counts, bins = np.histogram._implementation( + np.asarray(a), + bins=bins, + range=range, + normed=normed, + weights=np.asarray(weights) if weights is not None else None, + density=density, + ) # a and/or weights could have units, only apply if present # don't getattr(..., "units", NULL_UNIT) because e.g. we don't want # a unyt_array if weights are not a unyt_array and not density @@ -182,12 +187,42 @@ def histogram( return counts, bins * a.units -@implements(np.histogram2d) -def histogram2d(x, y, bins=10, range=None, density=None, weights=None, *args, **kwargs): +if NUMPY_VERSION > Version("1.23"): + + @implements(np.histogram) + def histogram(a, bins=10, range=None, density=None, weights=None): + return _histogram(a, bins=bins, range=range, density=density, weights=weights) + +else: + + @implements(np.histogram) + def histogram(a, bins=10, range=None, normed=None, weights=None, density=None): + return _histogram( + a, bins=bins, range=range, normed=normed, weights=weights, density=density + ) + + +def _histogram2d(x, y, bins=10, range=None, density=None, weights=None, normed=None): range = _sanitize_range(range, units=[x.units, y.units]) - counts, xbins, ybins = np.histogram2d._implementation( - np.asarray(x), np.asarray(y), bins, range, *args, **kwargs - ) + if NUMPY_VERSION > Version("1.23"): + counts, xbins, ybins = np.histogram2d._implementation( + np.asarray(x), + np.asarray(y), + bins=bins, + range=range, + density=density, + weights=np.asarray(weights) if weights is not None else None, + ) + else: + counts, xbins, ybins = np.histogram2d._implementation( + np.asarray(x), + np.asarray(y), + bins=bins, + range=range, + normed=normed, + weights=np.asarray(weights) if weights is not None else None, + density=density, + ) # x, y and/or weights could have units, only apply if present # don't getattr(..., "units", NULL_UNIT) because e.g. we don't want # a unyt_array if weights are not a unyt_array and not density @@ -201,15 +236,49 @@ def histogram2d(x, y, bins=10, range=None, density=None, weights=None, *args, ** return counts, xbins * x.units, ybins * y.units -@implements(np.histogramdd) -def histogramdd( - sample, bins=10, range=None, density=None, weights=None, *args, **kwargs -): - units = [_.units for _ in sample] +if NUMPY_VERSION > Version("1.23"): + + @implements(np.histogram2d) + def histogram2d(x, y, bins=10, range=None, density=None, weights=None): + return _histogram2d( + x, y, bins=bins, range=range, density=density, weights=weights + ) + +else: + + @implements(np.histogram2d) + def histogram2d(x, y, bins=10, range=None, normed=None, weights=None, density=None): + return _histogram2d( + x, + y, + bins=bins, + range=range, + normed=normed, + weights=weights, + density=density, + ) + + +def _histogramdd(sample, bins=10, range=None, density=None, weights=None, normed=None): + units = [getattr(_, "units", NULL_UNIT) for _ in sample] range = _sanitize_range(range, units=units) - counts, bins = np.histogramdd._implementation( - [np.asarray(_) for _ in sample], bins, range, *args, **kwargs - ) + if NUMPY_VERSION > Version("1.23"): + counts, bins = np.histogramdd._implementation( + [np.asarray(_) for _ in sample], + bins=bins, + range=range, + density=density, + weights=np.asarray(weights) if weights is not None else None, + ) + else: + counts, bins = np.histogramdd._implementation( + [np.asarray(_) for _ in sample], + bins=bins, + range=range, + normed=normed, + weights=np.asarray(weights) if weights is not None else None, + density=density, + ) # sample(s) and/or weights could have units, only apply if present # don't getattr(..., "units", NULL_UNIT) because e.g. we don't want # a unyt_array if weights are not a unyt_array and not density @@ -222,6 +291,30 @@ def histogramdd( return counts, tuple(_bin * u for _bin, u in zip(bins, units)) +if NUMPY_VERSION > Version("1.23"): + + @implements(np.histogramdd) + def histogramdd(sample, bins=10, range=None, density=None, weights=None): + return _histogramdd( + sample, bins=bins, range=range, density=density, weights=weights + ) + +else: + + @implements(np.histogramdd) + def histogramdd( + sample, bins=10, range=None, normed=None, weights=None, density=None + ): + return _histogramdd( + sample, + bins=bins, + range=range, + normed=normed, + weights=weights, + density=density, + ) + + @implements(np.histogram_bin_edges) def histogram_bin_edges(a, *args, **kwargs): return ( @@ -600,9 +693,8 @@ def nanquantile(a, *args, **kwargs): @implements(np.linalg.det) def linalg_det(a, *args, **kwargs): - return ( - np.linalg.det._implementation(np.asarray(a), *args, **kwargs) - * a.units ** (a.shape[0]) + return np.linalg.det._implementation(np.asarray(a), *args, **kwargs) * a.units ** ( + a.shape[0] )