Skip to content

Commit a97b043

Browse files
committed
Cohort subsets for PBS
1 parent 74a9229 commit a97b043

File tree

2 files changed

+60
-12
lines changed

2 files changed

+60
-12
lines changed

sgkit/stats/popgen.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import collections
2-
from typing import Hashable, Optional, Sequence, Union
2+
import itertools
3+
from typing import Hashable, Optional, Sequence, Tuple, Union
34

45
import dask.array as da
56
import numpy as np
67
from numba import guvectorize
78
from xarray import Dataset
89

10+
from sgkit.cohorts import _cohorts_to_array
911
from sgkit.stats.utils import assert_array_shape
1012
from sgkit.typing import ArrayLike
1113
from sgkit.utils import (
@@ -606,10 +608,39 @@ def _pbs(t: ArrayLike, out: ArrayLike) -> None: # pragma: no cover
606608
out[i, j, k] = ret
607609

608610

611+
# c = cohorts, ct = cohort_triples, i = index (size 3)
612+
@guvectorize( # type: ignore
613+
[
614+
"void(float32[:, :], int32[:, :], float32[:,:,:])",
615+
"void(float64[:, :], int32[:, :], float64[:,:,:])",
616+
],
617+
"(c,c),(ct,i)->(c,c,c)",
618+
nopython=True,
619+
cache=True,
620+
)
621+
def _pbs_cohorts(
622+
t: ArrayLike, ct: ArrayLike, out: ArrayLike
623+
) -> None: # pragma: no cover
624+
"""Generalized U-function for computing PBS."""
625+
out[:, :, :] = np.nan # (cohorts, cohorts, cohorts)
626+
n_cohort_triples = ct.shape[0]
627+
for n in range(n_cohort_triples):
628+
i = ct[n, 0]
629+
j = ct[n, 1]
630+
k = ct[n, 2]
631+
ret = (t[i, j] + t[i, k] - t[j, k]) / 2
632+
norm = 1 + (t[i, j] + t[i, k] + t[j, k]) / 2
633+
ret = ret / norm
634+
out[i, j, k] = ret
635+
636+
609637
def pbs(
610638
ds: Dataset,
611639
*,
612640
stat_Fst: Hashable = variables.stat_Fst,
641+
cohorts: Optional[
642+
Sequence[Union[Tuple[int, int, int], Tuple[str, str, str]]]
643+
] = None,
613644
merge: bool = True,
614645
) -> Dataset:
615646
"""Compute the population branching statistic (PBS) between cohort triples.
@@ -627,6 +658,10 @@ def pbs(
627658
:data:`sgkit.variables.stat_Fst_spec`.
628659
If the variable is not present in ``ds``, it will be computed
629660
using :func:`Fst`.
661+
cohorts
662+
The cohort triples to compute statistics for, specified as a sequence of
663+
tuples of cohort indexes or IDs. None (the default) means compute statistics
664+
for all cohorts.
630665
merge
631666
If True (the default), merge the input dataset and the computed
632667
output variables into a single dataset, otherwise return only
@@ -680,7 +715,13 @@ def pbs(
680715
# calculate PBS triples
681716
t = da.asarray(t)
682717
shape = (t.chunks[0], n_cohorts, n_cohorts, n_cohorts)
683-
p = da.map_blocks(_pbs, t, chunks=shape, new_axis=3, dtype=np.float64)
718+
719+
cohorts = cohorts or list(itertools.combinations(range(n_cohorts), 3)) # type: ignore
720+
ct = _cohorts_to_array(cohorts, ds.indexes.get("cohorts_0", None))
721+
722+
p = da.map_blocks(
723+
lambda t: _pbs_cohorts(t, ct), t, chunks=shape, new_axis=3, dtype=np.float64
724+
)
684725
assert_array_shape(p, n_windows, n_cohorts, n_cohorts, n_cohorts)
685726

686727
new_ds = Dataset(

sgkit/tests/test_popgen.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -338,17 +338,21 @@ def test_pbs(sample_size, n_cohorts):
338338

339339

340340
@pytest.mark.parametrize(
341-
"sample_size, n_cohorts",
342-
[(10, 3), (20, 4)],
341+
"sample_size, n_cohorts, cohorts, cohort_indexes",
342+
[
343+
(10, 3, None, None),
344+
(20, 4, None, None),
345+
(20, 4, [(0, 1, 2), (3, 1, 2)], [(0, 1, 2), (3, 1, 2)]),
346+
],
343347
)
344348
@pytest.mark.parametrize("chunks", [(-1, -1), (50, -1)])
345-
def test_pbs__windowed(sample_size, n_cohorts, chunks):
349+
def test_pbs__windowed(sample_size, n_cohorts, cohorts, cohort_indexes, chunks):
346350
ts = msprime.simulate(sample_size, length=200, mutation_rate=0.05, random_seed=42)
347351
ds = ts_to_dataset(ts, chunks) # type: ignore[no-untyped-call]
348352
ds, subsets = add_cohorts(ds, ts, n_cohorts, cohort_key_names=["cohorts_0", "cohorts_1", "cohorts_2"]) # type: ignore[no-untyped-call]
349353
ds = window(ds, size=25)
350354

351-
ds = pbs(ds)
355+
ds = pbs(ds, cohorts=cohorts)
352356

353357
# scikit-allel
354358
for i, j, k in itertools.combinations(range(n_cohorts), 3):
@@ -358,14 +362,17 @@ def test_pbs__windowed(sample_size, n_cohorts, chunks):
358362
.values
359363
)
360364

361-
ac_i = ds.cohort_allele_count.values[:, i, :]
362-
ac_j = ds.cohort_allele_count.values[:, j, :]
363-
ac_k = ds.cohort_allele_count.values[:, k, :]
365+
if cohort_indexes is not None and (i, j, k) not in cohort_indexes:
366+
np.testing.assert_array_equal(stat_pbs, np.full_like(stat_pbs, np.nan))
367+
else:
368+
ac_i = ds.cohort_allele_count.values[:, i, :]
369+
ac_j = ds.cohort_allele_count.values[:, j, :]
370+
ac_k = ds.cohort_allele_count.values[:, k, :]
364371

365-
ska_pbs_value = allel.pbs(ac_i, ac_j, ac_k, window_size=25)
372+
ska_pbs_value = allel.pbs(ac_i, ac_j, ac_k, window_size=25)
366373

367-
# scikit-allel has final window missing
368-
np.testing.assert_allclose(stat_pbs[:-1], ska_pbs_value)
374+
# scikit-allel has final window missing
375+
np.testing.assert_allclose(stat_pbs[:-1], ska_pbs_value)
369376

370377

371378
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)