Skip to content

Commit

Permalink
Add cohort_statistic function sgkit-dev#730
Browse files Browse the repository at this point in the history
  • Loading branch information
timothymillar committed Dec 7, 2021
1 parent 31dc606 commit df3d633
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
51 changes: 51 additions & 0 deletions sgkit/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sgkit.utils import (
MergeWarning,
check_array_like,
cohort_statistic,
define_variable_if_absent,
encode_array,
hash_array,
Expand Down Expand Up @@ -266,3 +267,53 @@ def test_smallest_numpy_int_dtype__overflow():

with pytest.raises(OverflowError):
smallest_numpy_int_dtype(np.iinfo(np.int64).max + 1)


@pytest.mark.parametrize(
"statistic,expect",
[
(
np.mean,
[
[1.0, 0.75, 0.5],
[2 / 3, 0.25, 0.0],
[2 / 3, 0.75, 0.5],
[2 / 3, 0.5, 1.0],
[1 / 3, 0.5, 0.0],
],
),
(np.sum, [[3, 3, 1], [2, 1, 0], [2, 3, 1], [2, 2, 2], [1, 2, 0]]),
],
)
@pytest.mark.parametrize(
"chunks",
[
((5,), (10,)),
((3, 2), (10,)),
((3, 2), (5, 5)),
],
)
def test_cohort_statistic(statistic, expect, chunks):
variables = da.asarray(
[
[1, 1, 1, 0, 1, 1, 0, 0, 1, 1],
[0, 0, 1, 0, 1, 0, 0, 0, 1, 0],
[0, 0, 1, 1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 0, 0, 0, 0, 1, 1],
[0, 1, 0, 0, 1, 1, 1, 0, 0, 0],
],
chunks=chunks,
)
cohorts = np.array([0, 1, 0, 2, 0, 1, -1, 1, 1, 2])
np.testing.assert_array_equal(
expect, cohort_statistic(variables, statistic, cohorts, axis=1)
)


def test_cohort_statistic_axis0():
variables = da.asarray([2, 3, 2, 4, 3, 1, 4, 5, 3, 1])
cohorts = np.array([0, 0, 0, 0, 0, -1, 1, 1, 1, 2])
np.testing.assert_array_equal(
[2.8, 4.0, 1.0],
cohort_statistic(variables, np.mean, cohorts, sample_axis=0, axis=0),
)
38 changes: 38 additions & 0 deletions sgkit/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from typing import Any, Callable, Hashable, List, Mapping, Optional, Set, Tuple, Union

import dask.array as da
import numpy as np
from numba import guvectorize
from xarray import Dataset
Expand Down Expand Up @@ -359,3 +360,40 @@ def smallest_numpy_int_dtype(value: int) -> Optional[DType]:
if np.iinfo(dtype).min <= value <= np.iinfo(dtype).max:
return dtype
raise OverflowError(f"Value {value} cannot be stored in np.int64")


def cohort_statistic(
values: ArrayLike,
statistic: Callable[..., ArrayLike],
cohorts: ArrayLike,
sample_axis: int = 1,
**kwargs: Any,
) -> da.Array:
"""Calculate a statistic for each cohort of samples.
Parameters
----------
values
An n-dimensional array of sample values.
statistic
A callable to apply to the samples of each cohort. The callable is
expected to consume the samples axis.
cohorts
An array of integers indicating which cohort each sample is assigned to.
Negative integers indicate that a sample is not assigned to any cohort.
sample_axis
Integer indicating the samples axis of the values array.
kwargs
Key word arguments to pass to the callable statistic.
Returns
-------
Array of results for each cohort.
"""
values = da.asarray(values)
cohorts = np.array(cohorts)
n_cohorts = cohorts.max() + 1
idx = [cohorts == c for c in range(n_cohorts)]
seq = [da.take(values, i, axis=sample_axis) for i in idx]
out = da.stack([statistic(c, **kwargs) for c in seq], axis=sample_axis)
return out

0 comments on commit df3d633

Please sign in to comment.