Skip to content

Commit 320ebc9

Browse files
Implement count_allele_calls #85
* Implement count_call_alleles #85 * Fixes for mypy and black * Add dependency on numba * gufunc implementation of count_alleles * Fix count alleles bug for chunking * Fix doctest for count_call_alleles * Docstring for gufunc * Remove duplication in setup.cfg * Fixes for pre-commit * Add ignore type checking for guvectorize decorator * Explicit import of guvectorize * Use display_genotypes in aggregation docstrings * Numpy style docstring for count_alleles Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 84f79c2 commit 320ebc9

File tree

4 files changed

+223
-56
lines changed

4 files changed

+223
-56
lines changed

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ xarray
33
dask[array]
44
scipy
55
numba
6-
zarr
6+
zarr

sgkit/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
)
99
from .display import display_genotypes
1010
from .io.vcfzarr_reader import read_vcfzarr
11-
from .stats.aggregation import count_alleles
11+
from .stats.aggregation import count_call_alleles, count_variant_alleles
1212
from .stats.association import gwas_linear_regression
1313
from .stats.hwe import hardy_weinberg_test
1414
from .stats.regenie import regenie
@@ -19,7 +19,8 @@
1919
"DIM_SAMPLE",
2020
"DIM_VARIANT",
2121
"create_genotype_call_dataset",
22-
"count_alleles",
22+
"count_variant_alleles",
23+
"count_call_alleles",
2324
"create_genotype_dosage_dataset",
2425
"display_genotypes",
2526
"gwas_linear_regression",

sgkit/stats/aggregation.py

+107-37
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,105 @@
11
import dask.array as da
22
import numpy as np
33
import xarray as xr
4+
from numba import guvectorize
45
from xarray import DataArray, Dataset
56

7+
from ..typing import ArrayLike
68

