Skip to content

Commit

Permalink
Implement median helper
Browse files Browse the repository at this point in the history
Co-authored-by: Ricardo Vieira <28983449+ricardov94@users.noreply.github.com>
  • Loading branch information
Dhruvanshu-Joshi and ricardoV94 committed Oct 11, 2024
1 parent ed6ca16 commit f277af7
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 0 deletions.
43 changes: 43 additions & 0 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,6 +1566,48 @@ def std(input, axis=None, ddof=0, keepdims=False, corrected=False):
return ret


def median(x: TensorLike, axis=None) -> TensorVariable:
"""
Computes the median along the given axis(es) of a tensor `input`.
Parameters
----------
x: TensorVariable
The input tensor.
axis: None or int or (list of int) (see `Sum`)
Compute the median along this axis of the tensor.
None means all axes (like numpy).
"""
from pytensor.ifelse import ifelse

x = as_tensor_variable(x)
x_ndim = x.type.ndim
if axis is None:
axis = list(range(x_ndim))
else:
axis = list(normalize_axis_tuple(axis, x_ndim))

non_axis = [i for i in range(x_ndim) if i not in axis]
non_axis_shape = [x.shape[i] for i in non_axis]

# Put axis at the end and unravel them
x_raveled = x.transpose(*non_axis, *axis)
if len(axis) > 1:
x_raveled = x_raveled.reshape((*non_axis_shape, -1))
raveled_size = x_raveled.shape[-1]
k = raveled_size // 2

# Sort the input tensor along the specified axis and pick median value
x_sorted = x_raveled.sort(axis=-1)
k_values = x_sorted[..., k]
km1_values = x_sorted[..., k - 1]

even_median = (k_values + km1_values) / 2.0
odd_median = k_values.astype(even_median.type.dtype)
even_k = eq(mod(raveled_size, 2), 0)
return ifelse(even_k, even_median, odd_median, name="median")


@scalar_elemwise(symbolname="scalar_maximum")
def maximum(x, y):
"""elemwise maximum. See max for the maximum in one tensor"""
Expand Down Expand Up @@ -3015,6 +3057,7 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
"sum",
"prod",
"mean",
"median",
"var",
"std",
"std",
Expand Down
31 changes: 31 additions & 0 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
max_and_argmax,
maximum,
mean,
median,
min,
minimum,
mod,
Expand Down Expand Up @@ -3735,3 +3736,33 @@ def test_nan_to_num(nan, posinf, neginf):
out,
np.nan_to_num(y, nan=nan, posinf=posinf, neginf=neginf),
)


@pytest.mark.parametrize(
"ndim, axis",
[
(2, None),
(2, 1),
(2, (0, 1)),
(3, None),
(3, (1, 2)),
(4, (1, 3, 0)),
],
)
def test_median(ndim, axis):
# Generate random data with both odd and even lengths
shape_even = np.arange(1, ndim + 1) * 2
shape_odd = shape_even - 1

data_even = np.random.rand(*shape_even)
data_odd = np.random.rand(*shape_odd)

x = tensor(dtype="float64", shape=(None,) * ndim)
f = function([x], median(x, axis=axis))
result_odd = f(data_odd)
result_even = f(data_even)
expected_odd = np.median(data_odd, axis=axis)
expected_even = np.median(data_even, axis=axis)

assert np.allclose(result_odd, expected_odd)
assert np.allclose(result_even, expected_even)

0 comments on commit f277af7

Please sign in to comment.