Skip to content

Cohort subsets #394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions sgkit/cohorts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd

from sgkit.typing import ArrayLike


def _tuple_len(t: Union[int, Tuple[int, ...], str, Tuple[str, ...]]) -> int:
"""Return the length of a tuple, or 1 for an int or string value."""
if isinstance(t, int) or isinstance(t, str):
return 1
return len(t)


def _cohorts_to_array(
cohorts: Sequence[Union[int, Tuple[int, ...], str, Tuple[str, ...]]],
index: Optional[pd.Index] = None,
) -> ArrayLike:
"""Convert cohorts or cohort tuples specified as a sequence of values or
tuples to an array of ints used to match samples in ``sample_cohorts``.

Cohorts can be specified by index (as used in ``sample_cohorts``), or a label, in
which case an ``index`` must be provided to find index locations for cohorts.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance of a couple of simple examples here and the return values? I'm finding it a bit abstract and a concrete example would help understand what the function does.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review @jeromekelleher. Added two examples in pystatgen/sgkit@f4b45ea


Parameters
----------
cohorts
A sequence of values or tuple representing cohorts or cohort tuples.
index
An index to turn labels into index locations, by default None.

Returns
-------
An array of shape ``(len(cohorts), tuple_len)``, where ``tuple_len`` is the length
of the tuples, or 1 if ``cohorts`` is a sequence of values.

Raises
------
ValueError
If the cohort tuples are not all the same length.

Examples
--------

>>> import pandas as pd
>>> from sgkit.cohorts import _cohorts_to_array
>>> _cohorts_to_array([(0, 1), (2, 1)]) # doctest: +SKIP
array([[0, 1],
[2, 1]], dtype=int32)
>>> _cohorts_to_array([("c0", "c1"), ("c2", "c1")], pd.Index(["c0", "c1", "c2"])) # doctest: +SKIP
array([[0, 1],
[2, 1]], dtype=int32)
"""
if len(cohorts) == 0:
return np.array([], np.int32)

tuple_len = _tuple_len(cohorts[0])
if not all(_tuple_len(cohort) == tuple_len for cohort in cohorts):
raise ValueError("Cohort tuples must all be the same length")

# convert cohort IDs using an index
if index is not None:
if isinstance(cohorts[0], str):
cohorts = [index.get_loc(id) for id in cohorts]
elif tuple_len > 1 and isinstance(cohorts[0][0], str): # type: ignore
cohorts = [tuple(index.get_loc(id) for id in t) for t in cohorts] # type: ignore

ct = np.empty((len(cohorts), tuple_len), np.int32)
for n, t in enumerate(cohorts):
ct[n, :] = t
return ct
60 changes: 54 additions & 6 deletions sgkit/stats/popgen.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import collections
from typing import Hashable, Optional
import itertools
from typing import Hashable, Optional, Sequence, Tuple, Union

import dask.array as da
import numpy as np
from numba import guvectorize
from xarray import Dataset