7-
def count_alleles(ds: Dataset) -> DataArray:
9+
10+
@guvectorize( # type: ignore
11+
[
12+
"void(int8[:], uint8[:], uint8[:])",
13+
"void(int16[:], uint8[:], uint8[:])",
14+
"void(int32[:], uint8[:], uint8[:])",
15+
"void(int64[:], uint8[:], uint8[:])",
16+
],
17+
"(k),(n)->(n)",
18+
nopython=True,
19+
)
20+
def count_alleles(g: ArrayLike, _: ArrayLike, out: ArrayLike) -> None:
21+
"""Generalized U-function for computing per sample allele counts.
22+
23+
Parameters
24+
----------
25+
g : array_like
26+
Genotype call of shape (ploidy,) containing alleles encoded as
27+
type `int` with values < 0 indicating a missing allele.
28+
_: array_like
29+
Dummy variable of type `uint8` and shape (alleles,) used to
30+
define the number of unique alleles to be counted in the
31+
return value.
32+
33+
Returns
34+
-------
35+
ac : ndarray
36+
Allele counts with shape (alleles,) and values corresponding to
37+
the number of non-missing occurrences of each allele.
38+
39+
"""
40+
out[:] = 0
41+
n_allele = len(g)
42+
for i in range(n_allele):
43+
a = g[i]
44+
if a >= 0:
45+
out[a] += 1
46+
47+
48+
def count_call_alleles(ds: Dataset) -> DataArray:
49+
"""Compute per sample allele counts from genotype calls.
50+
51+
Parameters
52+
----------
53+
ds : Dataset
54+
Genotype call dataset such as from
55+
`sgkit.create_genotype_call_dataset`.
56+
57+
Returns
58+
-------
59+
call_allele_count : DataArray
60+
Allele counts with shape (variants, samples, alleles) and values
61+
corresponding to the number of non-missing occurrences
62+
of each allele.
63+
64+
Examples
65+
--------
66+
67+
>>> import sgkit as sg
68+
>>> from sgkit.testing import simulate_genotype_call_dataset
69+
>>> ds = simulate_genotype_call_dataset(n_variant=4, n_sample=2, seed=1)
70+
>>> sg.display_genotypes(ds) # doctest: +NORMALIZE_WHITESPACE
71+
samples S0 S1
72+
variants
73+
0 1/0 1/0
74+
1 1/0 1/1
75+
2 0/1 1/0
76+
3 0/0 0/0
77+
78+
>>> sg.count_call_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE
79+
array([[[1, 1],
80+
[1, 1]],
81+
<BLANKLINE>
82+
[[1, 1],
83+
[0, 2]],
84+
<BLANKLINE>
85+
[[1, 1],
86+
[1, 1]],
87+
<BLANKLINE>
88+
[[2, 0],
89+
[2, 0]]], dtype=uint8)
90+
"""
91+
n_alleles = ds.dims["alleles"]
92+
G = da.asarray(ds["call_genotype"])
93+
shape = (G.chunks[0], G.chunks[1], n_alleles)
94+
N = da.empty(n_alleles, dtype=np.uint8)
95+
return xr.DataArray(
96+
da.map_blocks(count_alleles, G, N, chunks=shape, drop_axis=2, new_axis=2),
97+
dims=("variants", "samples", "alleles"),
98+
name="call_allele_count",
99+
)
100+
101+
102+
def count_variant_alleles(ds: Dataset) -> DataArray:
8103
"""Compute allele count from genotype calls.
9104
10105
Parameters
@@ -26,46 +121,21 @@ def count_alleles(ds: Dataset) -> DataArray:
26121
>>> import sgkit as sg
27122
>>> from sgkit.testing import simulate_genotype_call_dataset
28123
>>> ds = simulate_genotype_call_dataset(n_variant=4, n_sample=2, seed=1)
29-
>>> ds['call_genotype'].to_series().unstack().astype(str).apply('/'.join, axis=1).unstack() # doctest: +NORMALIZE_WHITESPACE
30-
samples 0 1
124+
>>> sg.display_genotypes(ds) # doctest: +NORMALIZE_WHITESPACE
125+
samples S0 S1
31126
variants
32-
0 1/0 1/0
33-
1 1/0 1/1
34-
2 0/1 1/0
35-
3 0/0 0/0
127+
0 1/0 1/0
128+
1 1/0 1/1
129+
2 0/1 1/0
130+
3 0/0 0/0
36131
37-
>>> sg.count_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE
132+
>>> sg.count_variant_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE
38133
array([[2, 2],
39134
[1, 3],
40135
[2, 2],
41-
[4, 0]])
136+
[4, 0]], dtype=uint64)
42137
"""
43-
# Count each allele index individually as a 1D vector and
44-
# restack into new alleles dimension with same order
45-
G = ds["call_genotype"].stack(calls=("samples", "ploidy"))
46-
M = ds["call_genotype_mask"].stack(calls=("samples", "ploidy"))
47-
n_variant, n_allele = G.shape[0], ds.dims["alleles"]
48-
max_allele = n_allele + 1
49-
50-
# Recode missing values as max allele index
51-
G = xr.where(M, n_allele, G) # type: ignore[no-untyped-call]
52-
G = da.asarray(G)
53-
54-
# Count allele indexes within each block
55-
CT = da.map_blocks(
56-
lambda x: np.apply_along_axis(np.bincount, 1, x, minlength=max_allele),
57-
G,
58-
chunks=(G.chunks[0], max_allele),
138+
return xr.DataArray(
139+
count_call_alleles(ds).sum(dim="samples").rename("variant_allele_count"),
140+
dims=("variants", "alleles"),
59141
)
60-
assert CT.shape == (n_variant, G.numblocks[1] * max_allele)
61-
62-
# Stack the column blocks on top of each other
63-
CTS = da.stack([CT.blocks[:, i] for i in range(CT.numblocks[1])])
64-
assert CTS.shape == (CT.numblocks[1], n_variant, max_allele)
65-
66-
# Sum over column blocks and slice off allele
67-
# index corresponding to missing values
68-
AC = CTS.sum(axis=0)[:, :n_allele]
69-
assert AC.shape == (n_variant, n_allele)
70-
71-
return DataArray(data=AC, dims=("variants", "alleles"), name="variant_allele_count")

sgkit/tests/test_aggregation.py

+112-16
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import xarray as xr
55
from xarray import Dataset
66

7-
from sgkit.stats.aggregation import count_alleles
7+
from sgkit.stats.aggregation import count_call_alleles, count_variant_alleles
88
from sgkit.testing import simulate_genotype_call_dataset
99
from sgkit.typing import ArrayLike
1010

@@ -20,23 +20,23 @@ def get_dataset(calls: ArrayLike, **kwargs: Any) -> Dataset:
2020
return ds
2121

2222

23-
def test_count_alleles__single_variant_single_sample():
24-
ac = count_alleles(get_dataset([[[1, 0]]]))
23+
def test_count_variant_alleles__single_variant_single_sample():
24+
ac = count_variant_alleles(get_dataset([[[1, 0]]]))
2525
np.testing.assert_equal(ac, np.array([[1, 1]]))
2626

2727

28-
def test_count_alleles__multi_variant_single_sample():
29-
ac = count_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]))
28+
def test_count_variant_alleles__multi_variant_single_sample():
29+
ac = count_variant_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]))
3030
np.testing.assert_equal(ac, np.array([[2, 0], [1, 1], [1, 1], [0, 2]]))
3131

