Skip to content

Commit 3df0af6

Browse files
committed
Cleanup for HWE function
1 parent 07c98db commit 3df0af6

File tree

1 file changed

+29
-12
lines changed

1 file changed

+29
-12
lines changed

sgkit/stats/hwe.py

+29-12
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from xarray import Dataset
99

1010
from sgkit import variables
11+
from sgkit.stats.aggregation import genotype_count
1112
from sgkit.utils import conditional_merge_datasets
1213

1314

@@ -153,11 +154,15 @@ def hardy_weinberg_test(
153154
Input variable name holding call_genotype_mask.
154155
Defined by :data:`sgkit.variables.call_genotype_mask_spec`
155156
ploidy
156-
Genotype ploidy, defaults to ``ploidy`` dimension of genotype
157-
call array (:data:`sgkit.variables.call_genotype_spec`) if present.
158-
If that variable is not present, then this value must be set.
157+
Genotype ploidy, defaults to ``ploidy`` dimension of provided dataset.
158+
If the `ploidy` dimension is not present, then this value must be set explicitly.
159159
Currently HWE calculations are only supported for diploid datasets,
160160
i.e. ``ploidy`` must equal 2.
161+
alleles
162+
Genotype allele count, defaults to ``alleles`` dimension of provided dataset.
163+
If the `alleles` dimension is not present, then this value must be set explicitly.
164+
Currently HWE calculations are only supported for biallelic datasets,
165+
i.e. ``alleles`` must equal 2.
161166
merge
162167
If True (the default), merge the input dataset and the computed
163168
output variables into a single dataset, otherwise return only
@@ -171,8 +176,9 @@ def hardy_weinberg_test(
171176
Returns
172177
-------
173178
Dataset containing (N = num variants):
179+
174180
variant_hwe_p_value : [array-like, shape: (N, O)]
175-
P values from HWE test for each variant as float in [0, 1].
181+
P values from HWE test for each variant as float in [0, 1].
176182
177183
References
178184
----------
@@ -190,12 +196,19 @@ def hardy_weinberg_test(
190196
ploidy = ploidy or ds.dims.get("ploidy")
191197
if not ploidy:
192198
raise ValueError(
193-
"`ploidy` parameter must be set when not present as array dimension."
199+
"`ploidy` parameter must be set when not present as dataset dimension."
194200
)
195201
if ploidy != 2:
196202
raise NotImplementedError("HWE test only implemented for diploid genotypes")
197-
if ds.dims["alleles"] != 2:
203+
204+
alleles = alleles or ds.dims.get("alleles")
205+
if not alleles:
206+
raise ValueError(
207+
"`alleles` parameter must be set when not present as dataset dimension."
208+
)
209+
if alleles != 2:
198210
raise NotImplementedError("HWE test only implemented for biallelic genotypes")
211+
199212
# Use precomputed genotype counts if provided
200213
if genotype_counts is not None:
201214
variables.validate(ds, {genotype_counts: variables.genotype_counts_spec})
@@ -209,12 +222,16 @@ def hardy_weinberg_test(
209222
call_genotype: variables.call_genotype_spec,
210223
},
211224
)
212-
# TODO: Use API genotype counting function instead, e.g.
213-
# https://github.com/pystatgen/sgkit/issues/29#issuecomment-656691069
214-
M = ds[call_genotype_mask].any(dim="ploidy")
215-
AC = xr.where(M, -1, ds[call_genotype].sum(dim="ploidy")) # type: ignore[no-untyped-call]
216-
cts = [1, 0, 2] # arg order: hets, hom1, hom2
217-
obs = [da.asarray((AC == ct).sum(dim="samples")) for ct in cts]
225+
ds_ct = genotype_count(
226+
ds,
227+
dim="samples",
228+
call_genotype=call_genotype,
229+
call_genotype_mask=call_genotype_mask,
230+
)
231+
obs = [
232+
da.asarray(ds_ct[v])
233+
for v in ["variant_n_het", "variant_n_hom_ref", "variant_n_hom_alt"]
234+
]
218235
p = da.map_blocks(hardy_weinberg_p_value_vec_jit, *obs)
219236
new_ds = xr.Dataset({variables.variant_hwe_p_value: ("variants", p)})
220237
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)

0 commit comments

Comments
 (0)