From 6ae92c09dbc68fbcf1d9693d75249c7eb4fe5fc0 Mon Sep 17 00:00:00 2001 From: Timothy Millar Date: Sat, 18 Jun 2022 19:27:03 +1200 Subject: [PATCH] Generic cohort reductions #730 --- sgkit/stats/utils.py | 216 +++++++++++++++++++++++++++++++- sgkit/tests/test_stats_utils.py | 57 +++++++++ 2 files changed, 272 insertions(+), 1 deletion(-) diff --git a/sgkit/stats/utils.py b/sgkit/stats/utils.py index 0b736efbc..669d0a394 100644 --- a/sgkit/stats/utils.py +++ b/sgkit/stats/utils.py @@ -1,9 +1,11 @@ -from typing import Hashable, Tuple +from functools import wraps +from typing import Callable, Hashable, Tuple import dask.array as da import numpy as np import xarray as xr from dask.array import Array +from numba import guvectorize from xarray import DataArray, Dataset from ..typing import ArrayLike @@ -109,3 +111,215 @@ def map_blocks_asnumpy(x: Array) -> Array: x = x.map_blocks(cp.asnumpy) return x + + +def cohort_reduction(gufunc: Callable) -> Callable: + @wraps(gufunc) + def func(x: ArrayLike, cohort: ArrayLike, n: int, axis: int = -1) -> ArrayLike: + x = da.swapaxes(da.asarray(x), axis, -1) + replaced = len(x.shape) - 1 + chunks = x.chunks[0:-1] + (n,) + out = da.map_blocks( + gufunc, + x, + cohort, + np.empty(n, np.int8), + chunks=chunks, + drop_axis=replaced, + new_axis=replaced, + ) + return da.swapaxes(out, axis, -1) + + return func + + +@cohort_reduction +@guvectorize( + [ + "(uint8[:], int64[:], int8[:], uint64[:])", + "(uint64[:], int64[:], int8[:], uint64[:])", + "(int8[:], int64[:], int8[:], int64[:])", + "(int64[:], int64[:], int8[:], int64[:])", + "(float32[:], int64[:], int8[:], float32[:])", + "(float64[:], int64[:], int8[:], float64[:])", + ], + "(n),(n),(c)->(c)", +) +def cohort_sum( + x: ArrayLike, cohort: ArrayLike, _: ArrayLike, out: ArrayLike +) -> None: # pragma: no cover + """Sum of values by cohort. + + Parameters + ---------- + x + Array of values corresponding to each sample. + cohort + Array of integers indicating the cohort membership of + each sample with negative values indicating no cohort. + n + Number of cohorts. + axis + The axis of array x corresponding to samples (defaults + to final axis). + + Returns + ------- + An array with the same number of dimensions as x in which + the sample axis has been replaced with a cohort axis of + size n. + """ + out[:] = 0 + n = len(x) + for i in range(n): + c = cohort[i] + if c >= 0: + out[c] += x[i] + return + + +@cohort_reduction +@guvectorize( + [ + "(uint8[:], int64[:], int8[:], uint64[:])", + "(uint64[:], int64[:], int8[:], uint64[:])", + "(int8[:], int64[:], int8[:], int64[:])", + "(int64[:], int64[:], int8[:], int64[:])", + "(float32[:], int64[:], int8[:], float32[:])", + "(float64[:], int64[:], int8[:], float64[:])", + ], + "(n),(n),(c)->(c)", +) +def cohort_nansum( + x: ArrayLike, cohort: ArrayLike, _: ArrayLike, out: ArrayLike +) -> None: # pragma: no cover + """Sum of values by cohort ignoring nan values. + + Parameters + ---------- + x + Array of values corresponding to each sample. + cohort + Array of integers indicating the cohort membership of + each sample with negative values indicating no cohort. + n + Number of cohorts. + axis + The axis of array x corresponding to samples (defaults + to final axis). + + Returns + ------- + An array with the same number of dimensions as x in which + the sample axis has been replaced with a cohort axis of + size n. + """ + out[:] = 0 + n = len(x) + for i in range(n): + c = cohort[i] + v = x[i] + if (not np.isnan(v)) and (c >= 0): + out[cohort[i]] += v + return + + +@cohort_reduction +@guvectorize( + [ + "(uint8[:], int64[:], int8[:], float64[:])", + "(uint64[:], int64[:], int8[:], float64[:])", + "(int8[:], int64[:], int8[:], float64[:])", + "(int64[:], int64[:], int8[:], float64[:])", + "(float32[:], int64[:], int8[:], float32[:])", + "(float64[:], int64[:], int8[:], float64[:])", + ], + "(n),(n),(c)->(c)", +) +def cohort_mean( + x: ArrayLike, cohort: ArrayLike, _: ArrayLike, out: ArrayLike +) -> None: # pragma: no cover + """Mean of values by cohort. + + Parameters + ---------- + x + Array of values corresponding to each sample. + cohort + Array of integers indicating the cohort membership of + each sample with negative values indicating no cohort. + n + Number of cohorts. + axis + The axis of array x corresponding to samples (defaults + to final axis). + + Returns + ------- + An array with the same number of dimensions as x in which + the sample axis has been replaced with a cohort axis of + size n. + """ + out[:] = 0 + n = len(x) + c = len(_) + count = np.zeros(c) + for i in range(n): + j = cohort[i] + if j >= 0: + out[j] += x[i] + count[j] += 1 + for j in range(c): + out[j] /= count[j] + return + + +@cohort_reduction +@guvectorize( + [ + "(uint8[:], int64[:], int8[:], float64[:])", + "(uint64[:], int64[:], int8[:], float64[:])", + "(int8[:], int64[:], int8[:], float64[:])", + "(int64[:], int64[:], int8[:], float64[:])", + "(float32[:], int64[:], int8[:], float32[:])", + "(float64[:], int64[:], int8[:], float64[:])", + ], + "(n),(n),(c)->(c)", +) +def cohort_nanmean( + x: ArrayLike, cohort: ArrayLike, _: ArrayLike, out: ArrayLike +) -> None: # pragma: no cover + """Mean of values by cohort ignoring nan values. + + Parameters + ---------- + x + Array of values corresponding to each sample. + cohort + Array of integers indicating the cohort membership of + each sample with negative values indicating no cohort. + n + Number of cohorts. + axis + The axis of array x corresponding to samples (defaults + to final axis). + + Returns + ------- + An array with the same number of dimensions as x in which + the sample axis has been replaced with a cohort axis of + size n. + """ + out[:] = 0 + n = len(x) + c = len(_) + count = np.zeros(c) + for i in range(n): + j = cohort[i] + v = x[i] + if (not np.isnan(v)) and (j >= 0): + out[j] += v + count[j] += 1 + for j in range(c): + out[j] /= count[j] + return diff --git a/sgkit/tests/test_stats_utils.py b/sgkit/tests/test_stats_utils.py index c9f4f5e2c..7fcee105a 100644 --- a/sgkit/tests/test_stats_utils.py +++ b/sgkit/tests/test_stats_utils.py @@ -12,6 +12,10 @@ assert_array_shape, assert_block_shape, assert_chunk_shape, + cohort_mean, + cohort_nanmean, + cohort_nansum, + cohort_sum, concat_2d, r2_score, ) @@ -164,3 +168,56 @@ def _col_shape_sum(ds: Dataset) -> int: def _rename_dim(ds: Dataset, prefix: str, name: str) -> Dataset: return ds.rename_dims({d: name for d in ds.dims if str(d).startswith(prefix)}) + + +def _random_cohort_data(chunks, n, axis, missing=0.0, scale=1, dtype=float, seed=0): + shape = tuple(np.sum(tup) for tup in chunks) + np.random.seed(seed) + x = np.random.rand(*shape) * scale + idx = np.random.choice([1, 0], shape, p=[missing, 1 - missing]).astype(bool) + x[idx] = np.nan + x = da.asarray(x, chunks=chunks, dtype=dtype) + cohort = np.random.randint(-1, n, size=shape[axis]) + return x, cohort, n, axis + + +def _cohort_reduction(func, x, cohort, n, axis=-1): + # reference implementation + out = [] + for i in range(n): + idx = np.where(cohort == i)[0] + x_c = np.take(x, idx, axis=axis) + out.append(func(x_c, axis=axis)) + out = np.swapaxes(np.array(out), 0, axis) + return out + + +@pytest.mark.parametrize( + "x, cohort, n, axis", + [ + _random_cohort_data((20,), n=3, axis=0), + _random_cohort_data((20, 20), n=2, axis=0, dtype=np.float32), + _random_cohort_data((10, 10), n=2, axis=-1, scale=30, dtype=np.int16), + _random_cohort_data((20, 20), n=3, axis=-1, missing=0.3), + _random_cohort_data((7, 103, 4), n=5, axis=1, scale=7, missing=0.3), + _random_cohort_data( + ((3, 4), (50, 50, 3), 4), n=5, axis=1, scale=7, dtype=np.uint8 + ), + _random_cohort_data( + ((6, 6), (50, 50, 7), (3, 1)), n=5, axis=1, scale=7, missing=0.3 + ), + ], +) +@pytest.mark.parametrize( + "reduction, func", + [ + (cohort_sum, np.sum), + (cohort_nansum, np.nansum), + (cohort_mean, np.mean), + (cohort_nanmean, np.nanmean), + ], +) +def test_cohort_reductions(reduction, func, x, cohort, n, axis): + expect = _cohort_reduction(func, x, cohort, n, axis=axis) + actual = reduction(x, cohort, n, axis=axis) + np.testing.assert_array_almost_equal(expect, actual)