Skip to content

Commit 778d71b

Browse files
committed
Add cohort_statistic function pystatgen#730
1 parent 31dc606 commit 778d71b

File tree

2 files changed

+91
-2
lines changed

2 files changed

+91
-2
lines changed

sgkit/cohorts.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Optional, Sequence, Tuple, Union
1+
from typing import Any, Callable, Optional, Sequence, Tuple, Union
22

3+
import dask.array as da
34
import numpy as np
45
import pandas as pd
56

@@ -70,3 +71,40 @@ def _cohorts_to_array(
7071
for n, t in enumerate(cohorts):
7172
ct[n, :] = t
7273
return ct
74+
75+
76+
def cohort_statistic(
77+
values: ArrayLike,
78+
statistic: Callable[..., ArrayLike],
79+
cohorts: ArrayLike,
80+
sample_axis: int = 1,
81+
**kwargs: Any,
82+
) -> da.Array:
83+
"""Calculate a statistic for each cohort of samples.
84+
85+
Parameters
86+
----------
87+
values
88+
An n-dimensional array of sample values.
89+
statistic
90+
A callable to apply to the samples of each cohort. The callable is
91+
expected to consume the samples axis.
92+
cohorts
93+
An array of integers indicating which cohort each sample is assigned to.
94+
Negative integers indicate that a sample is not assigned to any cohort.
95+
sample_axis
96+
Integer indicating the samples axis of the values array.
97+
kwargs
98+
Key word arguments to pass to the callable statistic.
99+
100+
Returns
101+
-------
102+
Array of results for each cohort.
103+
"""
104+
values = da.asarray(values)
105+
cohorts = np.array(cohorts)
106+
n_cohorts = cohorts.max() + 1
107+
idx = [cohorts == c for c in range(n_cohorts)]
108+
seq = [da.take(values, i, axis=sample_axis) for i in idx]
109+
out = da.stack([statistic(c, **kwargs) for c in seq], axis=sample_axis)
110+
return out

sgkit/tests/test_cohorts.py

+52-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import dask.array as da
12
import numpy as np
23
import pandas as pd
34
import pytest
45

5-
from sgkit.cohorts import _cohorts_to_array, _tuple_len
6+
from sgkit.cohorts import _cohorts_to_array, _tuple_len, cohort_statistic
67

78

89
def test_tuple_len():
@@ -51,3 +52,53 @@ def test_cohorts_to_array__ids():
5152
),
5253
np.array([[0, 1, 2], [3, 1, 2]]),
5354
)
55+
56+
57+
@pytest.mark.parametrize(
58+
"statistic,expect",
59+
[
60+
(
61+
np.mean,
62+
[
63+
[1.0, 0.75, 0.5],
64+
[2 / 3, 0.25, 0.0],
65+
[2 / 3, 0.75, 0.5],
66+
[2 / 3, 0.5, 1.0],
67+
[1 / 3, 0.5, 0.0],
68+
],
69+
),
70+
(np.sum, [[3, 3, 1], [2, 1, 0], [2, 3, 1], [2, 2, 2], [1, 2, 0]]),
71+
],
72+
)
73+
@pytest.mark.parametrize(
74+
"chunks",
75+
[
76+
((5,), (10,)),
77+
((3, 2), (10,)),
78+
((3, 2), (5, 5)),
79+
],
80+
)
81+
def test_cohort_statistic(statistic, expect, chunks):
82+
variables = da.asarray(
83+
[
84+
[1, 1, 1, 0, 1, 1, 0, 0, 1, 1],
85+
[0, 0, 1, 0, 1, 0, 0, 0, 1, 0],
86+
[0, 0, 1, 1, 1, 1, 1, 1, 1, 0],
87+
[1, 1, 1, 1, 0, 0, 0, 0, 1, 1],
88+
[0, 1, 0, 0, 1, 1, 1, 0, 0, 0],
89+
],
90+
chunks=chunks,
91+
)
92+
cohorts = np.array([0, 1, 0, 2, 0, 1, -1, 1, 1, 2])
93+
np.testing.assert_array_equal(
94+
expect, cohort_statistic(variables, statistic, cohorts, axis=1)
95+
)
96+
97+
98+
def test_cohort_statistic_axis0():
99+
variables = da.asarray([2, 3, 2, 4, 3, 1, 4, 5, 3, 1])
100+
cohorts = np.array([0, 0, 0, 0, 0, -1, 1, 1, 1, 2])
101+
np.testing.assert_array_equal(
102+
[2.8, 4.0, 1.0],
103+
cohort_statistic(variables, np.mean, cohorts, sample_axis=0, axis=0),
104+
)

0 commit comments

Comments
 (0)