Skip to content

Commit 55e8e89

Browse files
committed
Cohort utilities
1 parent 8078bf8 commit 55e8e89

File tree

2 files changed

+113
-0
lines changed

2 files changed

+113
-0
lines changed

sgkit/cohorts.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from typing import Optional, Sequence, Tuple, Union
2+
3+
import numpy as np
4+
import pandas as pd
5+
6+
from sgkit.typing import ArrayLike
7+
8+
9+
def _tuple_len(t: Union[int, Tuple[int, ...], str, Tuple[str, ...]]) -> int:
10+
"""Return the length of a tuple, or 1 for an int or string value."""
11+
if isinstance(t, int) or isinstance(t, str):
12+
return 1
13+
return len(t)
14+
15+
16+
def _cohorts_to_array(
17+
cohorts: Sequence[Union[int, Tuple[int, ...], str, Tuple[str, ...]]],
18+
index: Optional[pd.Index] = None,
19+
) -> ArrayLike:
20+
"""Convert cohorts or cohort tuples specified as a sequence of values or
21+
tuples to an array of ints used to match samples in ``sample_cohorts``.
22+
23+
Cohorts can be specified by index (as used in ``sample_cohorts``), or a label, in
24+
which case an ``index`` must be provided to find index locations for cohorts.
25+
26+
Parameters
27+
----------
28+
cohorts
29+
A sequence of values or tuple representing cohorts or cohort tuples.
30+
index
31+
An index to turn labels into index locations, by default None.
32+
33+
Returns
34+
-------
35+
An array of shape ``(len(cohorts), tuple_len)``, where ``tuple_len`` is the length
36+
of the tuples, or 1 if ``cohorts`` is a sequence of values.,
37+
38+
Raises
39+
------
40+
ValueError
41+
If the cohort tuples are not all the same length.
42+
"""
43+
if len(cohorts) == 0:
44+
return np.array([], np.int32)
45+
46+
tuple_len = _tuple_len(cohorts[0])
47+
if not all(_tuple_len(cohort) == tuple_len for cohort in cohorts):
48+
raise ValueError("Cohort tuples must all be the same length")
49+
50+
# convert cohort IDs using an index
51+
if index is not None:
52+
if isinstance(cohorts[0], str):
53+
cohorts = [index.get_loc(id) for id in cohorts]
54+
elif tuple_len > 1 and isinstance(cohorts[0][0], str): # type: ignore
55+
cohorts = [tuple(index.get_loc(id) for id in t) for t in cohorts] # type: ignore
56+
57+
ct = np.empty((len(cohorts), tuple_len), np.int32)
58+
for n, t in enumerate(cohorts):
59+
ct[n, :] = t
60+
return ct

sgkit/tests/test_cohorts.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
5+
from sgkit.cohorts import _cohorts_to_array, _tuple_len
6+
7+
8+
def test_tuple_len():
9+
assert _tuple_len(tuple()) == 0
10+
assert _tuple_len(1) == 1
11+
assert _tuple_len("a") == 1
12+
assert _tuple_len("ab") == 1
13+
assert _tuple_len((1,)) == 1
14+
assert _tuple_len(("a",)) == 1
15+
assert _tuple_len(("ab",)) == 1
16+
assert _tuple_len((1, 2)) == 2
17+
assert _tuple_len(("a", "b")) == 2
18+
assert _tuple_len(("ab", "cd")) == 2
19+
20+
21+
def test_cohorts_to_array__indexes():
22+
with pytest.raises(ValueError, match="Cohort tuples must all be the same length"):
23+
_cohorts_to_array([(0, 1), (0, 1, 2)])
24+
25+
np.testing.assert_equal(_cohorts_to_array([]), np.array([]))
26+
np.testing.assert_equal(_cohorts_to_array([0, 1]), np.array([[0], [1]]))
27+
np.testing.assert_equal(
28+
_cohorts_to_array([(0, 1), (2, 1)]), np.array([[0, 1], [2, 1]])
29+
)
30+
np.testing.assert_equal(
31+
_cohorts_to_array([(0, 1, 2), (3, 1, 2)]), np.array([[0, 1, 2], [3, 1, 2]])
32+
)
33+
34+
35+
def test_cohorts_to_array__ids():
36+
with pytest.raises(ValueError, match="Cohort tuples must all be the same length"):
37+
_cohorts_to_array([("c0", "c1"), ("c0", "c1", "c2")])
38+
39+
np.testing.assert_equal(_cohorts_to_array([]), np.array([]))
40+
np.testing.assert_equal(
41+
_cohorts_to_array(["c0", "c1"], pd.Index(["c0", "c1"])),
42+
np.array([[0], [1]]),
43+
)
44+
np.testing.assert_equal(
45+
_cohorts_to_array([("c0", "c1"), ("c2", "c1")], pd.Index(["c0", "c1", "c2"])),
46+
np.array([[0, 1], [2, 1]]),
47+
)
48+
np.testing.assert_equal(
49+
_cohorts_to_array(
50+
[("c0", "c1", "c2"), ("c3", "c1", "c2")], pd.Index(["c0", "c1", "c2", "c3"])
51+
),
52+
np.array([[0, 1, 2], [3, 1, 2]]),
53+
)

0 commit comments

Comments
 (0)