3232

33-
def test_count_alleles__single_variant_multi_sample():
34-
ac = count_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]]))
33+
def test_count_variant_alleles__single_variant_multi_sample():
34+
ac = count_variant_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]]))
3535
np.testing.assert_equal(ac, np.array([[4, 4]]))
3636

3737

38-
def test_count_alleles__multi_variant_multi_sample():
39-
ac = count_alleles(
38+
def test_count_variant_alleles__multi_variant_multi_sample():
39+
ac = count_variant_alleles(
4040
get_dataset(
4141
[
4242
[[0, 0], [0, 0], [0, 0]],
@@ -49,8 +49,8 @@ def test_count_alleles__multi_variant_multi_sample():
4949
np.testing.assert_equal(ac, np.array([[6, 0], [5, 1], [2, 4], [0, 6]]))
5050

5151

52-
def test_count_alleles__missing_data():
53-
ac = count_alleles(
52+
def test_count_variant_alleles__missing_data():
53+
ac = count_variant_alleles(
5454
get_dataset(
5555
[
5656
[[-1, -1], [-1, -1], [-1, -1]],
@@ -63,8 +63,8 @@ def test_count_alleles__missing_data():
6363
np.testing.assert_equal(ac, np.array([[0, 0], [2, 1], [1, 2], [0, 6]]))
6464

6565

66-
def test_count_alleles__higher_ploidy():
67-
ac = count_alleles(
66+
def test_count_variant_alleles__higher_ploidy():
67+
ac = count_variant_alleles(
6868
get_dataset(
6969
[
7070
[[-1, -1, 0], [-1, -1, 1], [-1, -1, 2]],
@@ -77,12 +77,108 @@ def test_count_alleles__higher_ploidy():
7777
np.testing.assert_equal(ac, np.array([[1, 1, 1, 0], [1, 2, 2, 1]]))
7878

7979

80-
def test_count_alleles__chunked():
80+
def test_count_variant_alleles__chunked():
8181
rs = np.random.RandomState(0)
8282
calls = rs.randint(0, 1, size=(50, 10, 2))
8383
ds = get_dataset(calls)
84-
ac1 = count_alleles(ds)
84+
ac1 = count_variant_alleles(ds)
8585
# Coerce from numpy to multiple chunks in all dimensions
8686
ds["call_genotype"] = ds["call_genotype"].chunk(chunks=(5, 5, 1)) # type: ignore[arg-type]
87-
ac2 = count_alleles(ds)
87+
ac2 = count_variant_alleles(ds)
88+
xr.testing.assert_equal(ac1, ac2) # type: ignore[no-untyped-call]
89+
90+
91+
def test_count_call_alleles__single_variant_single_sample():
92+
ac = count_call_alleles(get_dataset([[[1, 0]]]))
93+
np.testing.assert_equal(ac, np.array([[[1, 1]]]))
94+
95+
96+
def test_count_call_alleles__multi_variant_single_sample():
97+
ac = count_call_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]))
98+
np.testing.assert_equal(ac, np.array([[[2, 0]], [[1, 1]], [[1, 1]], [[0, 2]]]))
99+
100+
101+
def test_count_call_alleles__single_variant_multi_sample():
102+
ac = count_call_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]]))
103+
np.testing.assert_equal(ac, np.array([[[2, 0], [1, 1], [1, 1], [0, 2]]]))
104+
105+
106+
def test_count_call_alleles__multi_variant_multi_sample():
107+
ac = count_call_alleles(
108+
get_dataset(
109+
[
110+
[[0, 0], [0, 0], [0, 0]],
111+
[[0, 0], [0, 0], [0, 1]],
112+
[[1, 1], [0, 1], [1, 0]],
113+
[[1, 1], [1, 1], [1, 1]],
114+
]
115+
)
116+
)
117+
np.testing.assert_equal(
118+
ac,
119+
np.array(
120+
[
121+
[[2, 0], [2, 0], [2, 0]],
122+
[[2, 0], [2, 0], [1, 1]],
123+
[[0, 2], [1, 1], [1, 1]],
124+
[[0, 2], [0, 2], [0, 2]],
125+
]
126+
),
127+
)
128+
129+
130+
def test_count_call_alleles__missing_data():
131+
ac = count_call_alleles(
132+
get_dataset(
133+
[
134+
[[-1, -1], [-1, -1], [-1, -1]],
135+
[[-1, -1], [0, 0], [-1, 1]],
136+
[[1, 1], [-1, -1], [-1, 0]],
137+
[[1, 1], [1, 1], [1, 1]],
138+
]
139+
)
140+
)
141+
np.testing.assert_equal(
142+
ac,
143+
np.array(
144+
[
145+
[[0, 0], [0, 0], [0, 0]],
146+
[[0, 0], [2, 0], [0, 1]],
147+
[[0, 2], [0, 0], [1, 0]],
148+
[[0, 2], [0, 2], [0, 2]],
149+
]
150+
),
151+
)
152+
153+
154+
def test_count_call_alleles__higher_ploidy():
155+
ac = count_call_alleles(
156+
get_dataset(
157+
[
158+
[[-1, -1, 0], [-1, -1, 1], [-1, -1, 2]],
159+
[[0, 1, 2], [1, 2, 3], [-1, -1, -1]],
160+
],
161+
n_allele=4,
162+
n_ploidy=3,
163+
)
164+
)
165+
np.testing.assert_equal(
166+
ac,
167+
np.array(
168+
[
169+
[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]],
170+
[[1, 1, 1, 0], [0, 1, 1, 1], [0, 0, 0, 0]],
171+
]
172+
),
173+
)
174+
175+
176+
def test_count_call_alleles__chunked():
177+
rs = np.random.RandomState(0)
178+
calls = rs.randint(0, 1, size=(50, 10, 2))
179+
ds = get_dataset(calls)
180+
ac1 = count_call_alleles(ds)
181+
# Coerce from numpy to multiple chunks in all dimensions
182+
ds["call_genotype"] = ds["call_genotype"].chunk(chunks=(5, 5, 1)) # type: ignore[arg-type]
183+
ac2 = count_call_alleles(ds)
88184
xr.testing.assert_equal(ac1, ac2) # type: ignore[no-untyped-call]

0 commit comments

Comments
 (0)