Skip to content

Commit

Permalink
Account for differing histogram* signatures across numpy versions.
Browse files Browse the repository at this point in the history
  • Loading branch information
kyleaoman committed Dec 16, 2024
1 parent 459b90e commit ab2545b
Showing 1 changed file with 121 additions and 29 deletions.
150 changes: 121 additions & 29 deletions unyt/_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,25 @@ def _sanitize_range(_range, units):
return new_range.squeeze()


@implements(np.histogram)
def histogram(
a,
bins=10,
range=None,
density=None,
weights=None,
*args,
**kwargs,
):
def _histogram(a, bins=10, range=None, density=None, weights=None, normed=None):
range = _sanitize_range(range, units=[a.units])
counts, bins = np.histogram._implementation(
np.asarray(a), bins, range, *args, **kwargs
)
if NUMPY_VERSION > Version("1.23"):
counts, bins = np.histogram._implementation(
np.asarray(a),
bins=bins,
range=range,
density=density,
weights=np.asarray(weights) if weights is not None else None,
)
else:
counts, bins = np.histogram._implementation(
np.asarray(a),
bins=bins,
range=range,
normed=normed,
weights=np.asarray(weights) if weights is not None else None,
density=density,
)
# a and/or weights could have units, only apply if present
# don't getattr(..., "units", NULL_UNIT) because e.g. we don't want
# a unyt_array if weights are not a unyt_array and not density
Expand All @@ -182,12 +187,42 @@ def histogram(
return counts, bins * a.units


@implements(np.histogram2d)
def histogram2d(x, y, bins=10, range=None, density=None, weights=None, *args, **kwargs):
if NUMPY_VERSION > Version("1.23"):

@implements(np.histogram)
def histogram(a, bins=10, range=None, density=None, weights=None):
return _histogram(a, bins=bins, range=range, density=density, weights=weights)

else:

@implements(np.histogram)
def histogram(a, bins=10, range=None, normed=None, weights=None, density=None):
return _histogram(
a, bins=bins, range=range, normed=normed, weights=weights, density=density
)


def _histogram2d(x, y, bins=10, range=None, density=None, weights=None, normed=None):
range = _sanitize_range(range, units=[x.units, y.units])
counts, xbins, ybins = np.histogram2d._implementation(
np.asarray(x), np.asarray(y), bins, range, *args, **kwargs
)
if NUMPY_VERSION > Version("1.23"):
counts, xbins, ybins = np.histogram2d._implementation(
np.asarray(x),
np.asarray(y),
bins=bins,
range=range,
density=density,
weights=np.asarray(weights) if weights is not None else None,
)
else:
counts, xbins, ybins = np.histogram2d._implementation(
np.asarray(x),
np.asarray(y),
bins=bins,
range=range,
normed=normed,
weights=np.asarray(weights) if weights is not None else None,
density=density,
)
# x, y and/or weights could have units, only apply if present
# don't getattr(..., "units", NULL_UNIT) because e.g. we don't want
# a unyt_array if weights are not a unyt_array and not density
Expand All @@ -201,15 +236,49 @@ def histogram2d(x, y, bins=10, range=None, density=None, weights=None, *args, **
return counts, xbins * x.units, ybins * y.units


@implements(np.histogramdd)
def histogramdd(
sample, bins=10, range=None, density=None, weights=None, *args, **kwargs
):
units = [_.units for _ in sample]
if NUMPY_VERSION > Version("1.23"):

@implements(np.histogram2d)
def histogram2d(x, y, bins=10, range=None, density=None, weights=None):
return _histogram2d(
x, y, bins=bins, range=range, density=density, weights=weights
)

else:

@implements(np.histogram2d)
def histogram2d(x, y, bins=10, range=None, normed=None, weights=None, density=None):
return _histogram2d(
x,
y,
bins=bins,
range=range,
normed=normed,
weights=weights,
density=density,
)


def _histogramdd(sample, bins=10, range=None, density=None, weights=None, normed=None):
units = [getattr(_, "units", NULL_UNIT) for _ in sample]
range = _sanitize_range(range, units=units)
counts, bins = np.histogramdd._implementation(
[np.asarray(_) for _ in sample], bins, range, *args, **kwargs
)
if NUMPY_VERSION > Version("1.23"):
counts, bins = np.histogramdd._implementation(
[np.asarray(_) for _ in sample],
bins=bins,
range=range,
density=density,
weights=np.asarray(weights) if weights is not None else None,
)
else:
counts, bins = np.histogramdd._implementation(
[np.asarray(_) for _ in sample],
bins=bins,
range=range,
normed=normed,
weights=np.asarray(weights) if weights is not None else None,
density=density,
)
# sample(s) and/or weights could have units, only apply if present
# don't getattr(..., "units", NULL_UNIT) because e.g. we don't want
# a unyt_array if weights are not a unyt_array and not density
Expand All @@ -222,6 +291,30 @@ def histogramdd(
return counts, tuple(_bin * u for _bin, u in zip(bins, units))


if NUMPY_VERSION > Version("1.23"):

@implements(np.histogramdd)
def histogramdd(sample, bins=10, range=None, density=None, weights=None):
return _histogramdd(
sample, bins=bins, range=range, density=density, weights=weights
)

else:

@implements(np.histogramdd)
def histogramdd(
sample, bins=10, range=None, normed=None, weights=None, density=None
):
return _histogramdd(
sample,
bins=bins,
range=range,
normed=normed,
weights=weights,
density=density,
)


@implements(np.histogram_bin_edges)
def histogram_bin_edges(a, *args, **kwargs):
return (
Expand Down Expand Up @@ -600,9 +693,8 @@ def nanquantile(a, *args, **kwargs):

@implements(np.linalg.det)
def linalg_det(a, *args, **kwargs):
return (
np.linalg.det._implementation(np.asarray(a), *args, **kwargs)
* a.units ** (a.shape[0])
return np.linalg.det._implementation(np.asarray(a), *args, **kwargs) * a.units ** (
a.shape[0]
)


Expand Down

0 comments on commit ab2545b

Please sign in to comment.