Skip to content

Commit 1f796cc

Browse files
timothymillarmergify[bot]
authored andcommitted
Generic cohort reductions #730
1 parent 375ad0f commit 1f796cc

File tree

2 files changed

+272
-1
lines changed

2 files changed

+272
-1
lines changed

sgkit/stats/utils.py

+215-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from typing import Hashable, Tuple
1+
from functools import wraps
2+
from typing import Callable, Hashable, Tuple
23

34
import dask.array as da
45
import numpy as np
56
import xarray as xr
67
from dask.array import Array
8+
from numba import guvectorize
79
from xarray import DataArray, Dataset
810

911
from ..typing import ArrayLike
@@ -109,3 +111,215 @@ def map_blocks_asnumpy(x: Array) -> Array:
109111

110112
x = x.map_blocks(cp.asnumpy)
111113
return x
114+
115+
116+
def cohort_reduction(gufunc: Callable) -> Callable:
117+
@wraps(gufunc)
118+
def func(x: ArrayLike, cohort: ArrayLike, n: int, axis: int = -1) -> ArrayLike:
119+
x = da.swapaxes(da.asarray(x), axis, -1)
120+
replaced = len(x.shape) - 1
121+
chunks = x.chunks[0:-1] + (n,)
122+
out = da.map_blocks(
123+
gufunc,
124+
x,
125+
cohort,
126+
np.empty(n, np.int8),
127+
chunks=chunks,
128+
drop_axis=replaced,
129+
new_axis=replaced,
130+
)
131+
return da.swapaxes(out, axis, -1)
132+
133+
return func
134+
135+
136+
@cohort_reduction
137+
@guvectorize(
138+
[
139+
"(uint8[:], int64[:], int8[:], uint64[:])",
140+
"(uint64[:], int64[:], int8[:], uint64[:])",
141+
"(int8[:], int64[:], int8[:], int64[:])",
142+
"(int64[:], int64[:], int8[:], int64[:])",
143+
"(float32[:], int64[:], int8[:], float32[:])",
144+
"(float64[:], int64[:], int8[:], float64[:])",
145+
],
146+
"(n),(n),(c)->(c)",
147+
)
148+
def cohort_sum(
149+
x: ArrayLike, cohort: ArrayLike, _: ArrayLike, out: ArrayLike
150+
) -> None: # pragma: no cover
151+
"""Sum of values by cohort.
152+
153+
Parameters
154+
----------
155+
x
156+
Array of values corresponding to each sample.
157+
cohort
158+
Array of integers indicating the cohort membership of
159+
each sample with negative values indicating no cohort.
160+
n
161+
Number of cohorts.
162+
axis
163+
The axis of array x corresponding to samples (defaults
164+
to final axis).
165+
166+
Returns
167+
-------
168+
An array with the same number of dimensions as x in which
169+
the sample axis has been replaced with a cohort axis of
170+
size n.
171+
"""
172+
out[:] = 0
173+
n = len(x)
174+
for i in range(n):
175+
c = cohort[i]
176+
if c >= 0:
177+
out[c] += x[i]
178+
return
179+
180+
181+
@cohort_reduction
182+
@guvectorize(
183+
[
184+
"(uint8[:], int64[:], int8[:], uint64[:])",
185+
"(uint64[:], int64[:], int8[:], uint64[:])",
186+
"(int8[:], int64[:], int8[:], int64[:])",
187+
"(int64[:], int64[:], int8[:], int64[:])",
188+
"(float32[:], int64[:], int8[:], float32[:])",
189+
"(float64[:], int64[:], int8[:], float64[:])",
190+
],
191+
"(n),(n),(c)->(c)",
192+
)
193+
def cohort_nansum(
194+
x: ArrayLike, cohort: ArrayLike, _: ArrayLike, out: ArrayLike
195+
) -> None: # pragma: no cover
196+
"""Sum of values by cohort ignoring nan values.
197+
198+
Parameters
199+
----------
200+
x
201+
Array of values corresponding to each sample.
202+
cohort
203+
Array of integers indicating the cohort membership of
204+
each sample with negative values indicating no cohort.
205+
n
206+
Number of cohorts.
207+
axis
208+
The axis of array x corresponding to samples (defaults
209+
to final axis).
210+
211+
Returns
212+
-------
213+
An array with the same number of dimensions as x in which
214+
the sample axis has been replaced with a cohort axis of
215+
size n.
216+
"""
217+
out[:] = 0
218+
n = len(x)
219+
for i in range(n):
220+
c = cohort[i]
221+
v = x[i]
222+
if (not np.isnan(v)) and (c >= 0):
223+
out[cohort[i]] += v
224+
return
225+
226+
227+
@cohort_reduction
228+
@guvectorize(
229+
[
230+
"(uint8[:], int64[:], int8[:], float64[:])",
231+
"(uint64[:], int64[:], int8[:], float64[:])",
232+
"(int8[:], int64[:], int8[:], float64[:])",
233+
"(int64[:], int64[:], int8[:], float64[:])",
234+
"(float32[:], int64[:], int8[:], float32[:])",
235+
"(float64[:], int64[:], int8[:], float64[:])",
236+
],
237+
"(n),(n),(c)->(c)",
238+
)
239+
def cohort_mean(
240+
x: ArrayLike, cohort: ArrayLike, _: ArrayLike, out: ArrayLike
241+
) -> None: # pragma: no cover
242+
"""Mean of values by cohort.
243+
244+
Parameters
245+
----------
246+
x
247+
Array of values corresponding to each sample.
248+
cohort
249+
Array of integers indicating the cohort membership of
250+
each sample with negative values indicating no cohort.
251+
n
252+
Number of cohorts.
253+
axis
254+
The axis of array x corresponding to samples (defaults
255+
to final axis).
256+
257+
Returns
258+
-------
259+
An array with the same number of dimensions as x in which
260+
the sample axis has been replaced with a cohort axis of
261+
size n.
262+
"""
263+
out[:] = 0
264+
n = len(x)
265+
c = len(_)
266+
count = np.zeros(c)
267+
for i in range(n):
268+
j = cohort[i]
269+
if j >= 0:
270+
out[j] += x[i]
271+
count[j] += 1
272+
for j in range(c):
273+
out[j] /= count[j]
274+
return
275+
276+
277+
@cohort_reduction
278+
@guvectorize(
279+
[
280+
"(uint8[:], int64[:], int8[:], float64[:])",
281+
"(uint64[:], int64[:], int8[:], float64[:])",
282+
"(int8[:], int64[:], int8[:], float64[:])",
283+
"(int64[:], int64[:], int8[:], float64[:])",
284+
"(float32[:], int64[:], int8[:], float32[:])",
285+
"(float64[:], int64[:], int8[:], float64[:])",
286+
],
287+
"(n),(n),(c)->(c)",
288+
)
289+
def cohort_nanmean(
290+
x: ArrayLike, cohort: ArrayLike, _: ArrayLike, out: ArrayLike
291+
) -> None: # pragma: no cover
292+
"""Mean of values by cohort ignoring nan values.
293+
294+
Parameters
295+
----------
296+
x
297+
Array of values corresponding to each sample.
298+
cohort
299+
Array of integers indicating the cohort membership of
300+
each sample with negative values indicating no cohort.
301+
n
302+
Number of cohorts.
303+
axis
304+
The axis of array x corresponding to samples (defaults
305+
to final axis).
306+
307+
Returns
308+
-------
309+
An array with the same number of dimensions as x in which
310+
the sample axis has been replaced with a cohort axis of
311+
size n.
312+
"""
313+
out[:] = 0
314+
n = len(x)
315+
c = len(_)
316+
count = np.zeros(c)
317+
for i in range(n):
318+
j = cohort[i]
319+
v = x[i]
320+
if (not np.isnan(v)) and (j >= 0):
321+
out[j] += v
322+
count[j] += 1
323+
for j in range(c):
324+
out[j] /= count[j]
325+
return