from sgkit.cohorts import _cohorts_to_array
from sgkit.stats.utils import assert_array_shape
from sgkit.typing import ArrayLike
from sgkit.utils import (
Expand Down Expand Up @@ -607,10 +609,39 @@ def _pbs(t: ArrayLike, out: ArrayLike) -> None: # pragma: no cover
out[i, j, k] = ret


# c = cohorts, ct = cohort_triples, i = index (size 3)
@guvectorize( # type: ignore
[
"void(float32[:, :], int32[:, :], float32[:,:,:])",
"void(float64[:, :], int32[:, :], float64[:,:,:])",
],
"(c,c),(ct,i)->(c,c,c)",
nopython=True,
cache=True,
)
def _pbs_cohorts(
t: ArrayLike, ct: ArrayLike, out: ArrayLike
) -> None: # pragma: no cover
"""Generalized U-function for computing PBS."""
out[:, :, :] = np.nan # (cohorts, cohorts, cohorts)
n_cohort_triples = ct.shape[0]
for n in range(n_cohort_triples):
i = ct[n, 0]
j = ct[n, 1]
k = ct[n, 2]
ret = (t[i, j] + t[i, k] - t[j, k]) / 2
norm = 1 + (t[i, j] + t[i, k] + t[j, k]) / 2
ret = ret / norm
out[i, j, k] = ret


def pbs(
ds: Dataset,
*,
stat_Fst: Hashable = variables.stat_Fst,
cohorts: Optional[
Sequence[Union[Tuple[int, int, int], Tuple[str, str, str]]]
] = None,
merge: bool = True,
) -> Dataset:
"""Compute the population branching statistic (PBS) between cohort triples.
Expand All @@ -628,6 +659,10 @@ def pbs(
:data:`sgkit.variables.stat_Fst_spec`.
If the variable is not present in ``ds``, it will be computed
using :func:`Fst`.
cohorts
The cohort triples to compute statistics for, specified as a sequence of
tuples of cohort indexes or IDs. None (the default) means compute statistics
for all cohorts.
merge
If True (the default), merge the input dataset and the computed
output variables into a single dataset, otherwise return only
Expand Down Expand Up @@ -681,7 +716,13 @@ def pbs(
# calculate PBS triples
t = da.asarray(t)
shape = (t.chunks[0], n_cohorts, n_cohorts, n_cohorts)
p = da.map_blocks(_pbs, t, chunks=shape, new_axis=3, dtype=np.float64)

cohorts = cohorts or list(itertools.combinations(range(n_cohorts), 3)) # type: ignore
ct = _cohorts_to_array(cohorts, ds.indexes.get("cohorts_0", None))

p = da.map_blocks(
lambda t: _pbs_cohorts(t, ct), t, chunks=shape, new_axis=3, dtype=np.float64
)
assert_array_shape(p, n_windows, n_cohorts, n_cohorts, n_cohorts)

new_ds = Dataset(
Expand Down Expand Up @@ -719,12 +760,12 @@ def _Garud_h(haplotypes: ArrayLike) -> ArrayLike:


def _Garud_h_cohorts(
gt: ArrayLike, sample_cohort: ArrayLike, n_cohorts: int
gt: ArrayLike, sample_cohort: ArrayLike, n_cohorts: int, ct: ArrayLike
) -> ArrayLike:
# transpose to hash columns (haplotypes)
haplotypes = hash_array(gt.transpose()).transpose().flatten()
arr = np.empty((n_cohorts, N_GARUD_H_STATS))
for c in range(n_cohorts):
arr = np.full((n_cohorts, N_GARUD_H_STATS), np.nan)
for c in np.nditer(ct):
arr[c, :] = _Garud_h(haplotypes[sample_cohort == c])
return arr

Expand All @@ -733,6 +774,7 @@ def Garud_h(
ds: Dataset,
*,
call_genotype: Hashable = variables.call_genotype,
cohorts: Optional[Sequence[Union[int, str]]] = None,
merge: bool = True,
) -> Dataset:
"""Compute the H1, H12, H123 and H2/H1 statistics for detecting signatures
Expand All @@ -750,6 +792,10 @@ def Garud_h(
Input variable name holding call_genotype as defined by
:data:`sgkit.variables.call_genotype_spec`.
Must be present in ``ds``.
cohorts
The cohorts to compute statistics for, specified as a sequence of
cohort indexes or IDs. None (the default) means compute statistics
for all cohorts.
merge
If True (the default), merge the input dataset and the computed
output variables into a single dataset, otherwise return only
Expand Down Expand Up @@ -825,10 +871,12 @@ def Garud_h(
sc = ds.sample_cohort.values
hsc = np.stack((sc, sc), axis=1).ravel() # TODO: assumes diploid
n_cohorts = sc.max() + 1 # 0-based indexing
cohorts = cohorts or range(n_cohorts)
ct = _cohorts_to_array(cohorts, ds.indexes.get("cohorts", None))

gh = window_statistic(
gt,
lambda gt: _Garud_h_cohorts(gt, hsc, n_cohorts),
lambda gt: _Garud_h_cohorts(gt, hsc, n_cohorts, ct),
ds.window_start.values,
ds.window_stop.values,
dtype=np.float64,
Expand Down
53 changes: 53 additions & 0 deletions sgkit/tests/test_cohorts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import numpy as np
import pandas as pd
import pytest

from sgkit.cohorts import _cohorts_to_array, _tuple_len


def test_tuple_len():
assert _tuple_len(tuple()) == 0
assert _tuple_len(1) == 1
assert _tuple_len("a") == 1
assert _tuple_len("ab") == 1
assert _tuple_len((1,)) == 1
assert _tuple_len(("a",)) == 1
assert _tuple_len(("ab",)) == 1
assert _tuple_len((1, 2)) == 2
assert _tuple_len(("a", "b")) == 2
assert _tuple_len(("ab", "cd")) == 2


def test_cohorts_to_array__indexes():
with pytest.raises(ValueError, match="Cohort tuples must all be the same length"):
_cohorts_to_array([(0, 1), (0, 1, 2)])

np.testing.assert_equal(_cohorts_to_array([]), np.array([]))
np.testing.assert_equal(_cohorts_to_array([0, 1]), np.array([[0], [1]]))
np.testing.assert_equal(
_cohorts_to_array([(0, 1), (2, 1)]), np.array([[0, 1], [2, 1]])
)
np.testing.assert_equal(
_cohorts_to_array([(0, 1, 2), (3, 1, 2)]), np.array([[0, 1, 2], [3, 1, 2]])
)


def test_cohorts_to_array__ids():
with pytest.raises(ValueError, match="Cohort tuples must all be the same length"):
_cohorts_to_array([("c0", "c1"), ("c0", "c1", "c2")])

np.testing.assert_equal(_cohorts_to_array([]), np.array([]))
np.testing.assert_equal(
_cohorts_to_array(["c0", "c1"], pd.Index(["c0", "c1"])),
np.array([[0], [1]]),
)
np.testing.assert_equal(
_cohorts_to_array([("c0", "c1"), ("c2", "c1")], pd.Index(["c0", "c1", "c2"])),
np.array([[0, 1], [2, 1]]),
)
np.testing.assert_equal(
_cohorts_to_array(
[("c0", "c1", "c2"), ("c3", "c1", "c2")], pd.Index(["c0", "c1", "c2", "c3"])
),
np.array([[0, 1, 2], [3, 1, 2]]),
)
Loading