1
1
import collections
2
- from typing import Hashable , Optional , Sequence , Union
2
+ import itertools
3
+ from typing import Hashable , Optional , Sequence , Tuple , Union
3
4
4
5
import dask .array as da
5
6
import numpy as np
6
7
from numba import guvectorize
7
8
from xarray import Dataset
8
9
10
+ from sgkit .cohorts import _cohorts_to_array
9
11
from sgkit .stats .utils import assert_array_shape
10
12
from sgkit .typing import ArrayLike
11
13
from sgkit .utils import (
@@ -606,10 +608,39 @@ def _pbs(t: ArrayLike, out: ArrayLike) -> None: # pragma: no cover
606
608
out [i , j , k ] = ret
607
609
608
610
611
+ # c = cohorts, ct = cohort_triples, i = index (size 3)
612
+ @guvectorize ( # type: ignore
613
+ [
614
+ "void(float32[:, :], int32[:, :], float32[:,:,:])" ,
615
+ "void(float64[:, :], int32[:, :], float64[:,:,:])" ,
616
+ ],
617
+ "(c,c),(ct,i)->(c,c,c)" ,
618
+ nopython = True ,
619
+ cache = True ,
620
+ )
621
+ def _pbs_cohorts (
622
+ t : ArrayLike , ct : ArrayLike , out : ArrayLike
623
+ ) -> None : # pragma: no cover
624
+ """Generalized U-function for computing PBS."""
625
+ out [:, :, :] = np .nan # (cohorts, cohorts, cohorts)
626
+ n_cohort_triples = ct .shape [0 ]
627
+ for n in range (n_cohort_triples ):
628
+ i = ct [n , 0 ]
629
+ j = ct [n , 1 ]
630
+ k = ct [n , 2 ]
631
+ ret = (t [i , j ] + t [i , k ] - t [j , k ]) / 2
632
+ norm = 1 + (t [i , j ] + t [i , k ] + t [j , k ]) / 2
633
+ ret = ret / norm
634
+ out [i , j , k ] = ret
635
+
636
+
609
637
def pbs (
610
638
ds : Dataset ,
611
639
* ,
612
640
stat_Fst : Hashable = variables .stat_Fst ,
641
+ cohorts : Optional [
642
+ Sequence [Union [Tuple [int , int , int ], Tuple [str , str , str ]]]
643
+ ] = None ,
613
644
merge : bool = True ,
614
645
) -> Dataset :
615
646
"""Compute the population branching statistic (PBS) between cohort triples.
@@ -627,6 +658,10 @@ def pbs(
627
658
:data:`sgkit.variables.stat_Fst_spec`.
628
659
If the variable is not present in ``ds``, it will be computed
629
660
using :func:`Fst`.
661
+ cohorts
662
+ The cohort triples to compute statistics for, specified as a sequence of
663
+ tuples of cohort indexes or IDs. None (the default) means compute statistics
664
+ for all cohorts.
630
665
merge
631
666
If True (the default), merge the input dataset and the computed
632
667
output variables into a single dataset, otherwise return only
@@ -680,7 +715,13 @@ def pbs(
680
715
# calculate PBS triples
681
716
t = da .asarray (t )
682
717
shape = (t .chunks [0 ], n_cohorts , n_cohorts , n_cohorts )
683
- p = da .map_blocks (_pbs , t , chunks = shape , new_axis = 3 , dtype = np .float64 )
718
+
719
+ cohorts = cohorts or list (itertools .combinations (range (n_cohorts ), 3 )) # type: ignore
720
+ ct = _cohorts_to_array (cohorts , ds .indexes .get ("cohorts_0" , None ))
721
+
722
+ p = da .map_blocks (
723
+ lambda t : _pbs_cohorts (t , ct ), t , chunks = shape , new_axis = 3 , dtype = np .float64
724
+ )
684
725
assert_array_shape (p , n_windows , n_cohorts , n_cohorts , n_cohorts )
685
726
686
727
new_ds = Dataset (
0 commit comments