Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge output variables with input dataset #217

Merged
merged 4 commits into from
Sep 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ fail_under = 100
[tool:pytest]
addopts = --doctest-modules --ignore=validation
norecursedirs = .eggs docs
filterwarnings =
error

[flake8]
ignore =
Expand Down
67 changes: 45 additions & 22 deletions sgkit/stats/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import dask.array as da
import numpy as np
import xarray as xr
from numba import guvectorize
from xarray import DataArray, Dataset
from xarray import Dataset

from ..typing import ArrayLike
from sgkit.typing import ArrayLike
from sgkit.utils import merge_datasets


@guvectorize( # type: ignore
Expand Down Expand Up @@ -45,21 +45,27 @@ def count_alleles(g: ArrayLike, _: ArrayLike, out: ArrayLike) -> None:
out[a] += 1


def count_call_alleles(ds: Dataset) -> DataArray:
def count_call_alleles(ds: Dataset, merge: bool = True) -> Dataset:
"""Compute per sample allele counts from genotype calls.

Parameters
----------
ds : Dataset
Genotype call dataset such as from
`sgkit.create_genotype_call_dataset`.
merge : bool, optional
If True (the default), merge the input dataset and the computed
output variables into a single dataset. Output variables will
overwrite any input variables with the same name, and a warning
will be issued in this case.
If False, return only the computed output variables.

Returns
-------
call_allele_count : DataArray
Allele counts with shape (variants, samples, alleles) and values
corresponding to the number of non-missing occurrences
of each allele.
Dataset
Array `call_allele_count` of allele counts with
shape (variants, samples, alleles) and values corresponding to
the number of non-missing occurrences of each allele.

Examples
--------
Expand All @@ -75,7 +81,7 @@ def count_call_alleles(ds: Dataset) -> DataArray:
2 0/1 1/0
3 0/0 0/0

>>> sg.count_call_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE
>>> sg.count_call_alleles(ds)["call_allele_count"].values # doctest: +NORMALIZE_WHITESPACE
array([[[1, 1],
[1, 1]],
<BLANKLINE>
Expand All @@ -92,28 +98,40 @@ def count_call_alleles(ds: Dataset) -> DataArray:
G = da.asarray(ds["call_genotype"])
shape = (G.chunks[0], G.chunks[1], n_alleles)
N = da.empty(n_alleles, dtype=np.uint8)
return xr.DataArray(
da.map_blocks(count_alleles, G, N, chunks=shape, drop_axis=2, new_axis=2),
dims=("variants", "samples", "alleles"),
name="call_allele_count",
new_ds = Dataset(
{
"call_allele_count": (
("variants", "samples", "alleles"),
da.map_blocks(
count_alleles, G, N, chunks=shape, drop_axis=2, new_axis=2
),
)
}
)
return merge_datasets(ds, new_ds) if merge else new_ds


def count_variant_alleles(ds: Dataset) -> DataArray:
def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset:
"""Compute allele count from genotype calls.

Parameters
----------
ds : Dataset
Genotype call dataset such as from
`sgkit.create_genotype_call_dataset`.
merge : bool, optional
If True (the default), merge the input dataset and the computed
output variables into a single dataset. Output variables will
overwrite any input variables with the same name, and a warning
will be issued in this case.
If False, return only the computed output variables.

Returns
-------
variant_allele_count : DataArray
Allele counts with shape (variants, alleles) and values
corresponding to the number of non-missing occurrences
of each allele.
Dataset
Array `variant_allele_count` of allele counts with
shape (variants, alleles) and values corresponding to
the number of non-missing occurrences of each allele.

Examples
--------
Expand All @@ -129,13 +147,18 @@ def count_variant_alleles(ds: Dataset) -> DataArray:
2 0/1 1/0
3 0/0 0/0

>>> sg.count_variant_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE
>>> sg.count_variant_alleles(ds)["variant_allele_count"].values # doctest: +NORMALIZE_WHITESPACE
array([[2, 2],
[1, 3],
[2, 2],
[4, 0]], dtype=uint64)
"""
return xr.DataArray(
count_call_alleles(ds).sum(dim="samples").rename("variant_allele_count"),
dims=("variants", "alleles"),
new_ds = Dataset(
{
"variant_allele_count": (
("variants", "alleles"),
count_call_alleles(ds)["call_allele_count"].sum(dim="samples"),
)
}
)
return merge_datasets(ds, new_ds) if merge else new_ds
44 changes: 32 additions & 12 deletions sgkit/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,26 @@ def get_dataset(calls: ArrayLike, **kwargs: Any) -> Dataset:


def test_count_variant_alleles__single_variant_single_sample():
ac = count_variant_alleles(get_dataset([[[1, 0]]]))
ds = count_variant_alleles(get_dataset([[[1, 0]]]))
assert "call_genotype" in ds
ac = ds["variant_allele_count"]
np.testing.assert_equal(ac, np.array([[1, 1]]))


def test_count_variant_alleles__multi_variant_single_sample():
ac = count_variant_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]))
ds = count_variant_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]))
ac = ds["variant_allele_count"]
np.testing.assert_equal(ac, np.array([[2, 0], [1, 1], [1, 1], [0, 2]]))


def test_count_variant_alleles__single_variant_multi_sample():
ac = count_variant_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]]))
ds = count_variant_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]]))
ac = ds["variant_allele_count"]
np.testing.assert_equal(ac, np.array([[4, 4]]))


def test_count_variant_alleles__multi_variant_multi_sample():
ac = count_variant_alleles(
ds = count_variant_alleles(
get_dataset(
[
[[0, 0], [0, 0], [0, 0]],
Expand All @@ -46,11 +50,12 @@ def test_count_variant_alleles__multi_variant_multi_sample():
]
)
)
ac = ds["variant_allele_count"]
np.testing.assert_equal(ac, np.array([[6, 0], [5, 1], [2, 4], [0, 6]]))


def test_count_variant_alleles__missing_data():
ac = count_variant_alleles(
ds = count_variant_alleles(
get_dataset(
[
[[-1, -1], [-1, -1], [-1, -1]],
Expand All @@ -60,11 +65,12 @@ def test_count_variant_alleles__missing_data():
]
)
)
ac = ds["variant_allele_count"]
np.testing.assert_equal(ac, np.array([[0, 0], [2, 1], [1, 2], [0, 6]]))


def test_count_variant_alleles__higher_ploidy():
ac = count_variant_alleles(
ds = count_variant_alleles(
get_dataset(
[
[[-1, -1, 0], [-1, -1, 1], [-1, -1, 2]],
Expand All @@ -74,6 +80,7 @@ def test_count_variant_alleles__higher_ploidy():
n_ploidy=3,
)
)
ac = ds["variant_allele_count"]
np.testing.assert_equal(ac, np.array([[1, 1, 1, 0], [1, 2, 2, 1]]))


Expand All @@ -88,23 +95,33 @@ def test_count_variant_alleles__chunked():
xr.testing.assert_equal(ac1, ac2) # type: ignore[no-untyped-call]


def test_count_variant_alleles__no_merge():
ds = count_variant_alleles(get_dataset([[[1, 0]]]), merge=False)
assert "call_genotype" not in ds
ac = ds["variant_allele_count"]
np.testing.assert_equal(ac, np.array([[1, 1]]))


def test_count_call_alleles__single_variant_single_sample():
ac = count_call_alleles(get_dataset([[[1, 0]]]))
ds = count_call_alleles(get_dataset([[[1, 0]]]))
ac = ds["call_allele_count"]
np.testing.assert_equal(ac, np.array([[[1, 1]]]))


def test_count_call_alleles__multi_variant_single_sample():
ac = count_call_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]))
ds = count_call_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]))
ac = ds["call_allele_count"]
np.testing.assert_equal(ac, np.array([[[2, 0]], [[1, 1]], [[1, 1]], [[0, 2]]]))


def test_count_call_alleles__single_variant_multi_sample():
ac = count_call_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]]))
ds = count_call_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]]))
ac = ds["call_allele_count"]
np.testing.assert_equal(ac, np.array([[[2, 0], [1, 1], [1, 1], [0, 2]]]))


def test_count_call_alleles__multi_variant_multi_sample():
ac = count_call_alleles(
ds = count_call_alleles(
get_dataset(
[
[[0, 0], [0, 0], [0, 0]],
Expand All @@ -114,6 +131,7 @@ def test_count_call_alleles__multi_variant_multi_sample():
]
)
)
ac = ds["call_allele_count"]
np.testing.assert_equal(
ac,
np.array(
Expand All @@ -128,7 +146,7 @@ def test_count_call_alleles__multi_variant_multi_sample():


def test_count_call_alleles__missing_data():
ac = count_call_alleles(
ds = count_call_alleles(
get_dataset(
[
[[-1, -1], [-1, -1], [-1, -1]],
Expand All @@ -138,6 +156,7 @@ def test_count_call_alleles__missing_data():
]
)
)
ac = ds["call_allele_count"]
np.testing.assert_equal(
ac,
np.array(
Expand All @@ -152,7 +171,7 @@ def test_count_call_alleles__missing_data():


def test_count_call_alleles__higher_ploidy():
ac = count_call_alleles(
ds = count_call_alleles(
get_dataset(
[
[[-1, -1, 0], [-1, -1, 1], [-1, -1, 2]],
Expand All @@ -162,6 +181,7 @@ def test_count_call_alleles__higher_ploidy():
n_ploidy=3,
)
)
ac = ds["call_allele_count"]
np.testing.assert_equal(
ac,
np.array(
Expand Down
24 changes: 23 additions & 1 deletion sgkit/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from typing import Any, List

import dask.array as da
import numpy as np
import pytest
import xarray as xr
from hypothesis import given, settings
from hypothesis import strategies as st

from sgkit.utils import check_array_like, encode_array, split_array_chunks
from sgkit.utils import (
MergeWarning,
check_array_like,
encode_array,
merge_datasets,
split_array_chunks,
)


def test_check_array_like():
Expand Down Expand Up @@ -66,6 +74,20 @@ def test_encode_array(
np.testing.assert_equal(n, expected_names)


def test_merge_datasets():
ds = xr.Dataset(dict(x=xr.DataArray(da.zeros(100))))

new_ds1 = xr.Dataset(dict(y=xr.DataArray(da.zeros(100))))
new_ds2 = xr.Dataset(dict(y=xr.DataArray(da.ones(100))))

ds = merge_datasets(ds, new_ds1)
assert "y" in ds

with pytest.warns(MergeWarning):
ds = merge_datasets(ds, new_ds2)
np.testing.assert_equal(ds["y"].values, np.ones(100))


@pytest.mark.parametrize(
"n,blocks,expected_chunks",
[
Expand Down
35 changes: 35 additions & 0 deletions sgkit/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import warnings
from typing import Any, List, Set, Tuple, Union

import numpy as np
from xarray import Dataset

from .typing import ArrayLike, DType

Expand Down Expand Up @@ -100,6 +102,39 @@ def encode_array(x: ArrayLike) -> Tuple[ArrayLike, List[Any]]:
return rank[inverse], names[index]


class MergeWarning(UserWarning):
"""Warnings about merging datasets."""

pass


def merge_datasets(input: Dataset, output: Dataset) -> Dataset:
"""Merge the input and output datasets into a new dataset, giving precedence to variables in the output.

Parameters
----------
input : Dataset
The input dataset.
output : Dataset
The output dataset.

Returns
-------
Dataset
The merged dataset. If `input` and `output` have variables with the same name, a `MergeWarning`
is issued, and the variables from the `output` dataset are used.
"""
input_vars = {str(v) for v in input.data_vars.keys()}
output_vars = {str(v) for v in output.data_vars.keys()}
clobber_vars = sorted(list(input_vars & output_vars))
if len(clobber_vars) > 0:
warnings.warn(
f"The following variables in the input dataset will be replaced in the output: {', '.join(clobber_vars)}",
MergeWarning,
)
return output.merge(input, compat="override")


def split_array_chunks(n: int, blocks: int) -> Tuple[int, ...]:
"""Compute chunk sizes for an array split into blocks.

Expand Down