sgkit/tests/test_stats_utils.py

+57
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
assert_array_shape,
1313
assert_block_shape,
1414
assert_chunk_shape,
15+
cohort_mean,
16+
cohort_nanmean,
17+
cohort_nansum,
18+
cohort_sum,
1519
concat_2d,
1620
r2_score,
1721
)
@@ -164,3 +168,56 @@ def _col_shape_sum(ds: Dataset) -> int:
164168

165169
def _rename_dim(ds: Dataset, prefix: str, name: str) -> Dataset:
166170
return ds.rename_dims({d: name for d in ds.dims if str(d).startswith(prefix)})
171+
172+
173+
def _random_cohort_data(chunks, n, axis, missing=0.0, scale=1, dtype=float, seed=0):
174+
shape = tuple(np.sum(tup) for tup in chunks)
175+
np.random.seed(seed)
176+
x = np.random.rand(*shape) * scale
177+
idx = np.random.choice([1, 0], shape, p=[missing, 1 - missing]).astype(bool)
178+
x[idx] = np.nan
179+
x = da.asarray(x, chunks=chunks, dtype=dtype)
180+
cohort = np.random.randint(-1, n, size=shape[axis])
181+
return x, cohort, n, axis
182+
183+
184+
def _cohort_reduction(func, x, cohort, n, axis=-1):
185+
# reference implementation
186+
out = []
187+
for i in range(n):
188+
idx = np.where(cohort == i)[0]
189+
x_c = np.take(x, idx, axis=axis)
190+
out.append(func(x_c, axis=axis))
191+
out = np.swapaxes(np.array(out), 0, axis)
192+
return out
193+
194+
195+
@pytest.mark.parametrize(
196+
"x, cohort, n, axis",
197+
[
198+
_random_cohort_data((20,), n=3, axis=0),
199+
_random_cohort_data((20, 20), n=2, axis=0, dtype=np.float32),
200+
_random_cohort_data((10, 10), n=2, axis=-1, scale=30, dtype=np.int16),
201+
_random_cohort_data((20, 20), n=3, axis=-1, missing=0.3),
202+
_random_cohort_data((7, 103, 4), n=5, axis=1, scale=7, missing=0.3),
203+
_random_cohort_data(
204+
((3, 4), (50, 50, 3), 4), n=5, axis=1, scale=7, dtype=np.uint8
205+
),
206+
_random_cohort_data(
207+
((6, 6), (50, 50, 7), (3, 1)), n=5, axis=1, scale=7, missing=0.3
208+
),
209+
],
210+
)
211+
@pytest.mark.parametrize(
212+
"reduction, func",
213+
[
214+
(cohort_sum, np.sum),
215+
(cohort_nansum, np.nansum),
216+
(cohort_mean, np.mean),
217+
(cohort_nanmean, np.nanmean),
218+
],
219+
)
220+
def test_cohort_reductions(reduction, func, x, cohort, n, axis):
221+
expect = _cohort_reduction(func, x, cohort, n, axis=axis)
222+
actual = reduction(x, cohort, n, axis=axis)
223+
np.testing.assert_array_almost_equal(expect, actual)

0 commit comments

Comments
 (0)