Skip to content

Commit 8e01e0c

Browse files
committed
Generic cohort reductions sgkit-dev#730
1 parent 000aa4f commit 8e01e0c

File tree

2 files changed

+262
-1
lines changed

2 files changed

+262
-1
lines changed

sgkit/stats/utils.py

+210-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from typing import Hashable, Tuple
1+
from functools import wraps
2+
from typing import Callable, Hashable, Tuple
23

34
import dask.array as da
45
import numpy as np
56
import xarray as xr
67
from dask.array import Array
8+
from numba import guvectorize
79
from xarray import DataArray, Dataset
810

911
from ..typing import ArrayLike
@@ -109,3 +111,210 @@ def map_blocks_asnumpy(x: Array) -> Array:
109111

110112
x = x.map_blocks(cp.asnumpy)
111113
return x
114+
115+
116+
def cohort_reduction(gufunc: Callable) -> Callable:
117+
@wraps(gufunc)
118+
def func(x: ArrayLike, cohort: ArrayLike, n: int, axis: int = -1) -> ArrayLike:
119+
out = da.apply_gufunc(
120+
gufunc,
121+
gufunc.ufunc.signature,
122+
da.swapaxes(x, axis, -1),
123+
cohort,
124+
np.empty(n, np.int8),
125+
)
126+
return da.swapaxes(out, axis, -1)
127+
128+
return func
129+
130+
131+
@cohort_reduction
132+
@guvectorize(
133+
[
134+
"(uint8[:], int64[:], int8[:], uint64[:])",
135+
"(uint64[:], int64[:], int8[:], uint64[:])",
136+
"(int8[:], int64[:], int8[:], int64[:])",
137+
"(int64[:], int64[:], int8[:], int64[:])",
138+
"(float32[:], int64[:], int8[:], float32[:])",
139+
"(float64[:], int64[:], int8[:], float64[:])",
140+
],
141+
"(n),(n),(c)->(c)",
142+
)
143+
def cohort_sum(
144+
x: ArrayLike, cohort: ArrayLike, _: ArrayLike, out: ArrayLike
145+
) -> ArrayLike:
146+
"""Sum of values by cohort.
147+
148+
Parameters
149+
----------
150+
x
151+
Array of values corresponding to each sample.
152+
cohort
153+
Array of integers indicating the cohort membership of
154+
each sample with negative values indicating no cohort.
155+
n
156+
Number of cohorts.
157+
axis
158+
The axis of array x corresponding to samples (defaults
159+
to final axis).
160+
161+
Returns
162+
-------
163+
An array with the same number of dimensions as x in which
164+
the sample axis has been replaced with a cohort axis of
165+
size n.
166+
"""
167+
out[:] = 0
168+
n = len(x)
169+
for i in range(n):
170+
c = cohort[i]
171+
if c >= 0:
172+
out[c] += x[i]
173+
return
174+
175+
176+
@cohort_reduction
177+
@guvectorize(
178+
[
179+
"(uint8[:], int64[:], int8[:], uint64[:])",
180+
"(uint64[:], int64[:], int8[:], uint64[:])",
181+
"(int8[:], int64[:], int8[:], int64[:])",
182+
"(int64[:], int64[:], int8[:], int64[:])",
183+
"(float32[:], int64[:], int8[:], float32[:])",
184+
"(float64[:], int64[:], int8[:], float64[:])",
185+
],
186+
"(n),(n),(c)->(c)",
187+
)
188+
def cohort_nansum(
189+
x: ArrayLike, cohort: ArrayLike, _: ArrayLike, out: ArrayLike
190+
) -> ArrayLike:
191+
"""Sum of values by cohort ignoring nan values.
192+
193+
Parameters
194+
----------
195+
x
196+
Array of values corresponding to each sample.
197+
cohort
198+
Array of integers indicating the cohort membership of
199+
each sample with negative values indicating no cohort.
200+
n
201+
Number of cohorts.
202+
axis
203+
The axis of array x corresponding to samples (defaults
204+
to final axis).
205+
206+
Returns
207+
-------
208+
An array with the same number of dimensions as x in which
209+
the sample axis has been replaced with a cohort axis of
210+
size n.
211+
"""
212+
out[:] = 0
213+
n = len(x)
214+
for i in range(n):
215+
c = cohort[i]
216+
v = x[i]
217+
if (not np.isnan(v)) and (c >= 0):
218+
out[cohort[i]] += v
219+
return
220+
221+
222+
@cohort_reduction
223+
@guvectorize(
224+
[
225+
"(uint8[:], int64[:], int8[:], float64[:])",
226+
"(uint64[:], int64[:], int8[:], float64[:])",
227+
"(int8[:], int64[:], int8[:], float64[:])",
228+
"(int64[:], int64[:], int8[:], float64[:])",
229+
"(float32[:], int64[:], int8[:], float32[:])",
230+
"(float64[:], int64[:], int8[:], float64[:])",
231+
],
232+
"(n),(n),(c)->(c)",
233+
)
234+
def cohort_mean(
235+
x: ArrayLike, cohort: ArrayLike, _: ArrayLike, out: ArrayLike
236+
) -> ArrayLike:
237+
"""Mean of values by cohort.
238+
239+
Parameters
240+
----------
241+
x
242+
Array of values corresponding to each sample.
243+
cohort
244+
Array of integers indicating the cohort membership of
245+
each sample with negative values indicating no cohort.
246+
n
247+
Number of cohorts.
248+
axis
249+
The axis of array x corresponding to samples (defaults
250+
to final axis).
251+
252+
Returns
253+
-------
254+
An array with the same number of dimensions as x in which
255+
the sample axis has been replaced with a cohort axis of
256+
size n.
257+
"""
258+
out[:] = 0
259+
n = len(x)
260+
c = len(_)
261+
count = np.zeros(c)
262+
for i in range(n):
263+
j = cohort[i]
264+
if j >= 0:
265+
out[j] += x[i]
266+
count[j] += 1
267+
for j in range(c):
268+
out[j] /= count[j]
269+
return
270+
271+
272+
@cohort_reduction
273+
@guvectorize(
274+
[
275+
"(uint8[:], int64[:], int8[:], float64[:])",
276+
"(uint64[:], int64[:], int8[:], float64[:])",
277+
"(int8[:], int64[:], int8[:], float64[:])",
278+
"(int64[:], int64[:], int8[:], float64[:])",
279+
"(float32[:], int64[:], int8[:], float32[:])",
280+
"(float64[:], int64[:], int8[:], float64[:])",
281+
],
282+
"(n),(n),(c)->(c)",
283+
)
284+
def cohort_nanmean(
285+
x: ArrayLike, cohort: ArrayLike, _: ArrayLike, out: ArrayLike
286+
) -> ArrayLike:
287+
"""Mean of values by cohort ignoring nan values.
288+
289+
Parameters
290+
----------
291+
x
292+
Array of values corresponding to each sample.
293+
cohort
294+
Array of integers indicating the cohort membership of
295+
each sample with negative values indicating no cohort.
296+
n
297+
Number of cohorts.
298+
axis
299+
The axis of array x corresponding to samples (defaults
300+
to final axis).
301+
302+
Returns
303+
-------
304+
An array with the same number of dimensions as x in which
305+
the sample axis has been replaced with a cohort axis of
306+
size n.
307+
"""
308+
out[:] = 0
309+
n = len(x)
310+
c = len(_)
311+
count = np.zeros(c)
312+
for i in range(n):
313+
j = cohort[i]
314+
v = x[i]
315+
if (not np.isnan(v)) and (j >= 0):
316+
out[j] += v
317+
count[j] += 1
318+
for j in range(c):
319+
out[j] /= count[j]
320+
return

