Skip to content

Commit 88c08b9

Browse files
committed
Add merge parameter to variant_stats
1 parent cc4f4cd commit 88c08b9

File tree

2 files changed

+27
-13
lines changed

2 files changed

+27
-13
lines changed

sgkit/stats/aggregation.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from typing import Any, Dict, Hashable
2+
13
import dask.array as da
24
import numpy as np
5+
import xarray as xr
36
from numba import guvectorize
47
from typing_extensions import Literal
58
from xarray import Dataset
@@ -181,7 +184,7 @@ def genotype_count(ds: Dataset, dim: Dimension) -> Dataset:
181184
n_het = ~(n_hom_alt | n_hom_ref)
182185
# This would 0 out the `het` case with any missing calls
183186
agg = lambda x: xr.where(M, False, x).sum(dim=dim) # type: ignore[no-untyped-call]
184-
return xr.Dataset(
187+
return Dataset(
185188
{
186189
f"{odim}_n_het": agg(n_het), # type: ignore[no-untyped-call]
187190
f"{odim}_n_hom_ref": agg(n_hom_ref), # type: ignore[no-untyped-call]
@@ -192,26 +195,29 @@ def genotype_count(ds: Dataset, dim: Dimension) -> Dataset:
192195

193196

194197
def allele_frequency(ds: Dataset) -> Dataset:
195-
AC = count_variant_alleles(ds)
198+
data_vars: Dict[Hashable, Any] = {}
199+
# only compute variant allele count if not already in dataset
200+
if "variant_allele_count" in ds:
201+
AC = ds["variant_allele_count"]
202+
else:
203+
AC = count_variant_alleles(ds, merge=False)["variant_allele_count"]
204+
data_vars["variant_allele_count"] = AC
196205

197206
M = ds["call_genotype_mask"].stack(calls=("samples", "ploidy"))
198207
AN = (~M).sum(dim="calls") # type: ignore
199208
assert AN.shape == (ds.dims["variants"],)
200209

201-
return xr.Dataset(
202-
{
203-
"variant_allele_count": AC,
204-
"variant_allele_total": AN,
205-
"variant_allele_frequency": AC / AN,
206-
}
207-
)
210+
data_vars["variant_allele_total"] = AN
211+
data_vars["variant_allele_frequency"] = AC / AN
212+
return Dataset(data_vars)
208213

209214

210-
def variant_stats(ds: Dataset) -> Dataset:
211-
return xr.merge(
215+
def variant_stats(ds: Dataset, merge: bool = True) -> Dataset:
216+
new_ds = xr.merge(
212217
[
213218
call_rate(ds, dim="samples"),
214219
genotype_count(ds, dim="samples"),
215220
allele_frequency(ds),
216221
]
217222
)
223+
return ds.merge(new_ds) if merge else new_ds

sgkit/tests/test_aggregation.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
from typing import Any
22

33
import numpy as np
4+
import pytest
45
import xarray as xr
56
from xarray import Dataset
67

7-
from sgkit.stats.aggregation import count_call_alleles, count_variant_alleles, variant_stats
8+
from sgkit.stats.aggregation import (
9+
count_call_alleles,
10+
count_variant_alleles,
11+
variant_stats,
12+
)
813
from sgkit.testing import simulate_genotype_call_dataset
914
from sgkit.typing import ArrayLike
1015

@@ -204,10 +209,13 @@ def test_count_call_alleles__chunked():
204209
xr.testing.assert_equal(ac1, ac2) # type: ignore[no-untyped-call]
205210

206211

207-
def test_variant_stats():
212+
@pytest.mark.parametrize("precompute_variant_allele_count", [False, True])
213+
def test_variant_stats(precompute_variant_allele_count):
208214
ds = get_dataset(
209215
[[[1, 0], [-1, -1]], [[1, 0], [1, 1]], [[0, 1], [1, 0]], [[-1, -1], [0, 0]]]
210216
)
217+
if precompute_variant_allele_count:
218+
ds = count_variant_alleles(ds)
211219
vs = variant_stats(ds)
212220

213221
np.testing.assert_equal(vs["variant_n_called"], np.array([1, 2, 2, 1]))

0 commit comments

Comments
 (0)