Skip to content

Commit

Permalink
Use cohort_statistic in observed_heterozygosity
Browse files Browse the repository at this point in the history
  • Loading branch information
timothymillar committed Dec 7, 2021
1 parent df3d633 commit 01d29f7
Showing 1 changed file with 6 additions and 58 deletions.
64 changes: 6 additions & 58 deletions sgkit/stats/popgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sgkit.window import has_windows, window_statistic

from .. import variables
from ..utils import cohort_statistic
from .aggregation import (
count_cohort_alleles,
count_variant_alleles,
Expand Down Expand Up @@ -939,52 +940,6 @@ def Garud_H(
return conditional_merge_datasets(ds, new_ds, merge)


@guvectorize( # type: ignore
[
"void(float64[:], int32[:], uint8[:], float64[:])",
"void(float64[:], int64[:], uint8[:], float64[:])",
],
"(n),(n),(c)->(c)",
nopython=True,
cache=True,
)
def _cohort_observed_heterozygosity(
hi: ArrayLike, cohorts: ArrayLike, _: ArrayLike, out: ArrayLike
) -> None: # pragma: no cover
"""Generalized U-function for computing cohort observed heterozygosity.
Parameters
----------
hi
Individual sample heterozygosity of shape (samples,).
cohorts
Cohort indexes for samples of shape (samples,).
_
Dummy variable of type `uint8` and shape (cohorts,) used to
define the number of cohorts.
out
Observed heterozygosity with shape (cohorts,) and values corresponding
to the mean individual heterozygosity of each cohort.
"""
out[:] = 0
_[:] = 0
n_samples = len(hi)
n_cohorts = len(out)
for i in range(n_samples):
h = hi[i]
if not np.isnan(h):
c = cohorts[i]
if c >= 0:
out[c] += h
_[c] += 1
for j in range(n_cohorts):
n = _[j]
if n != 0:
out[j] /= n
else:
out[j] = np.nan


def observed_heterozygosity(
ds: Dataset,
*,
Expand Down Expand Up @@ -1061,18 +1016,11 @@ def observed_heterozygosity(
variables.validate(ds, {call_heterozygosity: variables.call_heterozygosity_spec})
hi = da.asarray(ds[call_heterozygosity])
sc = da.asarray(ds[sample_cohort])
n_cohorts = sc.max().compute() + 1
shape = (hi.chunks[0], n_cohorts)
n = da.zeros(n_cohorts, dtype=np.uint8)
ho = da.map_blocks(
_cohort_observed_heterozygosity,
hi,
sc,
n,
chunks=shape,
drop_axis=1,
new_axis=1,
dtype=np.float64,
ho = cohort_statistic(
values=hi,
statistic=np.nanmean,
cohorts=sc,
axis=1,
)
if has_windows(ds):
ho_sum = window_statistic(
Expand Down

0 comments on commit 01d29f7

Please sign in to comment.