sgkit/tests/test_stats_utils.py

+52
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
assert_array_shape,
1313
assert_block_shape,
1414
assert_chunk_shape,
15+
cohort_mean,
16+
cohort_nanmean,
17+
cohort_nansum,
18+
cohort_sum,
1519
concat_2d,
1620
r2_score,
1721
)
@@ -164,3 +168,51 @@ def _col_shape_sum(ds: Dataset) -> int:
164168

165169
def _rename_dim(ds: Dataset, prefix: str, name: str) -> Dataset:
166170
return ds.rename_dims({d: name for d in ds.dims if str(d).startswith(prefix)})
171+
172+
173+
def _random_cohort_data(shape, n, axis, missing=0.0, scale=1, dtype=float, seed=0):
174+
np.random.seed(seed)
175+
x = np.random.rand(*shape) * scale
176+
idx = np.random.choice([1, 0], shape, p=[missing, 1 - missing]).astype(bool)
177+
x[idx] = np.nan
178+
x = x.astype(dtype)
179+
cohort = np.random.randint(-1, n, size=shape[axis])
180+
return x, cohort, n, axis
181+
182+
183+
def _cohort_reduction(func, x, cohort, n, axis=-1):
184+
# reference implementation
185+
out = []
186+
for i in range(n):
187+
idx = np.where(cohort == i)[0]
188+
x_c = np.take(x, idx, axis=axis)
189+
out.append(func(x_c, axis=axis))
190+
out = np.swapaxes(np.array(out), 0, axis)
191+
return out
192+
193+
194+
@pytest.mark.parametrize(
195+
"x, cohort, n, axis",
196+
[
197+
_random_cohort_data((20,), n=3, axis=0),
198+
_random_cohort_data((20,), n=3, axis=0, dtype=np.float32),
199+
_random_cohort_data((20,), n=3, axis=-1, missing=0.3),
200+
_random_cohort_data((20,), n=3, axis=-1, scale=30, dtype=np.int16),
201+
_random_cohort_data((7, 103, 4), n=5, axis=1, scale=7, missing=0.3),
202+
_random_cohort_data((7, 103, 4), n=5, axis=1, scale=7, dtype=np.uint8),
203+
],
204+
)
205+
@pytest.mark.parametrize(
206+
"reduction, func",
207+
[
208+
(cohort_sum, np.sum),
209+
(cohort_nansum, np.nansum),
210+
(cohort_mean, np.mean),
211+
(cohort_nanmean, np.nanmean),
212+
],
213+
)
214+
def test_cohort_reductions(reduction, func, x, cohort, n, axis):
215+
expect = _cohort_reduction(func, x, cohort, n, axis=axis)
216+
actual = reduction(x, cohort, n, axis=axis)
217+
assert expect.dtype == actual.dtype
218+
np.testing.assert_array_almost_equal(expect, actual)

0 commit comments

Comments
 (0)