Skip to content

Commit 85d447c

Browse files
Merge branch 'master' into popgen_stats
2 parents f6a5431 + 36cd66e commit 85d447c

File tree

6 files changed

+132
-3
lines changed

6 files changed

+132
-3
lines changed

docs/api.rst

+2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ Methods
3131
hardy_weinberg_test
3232
regenie
3333
Tajimas_D
34+
variant_stats
35+
3436

3537
Utilities
3638
=========

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ numpy
22
xarray
33
dask[array]
44
scipy
5+
typing-extensions
56
numba
67
zarr

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ ignore =
6262
profile = black
6363
default_section = THIRDPARTY
6464
known_first_party = sgkit
65-
known_third_party = dask,fire,glow,hail,hypothesis,invoke,msprime,numba,numpy,pandas,pkg_resources,pyspark,pytest,setuptools,sgkit_plink,xarray,yaml,zarr
65+
known_third_party = dask,fire,glow,hail,hypothesis,invoke,msprime,numba,numpy,pandas,pkg_resources,pyspark,pytest,setuptools,sgkit_plink,typing_extensions,xarray,yaml,zarr
6666
multi_line_output = 3
6767
include_trailing_comma = True
6868
force_grid_wrap = 0

sgkit/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
)
99
from .display import display_genotypes
1010
from .io.vcfzarr_reader import read_vcfzarr
11-
from .stats.aggregation import count_call_alleles, count_variant_alleles
11+
from .stats.aggregation import count_call_alleles, count_variant_alleles, variant_stats
1212
from .stats.association import gwas_linear_regression
1313
from .stats.hwe import hardy_weinberg_test
1414
from .stats.popgen import Fst, Tajimas_D, divergence, diversity
@@ -32,4 +32,5 @@
3232
"divergence",
3333
"Fst",
3434
"Tajimas_D",
35+
"variant_stats",
3536
]

sgkit/stats/aggregation.py

+94
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1+
from typing import Any, Dict, Hashable
2+
13
import dask.array as da
24
import numpy as np
5+
import xarray as xr
36
from numba import guvectorize
7+
from typing_extensions import Literal
48
from xarray import Dataset
59

610
from sgkit.typing import ArrayLike
711
from sgkit.utils import merge_datasets
812

13+
Dimension = Literal["samples", "variants"]
14+
915

1016
@guvectorize( # type: ignore
1117
[
@@ -162,3 +168,91 @@ def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset:
162168
}
163169
)
164170
return merge_datasets(ds, new_ds) if merge else new_ds
171+
172+
173+
def _swap(dim: Dimension) -> Dimension:
174+
return "samples" if dim == "variants" else "variants"
175+
176+
177+
def call_rate(ds: Dataset, dim: Dimension) -> Dataset:
178+
odim = _swap(dim)[:-1]
179+
n_called = (~ds["call_genotype_mask"].any(dim="ploidy")).sum(dim=dim)
180+
return xr.Dataset(
181+
{f"{odim}_n_called": n_called, f"{odim}_call_rate": n_called / ds.dims[dim]}
182+
)
183+
184+
185+
def genotype_count(ds: Dataset, dim: Dimension) -> Dataset:
186+
odim = _swap(dim)[:-1]
187+
M, G = ds["call_genotype_mask"].any(dim="ploidy"), ds["call_genotype"]
188+
n_hom_ref = (G == 0).all(dim="ploidy")
189+
n_hom_alt = ((G > 0) & (G[..., 0] == G)).all(dim="ploidy")
190+
n_non_ref = (G > 0).any(dim="ploidy")
191+
n_het = ~(n_hom_alt | n_hom_ref)
192+
# This would 0 out the `het` case with any missing calls
193+
agg = lambda x: xr.where(M, False, x).sum(dim=dim) # type: ignore[no-untyped-call]
194+
return Dataset(
195+
{
196+
f"{odim}_n_het": agg(n_het), # type: ignore[no-untyped-call]
197+
f"{odim}_n_hom_ref": agg(n_hom_ref), # type: ignore[no-untyped-call]
198+
f"{odim}_n_hom_alt": agg(n_hom_alt), # type: ignore[no-untyped-call]
199+
f"{odim}_n_non_ref": agg(n_non_ref), # type: ignore[no-untyped-call]
200+
}
201+
)
202+
203+
204+
def allele_frequency(ds: Dataset) -> Dataset:
205+
data_vars: Dict[Hashable, Any] = {}
206+
# only compute variant allele count if not already in dataset
207+
if "variant_allele_count" in ds:
208+
AC = ds["variant_allele_count"]
209+
else:
210+
AC = count_variant_alleles(ds, merge=False)["variant_allele_count"]
211+
data_vars["variant_allele_count"] = AC
212+
213+
M = ds["call_genotype_mask"].stack(calls=("samples", "ploidy"))
214+
AN = (~M).sum(dim="calls") # type: ignore
215+
assert AN.shape == (ds.dims["variants"],)
216+
217+
data_vars["variant_allele_total"] = AN
218+
data_vars["variant_allele_frequency"] = AC / AN
219+
return Dataset(data_vars)
220+
221+
222+
def variant_stats(ds: Dataset, merge: bool = True) -> Dataset:
223+
"""Compute quality control variant statistics from genotype calls.
224+
225+
Parameters
226+
----------
227+
ds : Dataset
228+
Genotype call dataset such as from
229+
`sgkit.create_genotype_call_dataset`.
230+
merge : bool, optional
231+
If True (the default), merge the input dataset and the computed
232+
output variables into a single dataset. Output variables will
233+
overwrite any input variables with the same name, and a warning
234+
will be issued in this case.
235+
If False, return only the computed output variables.
236+
237+
Returns
238+
-------
239+
Dataset
240+
A dataset containing the following variables:
241+
- `variant_n_called` (variants): The number of samples with called genotypes.
242+
- `variant_call_rate` (variants): The fraction of samples with called genotypes.
243+
- `variant_n_het` (variants): The number of samples with heterozygous calls.
244+
- `variant_n_hom_ref` (variants): The number of samples with homozygous reference calls.
245+
- `variant_n_hom_alt` (variants): The number of samples with homozygous alternate calls.
246+
- `variant_n_non_ref` (variants): The number of samples that are not homozygous reference calls.
247+
- `variant_allele_count` (variants, alleles): The number of occurrences of each allele.
248+
- `variant_allele_total` (variants): The number of occurrences of all alleles.
249+
- `variant_allele_frequency` (variants, alleles): The frequency of occurence of each allele.
250+
"""
251+
new_ds = xr.merge(
252+
[
253+
call_rate(ds, dim="samples"),
254+
genotype_count(ds, dim="samples"),
255+
allele_frequency(ds),
256+
]
257+
)
258+
return merge_datasets(ds, new_ds) if merge else new_ds

