From 23fdea96968e7daaec01ebd7ca738efacc6900f8 Mon Sep 17 00:00:00 2001 From: Tim Millar Date: Sun, 16 Aug 2020 21:01:31 +1200 Subject: [PATCH 01/13] Implement count_call_alleles #85 --- sgkit/__init__.py | 3 +- sgkit/stats/aggregation.py | 116 ++++++++++++++++++++++++++++++++ sgkit/tests/test_aggregation.py | 95 +++++++++++++++++++++++++- 3 files changed, 212 insertions(+), 2 deletions(-) diff --git a/sgkit/__init__.py b/sgkit/__init__.py index 5b70b304b..eabbecb71 100644 --- a/sgkit/__init__.py +++ b/sgkit/__init__.py @@ -7,7 +7,7 @@ create_genotype_dosage_dataset, ) from .display import display_genotypes -from .stats.aggregation import count_alleles +from .stats.aggregation import count_alleles, count_call_alleles from .stats.association import gwas_linear_regression from .stats.regenie import regenie @@ -18,6 +18,7 @@ "DIM_VARIANT", "create_genotype_call_dataset", "count_alleles", + "count_call_alleles", "create_genotype_dosage_dataset", "display_genotypes", "gwas_linear_regression", diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index ae7098f87..592f5dbfd 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -2,6 +2,7 @@ import numpy as np import xarray as xr from xarray import DataArray, Dataset +from numba import njit def count_alleles(ds: Dataset) -> DataArray: @@ -69,3 +70,118 @@ def count_alleles(ds: Dataset) -> DataArray: assert AC.shape == (n_variant, n_allele) return DataArray(data=AC, dims=("variants", "alleles"), name="variant_allele_count") + + +def count_call_alleles_ndarray(genotypes, mask, n_alleles=None, dtype=np.uint8): + """Compute allele count from genotype calls. + + Parameters + ---------- + genotypes : ndarray, int, shape (variants, samples, ploidy) + Array of genotype calls. + mask : ndarray, bool, shape (variants, samples, ploidy) + Array of booleans indicating individual allele calls + which should not be counted. + n_alleles : int, optional. + The number of unique alleles to be counted + (defaults to all alleles). + dtype : type, optional + Dtype of the allele counts. + + Returns + ------- + call_allele_count : ndarray, shape (variants, samples, alleles) + Allele counts with values corresponding to the number + of non-missing occurrences of each allele. + + """ + assert genotypes.shape == mask.shape + n_variants, n_samples, ploidy = genotypes.shape + + # default to counting all alleles + if n_alleles is None: + n_alleles = np.max(genotypes) + 1 + + ac = np.zeros((n_variants, n_samples, n_alleles), dtype=dtype) + for i in range(n_variants): + for j in range(n_samples): + for k in range(ploidy): + if mask[i, j, k]: + pass + else: + a = genotypes[i, j, k] + if a < 0: + raise ValueError('Encountered unmasked negative allele value.') + if a >= n_alleles: + pass + else: + ac[i, j, a] += 1 + return ac + + +count_call_alleles_ndarray_jit = njit(count_call_alleles_ndarray, nogil=True) + + +def count_call_alleles(ds: Dataset, dtype: type = np.uint8) -> DataArray: + """Compute allele count from genotype calls. + + Parameters + ---------- + ds : Dataset + Genotype call dataset such as from + `sgkit.create_genotype_call_dataset`. + dtype : type, optional + Dtype of the allele counts. + + Returns + ------- + call_allele_count : DataArray + Allele counts with shape (variants, samples, alleles) and values + corresponding to the number of non-missing occurrences + of each allele. + + Examples + -------- + + >>> import sgkit as sg + >>> from sgkit.testing import simulate_genotype_call_dataset + >>> ds = simulate_genotype_call_dataset(n_variant=4, n_sample=2, seed=1) + >>> ds['call_genotype'].to_series().unstack().astype(str).apply('/'.join, axis=1).unstack() # doctest: +NORMALIZE_WHITESPACE + samples 0 1 + variants + 0 1/0 1/0 + 1 1/0 1/1 + 2 0/1 1/0 + 3 0/0 0/0 + + >>> sg.count_call_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE + array([[[1, 1], + [1, 1]], + + [[1, 1], + [0, 2]], + + [[1, 1], + [1, 1]], + + [[2, 0], + [2, 0]]], dtype=uint8 + """ + # dask arrays must have matching chunk size + g = da.asarray(ds["call_genotype"].values) + m = da.asarray(ds["call_genotype_mask"].values).rechunk(g.chunks) + assert g.chunks == m.chunks + + # shape of resulting chunks + n_allele = ds.dims["alleles"] + shape = (g.chunks[0], g.chunks[1], n_allele) + + # map function ensuring constant allele dimension size + func = lambda x, y: count_call_alleles_ndarray_jit(x, y, n_allele, dtype) + ac = da.map_overlap(func, g, m, chunks=shape) + + return DataArray( + data=ac, + dims=("variants", "samples", "alleles"), + name="call_allele_count", + ) diff --git a/sgkit/tests/test_aggregation.py b/sgkit/tests/test_aggregation.py index 53ac0bd9b..779829c57 100644 --- a/sgkit/tests/test_aggregation.py +++ b/sgkit/tests/test_aggregation.py @@ -4,7 +4,7 @@ import xarray as xr from xarray import Dataset -from sgkit.stats.aggregation import count_alleles +from sgkit.stats.aggregation import count_alleles, count_call_alleles from sgkit.testing import simulate_genotype_call_dataset from sgkit.typing import ArrayLike @@ -86,3 +86,96 @@ def test_count_alleles__chunked(): ds["call_genotype"] = ds["call_genotype"].chunk(chunks=(5, 5, 1)) # type: ignore[arg-type] ac2 = count_alleles(ds) xr.testing.assert_equal(ac1, ac2) # type: ignore[no-untyped-call] + + +def test_count_call_alleles__single_variant_single_sample(): + ac = count_call_alleles(get_dataset([[[1, 0]]])) + np.testing.assert_equal(ac, np.array([[[1, 1]]])) + + +def test_count_call_alleles__multi_variant_single_sample(): + ac = count_call_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]])) + np.testing.assert_equal(ac, np.array([[[2, 0]], [[1, 1]], [[1, 1]], [[0, 2]]])) + + +def test_count_call_alleles__single_variant_multi_sample(): + ac = count_call_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]])) + np.testing.assert_equal(ac, np.array([[[2, 0], [1, 1], [1, 1], [0, 2]]])) + + +def test_count_call_alleles__multi_variant_multi_sample(): + ac = count_call_alleles( + get_dataset( + [ + [[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [0, 1]], + [[1, 1], [0, 1], [1, 0]], + [[1, 1], [1, 1], [1, 1]], + ] + ) + ) + np.testing.assert_equal( + ac, np.array( + [ + [[2, 0], [2, 0], [2, 0]], + [[2, 0], [2, 0], [1, 1]], + [[0, 2], [1, 1], [1, 1]], + [[0, 2], [0, 2], [0, 2]], + ] + ) + ) + + +def test_count_call_alleles__missing_data(): + ac = count_call_alleles( + get_dataset( + [ + [[-1, -1], [-1, -1], [-1, -1]], + [[-1, -1], [0, 0], [-1, 1]], + [[1, 1], [-1, -1], [-1, 0]], + [[1, 1], [1, 1], [1, 1]], + ] + ) + ) + np.testing.assert_equal( + ac, np.array( + [ + [[0, 0], [0, 0], [0, 0]], + [[0, 0], [2, 0], [0, 1]], + [[0, 2], [0, 0], [1, 0]], + [[0, 2], [0, 2], [0, 2]], + ] + ) + ) + + +def test_count_call_alleles__higher_ploidy(): + ac = count_call_alleles( + get_dataset( + [ + [[-1, -1, 0], [-1, -1, 1], [-1, -1, 2]], + [[0, 1, 2], [1, 2, 3], [-1, -1, -1]], + ], + n_allele=4, + n_ploidy=3, + ) + ) + np.testing.assert_equal( + ac, np.array( + [ + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]], + [[1, 1, 1, 0], [0, 1, 1, 1], [0, 0, 0, 0]], + ] + ) + ) + + +def test_count_call_alleles__chunked(): + rs = np.random.RandomState(0) + calls = rs.randint(0, 1, size=(50, 10, 2)) + ds = get_dataset(calls) + ac1 = count_call_alleles(ds) + # Coerce from numpy to multiple chunks in all dimensions + ds["call_genotype"] = ds["call_genotype"].chunk(chunks=(5, 5, 1)) # type: ignore[arg-type] + ac2 = count_call_alleles(ds) + xr.testing.assert_equal(ac1, ac2) # type: ignore[no-untyped-call] From 39995ede7eb54bfb40daaef51019b94a75639abe Mon Sep 17 00:00:00 2001 From: Tim Millar Date: Sun, 16 Aug 2020 21:18:15 +1200 Subject: [PATCH 02/13] Fixes for mypy and black --- sgkit/stats/aggregation.py | 24 ++++++++++++------------ sgkit/tests/test_aggregation.py | 15 +++++++++------ 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index 592f5dbfd..8e26c52ad 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -1,8 +1,8 @@ import dask.array as da import numpy as np import xarray as xr -from xarray import DataArray, Dataset from numba import njit +from xarray import DataArray, Dataset def count_alleles(ds: Dataset) -> DataArray: @@ -72,12 +72,14 @@ def count_alleles(ds: Dataset) -> DataArray: return DataArray(data=AC, dims=("variants", "alleles"), name="variant_allele_count") -def count_call_alleles_ndarray(genotypes, mask, n_alleles=None, dtype=np.uint8): +def count_call_alleles_ndarray( + g: np.ndarray, mask: np.ndarray, n_alleles: int = -1, dtype: type = np.uint8 +) -> np.ndarray: """Compute allele count from genotype calls. Parameters ---------- - genotypes : ndarray, int, shape (variants, samples, ploidy) + g : ndarray, int, shape (variants, samples, ploidy) Array of genotype calls. mask : ndarray, bool, shape (variants, samples, ploidy) Array of booleans indicating individual allele calls @@ -95,12 +97,12 @@ def count_call_alleles_ndarray(genotypes, mask, n_alleles=None, dtype=np.uint8): of non-missing occurrences of each allele. """ - assert genotypes.shape == mask.shape - n_variants, n_samples, ploidy = genotypes.shape + assert g.shape == mask.shape + n_variants, n_samples, ploidy = g.shape # default to counting all alleles - if n_alleles is None: - n_alleles = np.max(genotypes) + 1 + if n_alleles < 0: + n_alleles = np.max(g) + 1 ac = np.zeros((n_variants, n_samples, n_alleles), dtype=dtype) for i in range(n_variants): @@ -109,9 +111,9 @@ def count_call_alleles_ndarray(genotypes, mask, n_alleles=None, dtype=np.uint8): if mask[i, j, k]: pass else: - a = genotypes[i, j, k] + a = g[i, j, k] if a < 0: - raise ValueError('Encountered unmasked negative allele value.') + raise ValueError("Encountered unmasked negative allele value.") if a >= n_alleles: pass else: @@ -181,7 +183,5 @@ def count_call_alleles(ds: Dataset, dtype: type = np.uint8) -> DataArray: ac = da.map_overlap(func, g, m, chunks=shape) return DataArray( - data=ac, - dims=("variants", "samples", "alleles"), - name="call_allele_count", + data=ac, dims=("variants", "samples", "alleles"), name="call_allele_count", ) diff --git a/sgkit/tests/test_aggregation.py b/sgkit/tests/test_aggregation.py index 779829c57..6f757a180 100644 --- a/sgkit/tests/test_aggregation.py +++ b/sgkit/tests/test_aggregation.py @@ -115,14 +115,15 @@ def test_count_call_alleles__multi_variant_multi_sample(): ) ) np.testing.assert_equal( - ac, np.array( + ac, + np.array( [ [[2, 0], [2, 0], [2, 0]], [[2, 0], [2, 0], [1, 1]], [[0, 2], [1, 1], [1, 1]], [[0, 2], [0, 2], [0, 2]], ] - ) + ), ) @@ -138,14 +139,15 @@ def test_count_call_alleles__missing_data(): ) ) np.testing.assert_equal( - ac, np.array( + ac, + np.array( [ [[0, 0], [0, 0], [0, 0]], [[0, 0], [2, 0], [0, 1]], [[0, 2], [0, 0], [1, 0]], [[0, 2], [0, 2], [0, 2]], ] - ) + ), ) @@ -161,12 +163,13 @@ def test_count_call_alleles__higher_ploidy(): ) ) np.testing.assert_equal( - ac, np.array( + ac, + np.array( [ [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]], [[1, 1, 1, 0], [0, 1, 1, 1], [0, 0, 0, 0]], ] - ) + ), ) From e888bc43d63e59d01dcae4671d17dfef8b782269 Mon Sep 17 00:00:00 2001 From: Tim Millar Date: Sun, 16 Aug 2020 22:01:25 +1200 Subject: [PATCH 03/13] Add dependency on numba --- requirements.txt | 1 + setup.cfg | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 1990eaac5..45e97f470 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ xarray dask[array] scipy zarr +numba diff --git a/setup.cfg b/setup.cfg index 139df0d62..c8ccddc21 100644 --- a/setup.cfg +++ b/setup.cfg @@ -58,7 +58,7 @@ ignore = profile = black default_section = THIRDPARTY known_first_party = sgkit -known_third_party = dask,fire,glow,hail,hypothesis,invoke,numpy,pandas,pkg_resources,pyspark,pytest,setuptools,sgkit_plink,xarray,yaml,zarr +known_third_party = dask,fire,glow,hail,hypothesis,invoke,numba,numpy,pandas,pkg_resources,pyspark,pytest,setuptools,sgkit_plink,xarray,yaml,zarr multi_line_output = 3 include_trailing_comma = True force_grid_wrap = 0 @@ -87,3 +87,5 @@ allow_redefinition = True disallow_untyped_defs = False [mypy-validation.*] ignore_errors = True +[mypy-numba.*] +ignore_missing_imports = True From 324fa9a79dadf55505c96226b2e98679ec192329 Mon Sep 17 00:00:00 2001 From: Tim Millar Date: Tue, 18 Aug 2020 22:18:44 +1200 Subject: [PATCH 04/13] gufunc implementation of count_alleles --- sgkit/__init__.py | 4 +- sgkit/stats/aggregation.py | 171 ++++++++++---------------------- sgkit/tests/test_aggregation.py | 32 +++--- 3 files changed, 68 insertions(+), 139 deletions(-) diff --git a/sgkit/__init__.py b/sgkit/__init__.py index a7754cced..cdc7a4753 100644 --- a/sgkit/__init__.py +++ b/sgkit/__init__.py @@ -8,7 +8,7 @@ ) from .display import display_genotypes from .io.vcfzarr_reader import read_vcfzarr -from .stats.aggregation import count_alleles, count_call_alleles +from .stats.aggregation import count_variant_alleles, count_call_alleles from .stats.association import gwas_linear_regression from .stats.regenie import regenie @@ -18,7 +18,7 @@ "DIM_SAMPLE", "DIM_VARIANT", "create_genotype_call_dataset", - "count_alleles", + "count_variant_alleles", "count_call_alleles", "create_genotype_dosage_dataset", "display_genotypes", diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index 8e26c52ad..097a66bcd 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -1,12 +1,25 @@ import dask.array as da import numpy as np import xarray as xr -from numba import njit +import numba from xarray import DataArray, Dataset -def count_alleles(ds: Dataset) -> DataArray: - """Compute allele count from genotype calls. +@numba.guvectorize([ + 'void(numba.int8[:], numba.uint8[:], numba.uint8[:])', + 'void(numba.int16[:], numba.uint8[:], numba.uint8[:])', + 'void(numba.int32[:], numba.uint8[:], numba.uint8[:])', + 'void(numba.int64[:], numba.uint8[:], numba.uint8[:])', + ], '(n),(k)->(k)') +def count_alleles(x, _, out): + out[:] = 0 + for v in x: + if v >= 0: + out[v] += 1 + + +def count_call_alleles(ds: Dataset) -> DataArray: + """Compute per sample allele counts from genotype calls. Parameters ---------- @@ -16,8 +29,8 @@ def count_alleles(ds: Dataset) -> DataArray: Returns ------- - variant_allele_count : DataArray - Allele counts with shape (variants, alleles) and values + call_allele_count : DataArray + Allele counts with shape (variants, samples, alleles) and values corresponding to the number of non-missing occurrences of each allele. @@ -35,96 +48,32 @@ def count_alleles(ds: Dataset) -> DataArray: 2 0/1 1/0 3 0/0 0/0 - >>> sg.count_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE - array([[2, 2], - [1, 3], - [2, 2], - [4, 0]]) - """ - # Count each allele index individually as a 1D vector and - # restack into new alleles dimension with same order - G = ds["call_genotype"].stack(calls=("samples", "ploidy")) - M = ds["call_genotype_mask"].stack(calls=("samples", "ploidy")) - n_variant, n_allele = G.shape[0], ds.dims["alleles"] - max_allele = n_allele + 1 - - # Recode missing values as max allele index - G = xr.where(M, n_allele, G) # type: ignore[no-untyped-call] - G = da.asarray(G) - - # Count allele indexes within each block - CT = da.map_blocks( - lambda x: np.apply_along_axis(np.bincount, 1, x, minlength=max_allele), - G, - chunks=(G.chunks[0], max_allele), - ) - assert CT.shape == (n_variant, G.numblocks[1] * max_allele) - - # Stack the column blocks on top of each other - CTS = da.stack([CT.blocks[:, i] for i in range(CT.numblocks[1])]) - assert CTS.shape == (CT.numblocks[1], n_variant, max_allele) - - # Sum over column blocks and slice off allele - # index corresponding to missing values - AC = CTS.sum(axis=0)[:, :n_allele] - assert AC.shape == (n_variant, n_allele) - - return DataArray(data=AC, dims=("variants", "alleles"), name="variant_allele_count") + >>> sg.count_call_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE + array([[[1, 1], + [1, 1]], + [[1, 1], + [0, 2]], -def count_call_alleles_ndarray( - g: np.ndarray, mask: np.ndarray, n_alleles: int = -1, dtype: type = np.uint8 -) -> np.ndarray: - """Compute allele count from genotype calls. + [[1, 1], + [1, 1]], - Parameters - ---------- - g : ndarray, int, shape (variants, samples, ploidy) - Array of genotype calls. - mask : ndarray, bool, shape (variants, samples, ploidy) - Array of booleans indicating individual allele calls - which should not be counted. - n_alleles : int, optional. - The number of unique alleles to be counted - (defaults to all alleles). - dtype : type, optional - Dtype of the allele counts. + [[2, 0], + [2, 0]]], dtype=uint8 + """ + G = da.asarray(ds.call_genotype) + # This array is only necessary to tell dask/numba what the + # dimensions and dtype are for the output array + O = da.empty(G.shape[:2] + (ds.dims['alleles'],), dtype=np.uint8) + O = O.rechunk(G.chunks[:2] + (-1,)) + return xr.DataArray( + count_alleles(G, O), + dims=('variants', 'samples', 'alleles'), + name='call_allele_count' + ) - Returns - ------- - call_allele_count : ndarray, shape (variants, samples, alleles) - Allele counts with values corresponding to the number - of non-missing occurrences of each allele. - """ - assert g.shape == mask.shape - n_variants, n_samples, ploidy = g.shape - - # default to counting all alleles - if n_alleles < 0: - n_alleles = np.max(g) + 1 - - ac = np.zeros((n_variants, n_samples, n_alleles), dtype=dtype) - for i in range(n_variants): - for j in range(n_samples): - for k in range(ploidy): - if mask[i, j, k]: - pass - else: - a = g[i, j, k] - if a < 0: - raise ValueError("Encountered unmasked negative allele value.") - if a >= n_alleles: - pass - else: - ac[i, j, a] += 1 - return ac - - -count_call_alleles_ndarray_jit = njit(count_call_alleles_ndarray, nogil=True) - - -def count_call_alleles(ds: Dataset, dtype: type = np.uint8) -> DataArray: +def count_variant_alleles(ds: Dataset) -> DataArray: """Compute allele count from genotype calls. Parameters @@ -132,13 +81,11 @@ def count_call_alleles(ds: Dataset, dtype: type = np.uint8) -> DataArray: ds : Dataset Genotype call dataset such as from `sgkit.create_genotype_call_dataset`. - dtype : type, optional - Dtype of the allele counts. Returns ------- - call_allele_count : DataArray - Allele counts with shape (variants, samples, alleles) and values + variant_allele_count : DataArray + Allele counts with shape (variants, alleles) and values corresponding to the number of non-missing occurrences of each allele. @@ -156,32 +103,14 @@ def count_call_alleles(ds: Dataset, dtype: type = np.uint8) -> DataArray: 2 0/1 1/0 3 0/0 0/0 - >>> sg.count_call_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE - array([[[1, 1], - [1, 1]], - - [[1, 1], - [0, 2]], - - [[1, 1], - [1, 1]], - - [[2, 0], - [2, 0]]], dtype=uint8 + >>> sg.count_variant_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE + array([[2, 2], + [1, 3], + [2, 2], + [4, 0]], dtype=uint64) """ - # dask arrays must have matching chunk size - g = da.asarray(ds["call_genotype"].values) - m = da.asarray(ds["call_genotype_mask"].values).rechunk(g.chunks) - assert g.chunks == m.chunks - - # shape of resulting chunks - n_allele = ds.dims["alleles"] - shape = (g.chunks[0], g.chunks[1], n_allele) - - # map function ensuring constant allele dimension size - func = lambda x, y: count_call_alleles_ndarray_jit(x, y, n_allele, dtype) - ac = da.map_overlap(func, g, m, chunks=shape) - - return DataArray( - data=ac, dims=("variants", "samples", "alleles"), name="call_allele_count", + return ( + count_call_alleles(ds) + .sum(dim='samples') + .rename('variant_allele_count') ) diff --git a/sgkit/tests/test_aggregation.py b/sgkit/tests/test_aggregation.py index 6f757a180..8b6be9e56 100644 --- a/sgkit/tests/test_aggregation.py +++ b/sgkit/tests/test_aggregation.py @@ -4,7 +4,7 @@ import xarray as xr from xarray import Dataset -from sgkit.stats.aggregation import count_alleles, count_call_alleles +from sgkit.stats.aggregation import count_variant_alleles, count_call_alleles from sgkit.testing import simulate_genotype_call_dataset from sgkit.typing import ArrayLike @@ -20,23 +20,23 @@ def get_dataset(calls: ArrayLike, **kwargs: Any) -> Dataset: return ds -def test_count_alleles__single_variant_single_sample(): - ac = count_alleles(get_dataset([[[1, 0]]])) +def test_count_variant_alleles__single_variant_single_sample(): + ac = count_variant_alleles(get_dataset([[[1, 0]]])) np.testing.assert_equal(ac, np.array([[1, 1]])) -def test_count_alleles__multi_variant_single_sample(): - ac = count_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]])) +def test_count_variant_alleles__multi_variant_single_sample(): + ac = count_variant_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]])) np.testing.assert_equal(ac, np.array([[2, 0], [1, 1], [1, 1], [0, 2]])) -def test_count_alleles__single_variant_multi_sample(): - ac = count_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]])) +def test_count_variant_alleles__single_variant_multi_sample(): + ac = count_variant_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]])) np.testing.assert_equal(ac, np.array([[4, 4]])) -def test_count_alleles__multi_variant_multi_sample(): - ac = count_alleles( +def test_count_variant_alleles__multi_variant_multi_sample(): + ac = count_variant_alleles( get_dataset( [ [[0, 0], [0, 0], [0, 0]], @@ -49,8 +49,8 @@ def test_count_alleles__multi_variant_multi_sample(): np.testing.assert_equal(ac, np.array([[6, 0], [5, 1], [2, 4], [0, 6]])) -def test_count_alleles__missing_data(): - ac = count_alleles( +def test_count_variant_alleles__missing_data(): + ac = count_variant_alleles( get_dataset( [ [[-1, -1], [-1, -1], [-1, -1]], @@ -63,8 +63,8 @@ def test_count_alleles__missing_data(): np.testing.assert_equal(ac, np.array([[0, 0], [2, 1], [1, 2], [0, 6]])) -def test_count_alleles__higher_ploidy(): - ac = count_alleles( +def test_count_variant_alleles__higher_ploidy(): + ac = count_variant_alleles( get_dataset( [ [[-1, -1, 0], [-1, -1, 1], [-1, -1, 2]], @@ -77,14 +77,14 @@ def test_count_alleles__higher_ploidy(): np.testing.assert_equal(ac, np.array([[1, 1, 1, 0], [1, 2, 2, 1]])) -def test_count_alleles__chunked(): +def test_count_variant_alleles__chunked(): rs = np.random.RandomState(0) calls = rs.randint(0, 1, size=(50, 10, 2)) ds = get_dataset(calls) - ac1 = count_alleles(ds) + ac1 = count_variant_alleles(ds) # Coerce from numpy to multiple chunks in all dimensions ds["call_genotype"] = ds["call_genotype"].chunk(chunks=(5, 5, 1)) # type: ignore[arg-type] - ac2 = count_alleles(ds) + ac2 = count_variant_alleles(ds) xr.testing.assert_equal(ac1, ac2) # type: ignore[no-untyped-call] From 96666a9608f73b6c04e1cfdeb6c0bc9ff019b573 Mon Sep 17 00:00:00 2001 From: Tim Millar Date: Wed, 19 Aug 2020 17:28:32 +1200 Subject: [PATCH 05/13] Fix count alleles bug for chunking --- sgkit/stats/aggregation.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index 097a66bcd..819e16dc6 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -10,12 +10,14 @@ 'void(numba.int16[:], numba.uint8[:], numba.uint8[:])', 'void(numba.int32[:], numba.uint8[:], numba.uint8[:])', 'void(numba.int64[:], numba.uint8[:], numba.uint8[:])', - ], '(n),(k)->(k)') -def count_alleles(x, _, out): + ], '(n),(k)->(k)', nopython=True) +def count_alleles(g, _, out): out[:] = 0 - for v in x: - if v >= 0: - out[v] += 1 + n_allele = len(g) + for i in range(n_allele): + a = g[i] + if a >= 0: + out[a] += 1 def count_call_alleles(ds: Dataset) -> DataArray: @@ -61,13 +63,12 @@ def count_call_alleles(ds: Dataset) -> DataArray: [[2, 0], [2, 0]]], dtype=uint8 """ - G = da.asarray(ds.call_genotype) - # This array is only necessary to tell dask/numba what the - # dimensions and dtype are for the output array - O = da.empty(G.shape[:2] + (ds.dims['alleles'],), dtype=np.uint8) - O = O.rechunk(G.chunks[:2] + (-1,)) + n_alleles = ds.dims['alleles'] + G = da.asarray(ds['call_genotype']) + shape = (G.chunks[0], G.chunks[1], n_alleles) + K = da.empty(n_alleles, dtype=np.uint8) return xr.DataArray( - count_alleles(G, O), + da.map_blocks(count_alleles, G, K, chunks=shape, drop_axis=2, new_axis=2), dims=('variants', 'samples', 'alleles'), name='call_allele_count' ) From adc9d238449b13489bc34f01cb36ba2f99fb1f3c Mon Sep 17 00:00:00 2001 From: Tim Millar Date: Wed, 19 Aug 2020 17:38:23 +1200 Subject: [PATCH 06/13] Fix doctest for count_call_alleles --- sgkit/stats/aggregation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index 819e16dc6..6782cada7 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -53,15 +53,15 @@ def count_call_alleles(ds: Dataset) -> DataArray: >>> sg.count_call_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE array([[[1, 1], [1, 1]], - + [[1, 1], [0, 2]], - + [[1, 1], [1, 1]], - + [[2, 0], - [2, 0]]], dtype=uint8 + [2, 0]]], dtype=uint8) """ n_alleles = ds.dims['alleles'] G = da.asarray(ds['call_genotype']) From 62b39752162a64883c7e1251c85b90e671448599 Mon Sep 17 00:00:00 2001 From: Tim Millar Date: Wed, 19 Aug 2020 18:05:59 +1200 Subject: [PATCH 07/13] Docstring for gufunc --- sgkit/stats/aggregation.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index 6782cada7..1cd4f513a 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -4,14 +4,33 @@ import numba from xarray import DataArray, Dataset +from ..typing import ArrayLike + @numba.guvectorize([ 'void(numba.int8[:], numba.uint8[:], numba.uint8[:])', 'void(numba.int16[:], numba.uint8[:], numba.uint8[:])', 'void(numba.int32[:], numba.uint8[:], numba.uint8[:])', 'void(numba.int64[:], numba.uint8[:], numba.uint8[:])', - ], '(n),(k)->(k)', nopython=True) -def count_alleles(g, _, out): + ], '(k),(n)->(n)', nopython=True) +def count_alleles(g: ArrayLike, _: ArrayLike, out: ArrayLike): + """Generaliszed U-function for computing per sample allele counts. + + Parameters + ---------- + g : (K,) array-like, int + A genotype call with K alleles where K is the genotypes ploidy. + _: (N,) array-like, uint8 + Dummy variable of length N where N is the number of possible + unique alleles. + + Returns + ------- + ac : (N,) array-like, uint8 + Allele counts with values corresponding to the number of + non-missing occurrences of each allele whithin g. + + """ out[:] = 0 n_allele = len(g) for i in range(n_allele): @@ -66,9 +85,9 @@ def count_call_alleles(ds: Dataset) -> DataArray: n_alleles = ds.dims['alleles'] G = da.asarray(ds['call_genotype']) shape = (G.chunks[0], G.chunks[1], n_alleles) - K = da.empty(n_alleles, dtype=np.uint8) + N = da.empty(n_alleles, dtype=np.uint8) return xr.DataArray( - da.map_blocks(count_alleles, G, K, chunks=shape, drop_axis=2, new_axis=2), + da.map_blocks(count_alleles, G, N, chunks=shape, drop_axis=2, new_axis=2), dims=('variants', 'samples', 'alleles'), name='call_allele_count' ) From 218b524e347a4ab5173452423a93c03e9e8a54b5 Mon Sep 17 00:00:00 2001 From: Tim Millar Date: Wed, 19 Aug 2020 18:18:10 +1200 Subject: [PATCH 08/13] Remove duplication in setup.cfg --- setup.cfg | 2 -- 1 file changed, 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index ef49db00d..638fe3dcd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -91,5 +91,3 @@ disallow_untyped_defs = False disallow_untyped_decorators = False [mypy-validation.*] ignore_errors = True -[mypy-numba.*] -ignore_missing_imports = True From ba2538a67eeb452dc7df2b230ed0b6f2fb7dee62 Mon Sep 17 00:00:00 2001 From: Tim Millar Date: Wed, 19 Aug 2020 19:31:27 +1200 Subject: [PATCH 09/13] Fixes for pre-commit --- sgkit/__init__.py | 2 +- sgkit/stats/aggregation.py | 35 ++++++++++++++++++--------------- sgkit/tests/test_aggregation.py | 2 +- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/sgkit/__init__.py b/sgkit/__init__.py index 49bcc16cf..807bf7293 100644 --- a/sgkit/__init__.py +++ b/sgkit/__init__.py @@ -8,7 +8,7 @@ ) from .display import display_genotypes from .io.vcfzarr_reader import read_vcfzarr -from .stats.aggregation import count_variant_alleles, count_call_alleles +from .stats.aggregation import count_call_alleles, count_variant_alleles from .stats.association import gwas_linear_regression from .stats.hwe import hardy_weinberg_test from .stats.regenie import regenie diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index 1cd4f513a..8f92cd40d 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -1,19 +1,23 @@ import dask.array as da +import numba import numpy as np import xarray as xr -import numba from xarray import DataArray, Dataset from ..typing import ArrayLike -@numba.guvectorize([ - 'void(numba.int8[:], numba.uint8[:], numba.uint8[:])', - 'void(numba.int16[:], numba.uint8[:], numba.uint8[:])', - 'void(numba.int32[:], numba.uint8[:], numba.uint8[:])', - 'void(numba.int64[:], numba.uint8[:], numba.uint8[:])', - ], '(k),(n)->(n)', nopython=True) -def count_alleles(g: ArrayLike, _: ArrayLike, out: ArrayLike): +@numba.guvectorize( + [ + "void(numba.int8[:], numba.uint8[:], numba.uint8[:])", + "void(numba.int16[:], numba.uint8[:], numba.uint8[:])", + "void(numba.int32[:], numba.uint8[:], numba.uint8[:])", + "void(numba.int64[:], numba.uint8[:], numba.uint8[:])", + ], + "(k),(n)->(n)", + nopython=True, +) +def count_alleles(g: ArrayLike, _: ArrayLike, out: ArrayLike) -> None: """Generaliszed U-function for computing per sample allele counts. Parameters @@ -82,14 +86,14 @@ def count_call_alleles(ds: Dataset) -> DataArray: [[2, 0], [2, 0]]], dtype=uint8) """ - n_alleles = ds.dims['alleles'] - G = da.asarray(ds['call_genotype']) + n_alleles = ds.dims["alleles"] + G = da.asarray(ds["call_genotype"]) shape = (G.chunks[0], G.chunks[1], n_alleles) N = da.empty(n_alleles, dtype=np.uint8) return xr.DataArray( da.map_blocks(count_alleles, G, N, chunks=shape, drop_axis=2, new_axis=2), - dims=('variants', 'samples', 'alleles'), - name='call_allele_count' + dims=("variants", "samples", "alleles"), + name="call_allele_count", ) @@ -129,8 +133,7 @@ def count_variant_alleles(ds: Dataset) -> DataArray: [2, 2], [4, 0]], dtype=uint64) """ - return ( - count_call_alleles(ds) - .sum(dim='samples') - .rename('variant_allele_count') + return xr.DataArray( + count_call_alleles(ds).sum(dim="samples").rename("variant_allele_count"), + dims=("variants", "alleles"), ) diff --git a/sgkit/tests/test_aggregation.py b/sgkit/tests/test_aggregation.py index 8b6be9e56..e1e2ad5a5 100644 --- a/sgkit/tests/test_aggregation.py +++ b/sgkit/tests/test_aggregation.py @@ -4,7 +4,7 @@ import xarray as xr from xarray import Dataset -from sgkit.stats.aggregation import count_variant_alleles, count_call_alleles +from sgkit.stats.aggregation import count_call_alleles, count_variant_alleles from sgkit.testing import simulate_genotype_call_dataset from sgkit.typing import ArrayLike From 6c8427d8e5e2d8a735562426a4b75de2d308bf8f Mon Sep 17 00:00:00 2001 From: Tim Millar Date: Wed, 19 Aug 2020 20:00:53 +1200 Subject: [PATCH 10/13] Add ignore type checking for guvectorize decorator --- sgkit/stats/aggregation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index 8f92cd40d..b930240f6 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -7,7 +7,7 @@ from ..typing import ArrayLike -@numba.guvectorize( +@numba.guvectorize( # type: ignore [ "void(numba.int8[:], numba.uint8[:], numba.uint8[:])", "void(numba.int16[:], numba.uint8[:], numba.uint8[:])", From 360997604942bfec1af74a105000a525f555e1d4 Mon Sep 17 00:00:00 2001 From: Tim Millar Date: Thu, 20 Aug 2020 09:38:54 +1200 Subject: [PATCH 11/13] Explicit import of guvectorize --- sgkit/stats/aggregation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index b930240f6..70200cfed 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -1,18 +1,18 @@ import dask.array as da -import numba import numpy as np import xarray as xr +from numba import guvectorize from xarray import DataArray, Dataset from ..typing import ArrayLike -@numba.guvectorize( # type: ignore +@guvectorize( # type: ignore [ - "void(numba.int8[:], numba.uint8[:], numba.uint8[:])", - "void(numba.int16[:], numba.uint8[:], numba.uint8[:])", - "void(numba.int32[:], numba.uint8[:], numba.uint8[:])", - "void(numba.int64[:], numba.uint8[:], numba.uint8[:])", + "void(int8[:], uint8[:], uint8[:])", + "void(int16[:], uint8[:], uint8[:])", + "void(int32[:], uint8[:], uint8[:])", + "void(int64[:], uint8[:], uint8[:])", ], "(k),(n)->(n)", nopython=True, From 8eaa40bd87f48640f04d6e3bfac57a80c47c56a4 Mon Sep 17 00:00:00 2001 From: Tim Millar Date: Thu, 20 Aug 2020 19:21:57 +1200 Subject: [PATCH 12/13] Use display_genotypes in aggregation docstrings --- sgkit/stats/aggregation.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index 70200cfed..54a714a18 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -18,7 +18,7 @@ nopython=True, ) def count_alleles(g: ArrayLike, _: ArrayLike, out: ArrayLike) -> None: - """Generaliszed U-function for computing per sample allele counts. + """Generalized U-function for computing per sample allele counts. Parameters ---------- @@ -65,13 +65,13 @@ def count_call_alleles(ds: Dataset) -> DataArray: >>> import sgkit as sg >>> from sgkit.testing import simulate_genotype_call_dataset >>> ds = simulate_genotype_call_dataset(n_variant=4, n_sample=2, seed=1) - >>> ds['call_genotype'].to_series().unstack().astype(str).apply('/'.join, axis=1).unstack() # doctest: +NORMALIZE_WHITESPACE - samples 0 1 + >>> sg.display_genotypes(ds) # doctest: +NORMALIZE_WHITESPACE + samples S0 S1 variants - 0 1/0 1/0 - 1 1/0 1/1 - 2 0/1 1/0 - 3 0/0 0/0 + 0 1/0 1/0 + 1 1/0 1/1 + 2 0/1 1/0 + 3 0/0 0/0 >>> sg.count_call_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE array([[[1, 1], @@ -119,13 +119,13 @@ def count_variant_alleles(ds: Dataset) -> DataArray: >>> import sgkit as sg >>> from sgkit.testing import simulate_genotype_call_dataset >>> ds = simulate_genotype_call_dataset(n_variant=4, n_sample=2, seed=1) - >>> ds['call_genotype'].to_series().unstack().astype(str).apply('/'.join, axis=1).unstack() # doctest: +NORMALIZE_WHITESPACE - samples 0 1 + >>> sg.display_genotypes(ds) # doctest: +NORMALIZE_WHITESPACE + samples S0 S1 variants - 0 1/0 1/0 - 1 1/0 1/1 - 2 0/1 1/0 - 3 0/0 0/0 + 0 1/0 1/0 + 1 1/0 1/1 + 2 0/1 1/0 + 3 0/0 0/0 >>> sg.count_variant_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE array([[2, 2], From 45ad17296c3cd16204c197d34fa39949a95fc101 Mon Sep 17 00:00:00 2001 From: Tim Millar Date: Thu, 20 Aug 2020 21:36:58 +1200 Subject: [PATCH 13/13] Numpy style docstring for count_alleles --- sgkit/stats/aggregation.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index 54a714a18..15a1a16cf 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -22,17 +22,19 @@ def count_alleles(g: ArrayLike, _: ArrayLike, out: ArrayLike) -> None: Parameters ---------- - g : (K,) array-like, int - A genotype call with K alleles where K is the genotypes ploidy. - _: (N,) array-like, uint8 - Dummy variable of length N where N is the number of possible - unique alleles. + g : array_like + Genotype call of shape (ploidy,) containing alleles encoded as + type `int` with values < 0 indicating a missing allele. + _: array_like + Dummy variable of type `uint8` and shape (alleles,) used to + define the number of unique alleles to be counted in the + return value. Returns ------- - ac : (N,) array-like, uint8 - Allele counts with values corresponding to the number of - non-missing occurrences of each allele whithin g. + ac : ndarray + Allele counts with shape (alleles,) and values corresponding to + the number of non-missing occurrences of each allele. """ out[:] = 0