Skip to content

Commit

Permalink
Merge branch 'master' into nd-support-to-trim_zero
Browse files Browse the repository at this point in the history
  • Loading branch information
antonwolfy authored Dec 18, 2024
2 parents ee19ab4 + cabc0d7 commit 21be8f4
Showing 1 changed file with 14 additions and 21 deletions.
35 changes: 14 additions & 21 deletions dpnp/dpnp_utils/dpnp_utils_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def _calc_median(a, axis, out=None):
return res


def _calc_nanmedian(a, axis, out=None):
def _calc_nanmedian(a, out=None):
"""Compute the median of an array along a specified axis, ignoring NaNs."""
mask = dpnp.isnan(a)
valid_counts = dpnp.sum(~mask, axis=axis)
valid_counts = dpnp.sum(~mask, axis=-1)
if out is None:
res = dpnp.empty_like(valid_counts, dtype=a.dtype)
else:
Expand All @@ -76,27 +76,19 @@ def _calc_nanmedian(a, axis, out=None):
)
res = out

# Iterate over all indices of the output shape
for idx in dpnp.ndindex(res.shape):
current_valid_counts = valid_counts[idx]
left = (valid_counts - 1) // 2
right = valid_counts // 2

if current_valid_counts > 0:
# Extract the corresponding slice from the last axis of `a`
data = a[idx][:current_valid_counts]
left = (current_valid_counts - 1) // 2
right = current_valid_counts // 2
left_data = dpnp.take_along_axis(a, left[..., None], axis=-1)
right_data = dpnp.take_along_axis(a, right[..., None], axis=-1)
res = dpnp.where(
valid_counts[..., None] > 0, (left_data + right_data) / 2.0, dpnp.nan
)

if left == right:
res[idx] = data[left]
else:
res[idx] = (data[left] + data[right]) / 2.0
else:
warnings.warn(
"All-NaN slice encountered", RuntimeWarning, stacklevel=6
)
res[idx] = dpnp.nan
if mask.all(axis=-1).any():
warnings.warn("All-NaN slice encountered", RuntimeWarning, stacklevel=6)

return res
return dpnp.squeeze(res)


def _flatten_array_along_axes(a, axes_to_flatten, overwrite_input):
Expand Down Expand Up @@ -232,7 +224,8 @@ def dpnp_median(

if ignore_nan:
# sorting puts NaNs at the end
res = _calc_nanmedian(a_sorted, axis=axis, out=out)
assert axis == -1
res = _calc_nanmedian(a_sorted, out=out)
else:
# We can't pass keepdims and use it in dpnp.mean and dpnp.any
# because of the reshape hack that might have been used in
Expand Down

0 comments on commit 21be8f4

Please sign in to comment.