sgkit/tests/test_aggregation.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
from typing import Any
22

33
import numpy as np
4+
import pytest
45
import xarray as xr
56
from xarray import Dataset
67

7-
from sgkit.stats.aggregation import count_call_alleles, count_variant_alleles
8+
from sgkit.stats.aggregation import (
9+
count_call_alleles,
10+
count_variant_alleles,
11+
variant_stats,
12+
)
813
from sgkit.testing import simulate_genotype_call_dataset
914
from sgkit.typing import ArrayLike
1015

@@ -202,3 +207,29 @@ def test_count_call_alleles__chunked():
202207
ds["call_genotype"] = ds["call_genotype"].chunk(chunks=(5, 5, 1)) # type: ignore[arg-type]
203208
ac2 = count_call_alleles(ds)
204209
xr.testing.assert_equal(ac1, ac2) # type: ignore[no-untyped-call]
210+
211+
212+
@pytest.mark.parametrize("precompute_variant_allele_count", [False, True])
213+
def test_variant_stats(precompute_variant_allele_count):
214+
ds = get_dataset(
215+
[[[1, 0], [-1, -1]], [[1, 0], [1, 1]], [[0, 1], [1, 0]], [[-1, -1], [0, 0]]]
216+
)
217+
if precompute_variant_allele_count:
218+
ds = count_variant_alleles(ds)
219+
vs = variant_stats(ds)
220+
221+
np.testing.assert_equal(vs["variant_n_called"], np.array([1, 2, 2, 1]))
222+
np.testing.assert_equal(vs["variant_call_rate"], np.array([0.5, 1.0, 1.0, 0.5]))
223+
np.testing.assert_equal(vs["variant_n_hom_ref"], np.array([0, 0, 0, 1]))
224+
np.testing.assert_equal(vs["variant_n_hom_alt"], np.array([0, 1, 0, 0]))
225+
np.testing.assert_equal(vs["variant_n_het"], np.array([1, 1, 2, 0]))
226+
np.testing.assert_equal(vs["variant_n_non_ref"], np.array([1, 2, 2, 0]))
227+
np.testing.assert_equal(vs["variant_n_non_ref"], np.array([1, 2, 2, 0]))
228+
np.testing.assert_equal(
229+
vs["variant_allele_count"], np.array([[1, 1], [1, 3], [2, 2], [2, 0]])
230+
)
231+
np.testing.assert_equal(vs["variant_allele_total"], np.array([2, 4, 4, 2]))
232+
np.testing.assert_equal(
233+
vs["variant_allele_frequency"],
234+
np.array([[0.5, 0.5], [0.25, 0.75], [0.5, 0.5], [1, 0]]),
235+
)

0 commit comments

Comments
 (0)