Skip to content

Commit 2b22e80

Browse files
authored
HWE function updates (#334)
* Rework HWE inputs * Cleanup for HWE function
1 parent 6160f06 commit 2b22e80

File tree

1 file changed

+40
-10
lines changed

1 file changed

+40
-10
lines changed

sgkit/stats/hwe.py

+40-10
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from xarray import Dataset
99

1010
from sgkit import variables
11+
from sgkit.stats.aggregation import genotype_count
1112
from sgkit.utils import conditional_merge_datasets
1213

1314

@@ -129,7 +130,9 @@ def hardy_weinberg_test(
129130
genotype_counts: Optional[Hashable] = None,
130131
call_genotype: Hashable = variables.call_genotype,
131132
call_genotype_mask: Hashable = variables.call_genotype_mask,
132-
merge: bool = True,
133+
ploidy: Optional[int] = None,
134+
alleles: Optional[int] = None,
135+
merge: bool = True
133136
) -> Dataset:
134137
"""Exact test for HWE as described in Wigginton et al. 2005 [1].
135138
@@ -150,6 +153,16 @@ def hardy_weinberg_test(
150153
call_genotype_mask
151154
Input variable name holding call_genotype_mask.
152155
Defined by :data:`sgkit.variables.call_genotype_mask_spec`
156+
ploidy
157+
Genotype ploidy, defaults to ``ploidy`` dimension of provided dataset.
158+
If the `ploidy` dimension is not present, then this value must be set explicitly.
159+
Currently HWE calculations are only supported for diploid datasets,
160+
i.e. ``ploidy`` must equal 2.
161+
alleles
162+
Genotype allele count, defaults to ``alleles`` dimension of provided dataset.
163+
If the `alleles` dimension is not present, then this value must be set explicitly.
164+
Currently HWE calculations are only supported for biallelic datasets,
165+
i.e. ``alleles`` must equal 2.
153166
merge
154167
If True (the default), merge the input dataset and the computed
155168
output variables into a single dataset, otherwise return only
@@ -163,8 +176,9 @@ def hardy_weinberg_test(
163176
Returns
164177
-------
165178
Dataset containing (N = num variants):
179+
166180
variant_hwe_p_value : [array-like, shape: (N, O)]
167-
P values from HWE test for each variant as float in [0, 1].
181+
P values from HWE test for each variant as float in [0, 1].
168182
169183
References
170184
----------
@@ -179,10 +193,22 @@ def hardy_weinberg_test(
179193
NotImplementedError
180194
If maximum number of alleles in provided dataset != 2
181195
"""
182-
if ds.dims["ploidy"] != 2:
196+
ploidy = ploidy or ds.dims.get("ploidy")
197+
if not ploidy:
198+
raise ValueError(
199+
"`ploidy` parameter must be set when not present as dataset dimension."
200+
)
201+
if ploidy != 2:
183202
raise NotImplementedError("HWE test only implemented for diploid genotypes")
184-
if ds.dims["alleles"] != 2:
203+
204+
alleles = alleles or ds.dims.get("alleles")
205+
if not alleles:
206+
raise ValueError(
207+
"`alleles` parameter must be set when not present as dataset dimension."
208+
)
209+
if alleles != 2:
185210
raise NotImplementedError("HWE test only implemented for biallelic genotypes")
211+
186212
# Use precomputed genotype counts if provided
187213
if genotype_counts is not None:
188214
variables.validate(ds, {genotype_counts: variables.genotype_counts_spec})
@@ -196,12 +222,16 @@ def hardy_weinberg_test(
196222
call_genotype: variables.call_genotype_spec,
197223
},
198224
)
199-
# TODO: Use API genotype counting function instead, e.g.
200-
# https://github.com/pystatgen/sgkit/issues/29#issuecomment-656691069
201-
M = ds[call_genotype_mask].any(dim="ploidy")
202-
AC = xr.where(M, -1, ds[call_genotype].sum(dim="ploidy")) # type: ignore[no-untyped-call]
203-
cts = [1, 0, 2] # arg order: hets, hom1, hom2
204-
obs = [da.asarray((AC == ct).sum(dim="samples")) for ct in cts]
225+
ds_ct = genotype_count(
226+
ds,
227+
dim="samples",
228+
call_genotype=call_genotype,
229+
call_genotype_mask=call_genotype_mask,
230+
)
231+
obs = [
232+
da.asarray(ds_ct[v])
233+
for v in ["variant_n_het", "variant_n_hom_ref", "variant_n_hom_alt"]
234+
]
205235
p = da.map_blocks(hardy_weinberg_p_value_vec_jit, *obs)
206236
new_ds = xr.Dataset({variables.variant_hwe_p_value: ("variants", p)})
207237
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)

0 commit comments

Comments
 (0)