Skip to content

Commit 3524b8f

Browse files
committed
Variant summary stats
1 parent fdd7b62 commit 3524b8f

File tree

4 files changed

+85
-2
lines changed

4 files changed

+85
-2
lines changed

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ numpy
22
xarray
33
dask[array]
44
scipy
5+
typing-extensions
56
numba
67
zarr

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ ignore =
5959
profile = black
6060
default_section = THIRDPARTY
6161
known_first_party = sgkit
62-
known_third_party = dask,fire,glow,hail,hypothesis,invoke,numba,numpy,pandas,pkg_resources,pyspark,pytest,setuptools,sgkit_plink,xarray,yaml,zarr
62+
known_third_party = dask,fire,glow,hail,hypothesis,invoke,numba,numpy,pandas,pkg_resources,pyspark,pytest,setuptools,sgkit_plink,typing_extensions,xarray,yaml,zarr
6363
multi_line_output = 3
6464
include_trailing_comma = True
6565
force_grid_wrap = 0

sgkit/stats/aggregation.py

+59
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import dask.array as da
22
import numpy as np
33
from numba import guvectorize
4+
from typing_extensions import Literal
45
from xarray import Dataset
56

67
from ..typing import ArrayLike
78

9+
Dimension = Literal["samples", "variants"]
10+
811

912
@guvectorize( # type: ignore
1013
[
@@ -155,3 +158,59 @@ def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset:
155158
}
156159
)
157160
return ds.merge(new_ds) if merge else new_ds
161+
162+
163+
def _swap(dim: Dimension) -> Dimension:
164+
return "samples" if dim == "variants" else "variants"
165+
166+
167+
def call_rate(ds: Dataset, dim: Dimension) -> Dataset:
168+
odim = _swap(dim)[:-1]
169+
n_called = (~ds["call_genotype_mask"].any(dim="ploidy")).sum(dim=dim)
170+
return xr.Dataset(
171+
{f"{odim}_n_called": n_called, f"{odim}_call_rate": n_called / ds.dims[dim]}
172+
)
173+
174+
175+
def genotype_count(ds: Dataset, dim: Dimension) -> Dataset:
176+
odim = _swap(dim)[:-1]
177+
M, G = ds["call_genotype_mask"].any(dim="ploidy"), ds["call_genotype"]
178+
n_het = (G > 0).any(dim="ploidy") & (G == 0).any(dim="ploidy")
179+
n_hom_ref = (G == 0).all(dim="ploidy")
180+
n_hom_alt = (G > 0).all(dim="ploidy")
181+
n_non_ref = (G > 0).any(dim="ploidy")
182+
agg = lambda x: xr.where(M, False, x).sum(dim=dim) # type: ignore[no-untyped-call]
183+
return xr.Dataset(
184+
{
185+
f"{odim}_n_het": agg(n_het), # type: ignore[no-untyped-call]
186+
f"{odim}_n_hom_ref": agg(n_hom_ref), # type: ignore[no-untyped-call]
187+
f"{odim}_n_hom_alt": agg(n_hom_alt), # type: ignore[no-untyped-call]
188+
f"{odim}_n_non_ref": agg(n_non_ref), # type: ignore[no-untyped-call]
189+
}
190+
)
191+
192+
193+
def allele_frequency(ds: Dataset) -> Dataset:
194+
AC = count_alleles(ds)
195+
196+
M = ds["call_genotype_mask"].stack(calls=("samples", "ploidy"))
197+
AN = (~M).sum(dim="calls") # type: ignore
198+
assert AN.shape == (ds.dims["variants"],)
199+
200+
return xr.Dataset(
201+
{
202+
"variant_allele_count": AC,
203+
"variant_allele_total": AN,
204+
"variant_allele_frequency": AC / AN,
205+
}
206+
)
207+
208+
209+
def variant_stats(ds: Dataset) -> Dataset:
210+
return xr.merge(
211+
[
212+
call_rate(ds, dim="samples"),
213+
genotype_count(ds, dim="samples"),
214+
allele_frequency(ds),
215+
]
216+
)

sgkit/tests/test_aggregation.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import xarray as xr
55
from xarray import Dataset
66

7-
from sgkit.stats.aggregation import count_call_alleles, count_variant_alleles
7+
from sgkit.stats.aggregation import count_call_alleles, count_variant_alleles, variant_stats
88
from sgkit.testing import simulate_genotype_call_dataset
99
from sgkit.typing import ArrayLike
1010

@@ -202,3 +202,26 @@ def test_count_call_alleles__chunked():
202202
ds["call_genotype"] = ds["call_genotype"].chunk(chunks=(5, 5, 1)) # type: ignore[arg-type]
203203
ac2 = count_call_alleles(ds)
204204
xr.testing.assert_equal(ac1, ac2) # type: ignore[no-untyped-call]
205+
206+
207+
def test_variant_stats():
208+
ds = get_dataset(
209+
[[[1, 0], [-1, -1]], [[1, 0], [1, 1]], [[0, 1], [1, 0]], [[-1, -1], [0, 0]]]
210+
)
211+
vs = variant_stats(ds)
212+
213+
np.testing.assert_equal(vs["variant_n_called"], np.array([1, 2, 2, 1]))
214+
np.testing.assert_equal(vs["variant_call_rate"], np.array([0.5, 1.0, 1.0, 0.5]))
215+
np.testing.assert_equal(vs["variant_n_hom_ref"], np.array([0, 0, 0, 1]))
216+
np.testing.assert_equal(vs["variant_n_hom_alt"], np.array([0, 1, 0, 0]))
217+
np.testing.assert_equal(vs["variant_n_het"], np.array([1, 1, 2, 0]))
218+
np.testing.assert_equal(vs["variant_n_non_ref"], np.array([1, 2, 2, 0]))
219+
np.testing.assert_equal(vs["variant_n_non_ref"], np.array([1, 2, 2, 0]))
220+
np.testing.assert_equal(
221+
vs["variant_allele_count"], np.array([[1, 1], [1, 3], [2, 2], [2, 0]])
222+
)
223+
np.testing.assert_equal(vs["variant_allele_total"], np.array([2, 4, 4, 2]))
224+
np.testing.assert_equal(
225+
vs["variant_allele_frequency"],
226+
np.array([[0.5, 0.5], [0.25, 0.75], [0.5, 0.5], [1, 0]]),
227+
)

0 commit comments

Comments
 (0)