Skip to content

Commit 74a9229

Browse files
committed
Cohort subsets for Garud H
1 parent 55e8e89 commit 74a9229

File tree

2 files changed

+45
-19
lines changed

2 files changed

+45
-19
lines changed

sgkit/stats/popgen.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import collections
2-
from typing import Hashable, Optional
2+
from typing import Hashable, Optional, Sequence, Union
33

44
import dask.array as da
55
import numpy as np
@@ -718,12 +718,12 @@ def _Garud_h(haplotypes: ArrayLike) -> ArrayLike:
718718

719719

720720
def _Garud_h_cohorts(
721-
gt: ArrayLike, sample_cohort: ArrayLike, n_cohorts: int
721+
gt: ArrayLike, sample_cohort: ArrayLike, n_cohorts: int, ct: ArrayLike
722722
) -> ArrayLike:
723723
# transpose to hash columns (haplotypes)
724724
haplotypes = hash_array(gt.transpose()).transpose().flatten()
725-
arr = np.empty((n_cohorts, N_GARUD_H_STATS))
726-
for c in range(n_cohorts):
725+
arr = np.full((n_cohorts, N_GARUD_H_STATS), np.nan)
726+
for c in np.nditer(ct):
727727
arr[c, :] = _Garud_h(haplotypes[sample_cohort == c])
728728
return arr
729729

@@ -732,6 +732,7 @@ def Garud_h(
732732
ds: Dataset,
733733
*,
734734
call_genotype: Hashable = variables.call_genotype,
735+
cohorts: Optional[Sequence[Union[int, str]]] = None,
735736
merge: bool = True,
736737
) -> Dataset:
737738
"""Compute the H1, H12, H123 and H2/H1 statistics for detecting signatures
@@ -749,6 +750,10 @@ def Garud_h(
749750
Input variable name holding call_genotype as defined by
750751
:data:`sgkit.variables.call_genotype_spec`.
751752
Must be present in ``ds``.
753+
cohorts
754+
The cohorts to compute statistics for, specified as a sequence of
755+
cohort indexes or IDs. None (the default) means compute statistics
756+
for all cohorts.
752757
merge
753758
If True (the default), merge the input dataset and the computed
754759
output variables into a single dataset, otherwise return only
@@ -824,10 +829,12 @@ def Garud_h(
824829
sc = ds.sample_cohort.values
825830
hsc = np.stack((sc, sc), axis=1).ravel() # TODO: assumes diploid
826831
n_cohorts = sc.max() + 1 # 0-based indexing
832+
cohorts = cohorts or range(n_cohorts)
833+
ct = _cohorts_to_array(cohorts, ds.indexes.get("cohorts", None))
827834

828835
gh = window_statistic(
829836
gt,
830-
lambda gt: _Garud_h_cohorts(gt, hsc, n_cohorts),
837+
lambda gt: _Garud_h_cohorts(gt, hsc, n_cohorts, ct),
831838
ds.window_start.values,
832839
ds.window_stop.values,
833840
dtype=np.float64,

sgkit/tests/test_popgen.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -362,18 +362,25 @@ def test_pbs__windowed(sample_size, n_cohorts, chunks):
362362
ac_j = ds.cohort_allele_count.values[:, j, :]
363363
ac_k = ds.cohort_allele_count.values[:, k, :]
364364

365-
ska_pbs_value = allel.pbs(ac_i, ac_j, ac_k, window_size=25, window_step=25)
365+
ska_pbs_value = allel.pbs(ac_i, ac_j, ac_k, window_size=25)
366366

367367
# scikit-allel has final window missing
368368
np.testing.assert_allclose(stat_pbs[:-1], ska_pbs_value)
369369

370370

371371
@pytest.mark.parametrize(
372-
"n_variants, n_samples, n_contigs, n_cohorts",
373-
[(9, 5, 1, 1), (9, 5, 1, 2)],
372+
"n_variants, n_samples, n_contigs, n_cohorts, cohorts, cohort_indexes",
373+
[
374+
(9, 5, 1, 1, None, None),
375+
(9, 5, 1, 2, None, None),
376+
(9, 5, 1, 2, [1], [1]),
377+
(9, 5, 1, 2, ["co_1"], [1]),
378+
],
374379
)
375380
@pytest.mark.parametrize("chunks", [(-1, -1), (5, -1)])
376-
def test_Garud_h(n_variants, n_samples, n_contigs, n_cohorts, chunks):
381+
def test_Garud_h(
382+
n_variants, n_samples, n_contigs, n_cohorts, cohorts, cohort_indexes, chunks
383+
):
377384
ds = simulate_genotype_call_dataset(
378385
n_variant=n_variants, n_sample=n_samples, n_contig=n_contigs
379386
)
@@ -383,25 +390,37 @@ def test_Garud_h(n_variants, n_samples, n_contigs, n_cohorts, chunks):
383390
[np.full_like(subset, i) for i, subset in enumerate(subsets)]
384391
)
385392
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
393+
cohort_names = [f"co_{i}" for i in range(n_cohorts)]
394+
coords = {k: cohort_names for k in ["cohorts"]}
395+
ds = ds.assign_coords(coords) # type: ignore[no-untyped-call]
386396
ds = window(ds, size=3)
387397

388-
gh = Garud_h(ds)
398+
gh = Garud_h(ds, cohorts=cohorts)
389399
h1 = gh.stat_Garud_h1.values
390400
h12 = gh.stat_Garud_h12.values
391401
h123 = gh.stat_Garud_h123.values
392402
h2_h1 = gh.stat_Garud_h2_h1.values
393403

394404
# scikit-allel
395405
for c in range(n_cohorts):
396-
gt = ds.call_genotype.values[:, sample_cohorts == c, :]
397-
ska_gt = allel.GenotypeArray(gt)
398-
ska_ha = ska_gt.to_haplotypes()
399-
ska_h = allel.moving_garud_h(ska_ha, size=3)
400-
401-
np.testing.assert_allclose(h1[:, c], ska_h[0])
402-
np.testing.assert_allclose(h12[:, c], ska_h[1])
403-
np.testing.assert_allclose(h123[:, c], ska_h[2])
404-
np.testing.assert_allclose(h2_h1[:, c], ska_h[3])
406+
if cohort_indexes is not None and c not in cohort_indexes:
407+
# cohorts that were not computed should be nan
408+
np.testing.assert_array_equal(h1[:, c], np.full_like(h1[:, c], np.nan))
409+
np.testing.assert_array_equal(h12[:, c], np.full_like(h12[:, c], np.nan))
410+
np.testing.assert_array_equal(h123[:, c], np.full_like(h123[:, c], np.nan))
411+
np.testing.assert_array_equal(
412+
h2_h1[:, c], np.full_like(h2_h1[:, c], np.nan)
413+
)
414+
else:
415+
gt = ds.call_genotype.values[:, sample_cohorts == c, :]
416+
ska_gt = allel.GenotypeArray(gt)
417+
ska_ha = ska_gt.to_haplotypes()
418+
ska_h = allel.moving_garud_h(ska_ha, size=3)
419+
420+
np.testing.assert_allclose(h1[:, c], ska_h[0])
421+
np.testing.assert_allclose(h12[:, c], ska_h[1])
422+
np.testing.assert_allclose(h123[:, c], ska_h[2])
423+
np.testing.assert_allclose(h2_h1[:, c], ska_h[3])
405424

406425

407426
def test_Garud_h__raise_on_non_diploid():

0 commit comments

Comments
 (0)