|
| 1 | +import dask.array as da |
1 | 2 | import numpy as np
|
2 | 3 | import pandas as pd
|
3 | 4 | import pytest
|
4 | 5 |
|
5 |
| -from sgkit.cohorts import _cohorts_to_array, _tuple_len |
| 6 | +from sgkit.cohorts import _cohorts_to_array, _tuple_len, cohort_statistic |
6 | 7 |
|
7 | 8 |
|
8 | 9 | def test_tuple_len():
|
@@ -51,3 +52,53 @@ def test_cohorts_to_array__ids():
|
51 | 52 | ),
|
52 | 53 | np.array([[0, 1, 2], [3, 1, 2]]),
|
53 | 54 | )
|
| 55 | + |
| 56 | + |
| 57 | +@pytest.mark.parametrize( |
| 58 | + "statistic,expect", |
| 59 | + [ |
| 60 | + ( |
| 61 | + np.mean, |
| 62 | + [ |
| 63 | + [1.0, 0.75, 0.5], |
| 64 | + [2 / 3, 0.25, 0.0], |
| 65 | + [2 / 3, 0.75, 0.5], |
| 66 | + [2 / 3, 0.5, 1.0], |
| 67 | + [1 / 3, 0.5, 0.0], |
| 68 | + ], |
| 69 | + ), |
| 70 | + (np.sum, [[3, 3, 1], [2, 1, 0], [2, 3, 1], [2, 2, 2], [1, 2, 0]]), |
| 71 | + ], |
| 72 | +) |
| 73 | +@pytest.mark.parametrize( |
| 74 | + "chunks", |
| 75 | + [ |
| 76 | + ((5,), (10,)), |
| 77 | + ((3, 2), (10,)), |
| 78 | + ((3, 2), (5, 5)), |
| 79 | + ], |
| 80 | +) |
| 81 | +def test_cohort_statistic(statistic, expect, chunks): |
| 82 | + variables = da.asarray( |
| 83 | + [ |
| 84 | + [1, 1, 1, 0, 1, 1, 0, 0, 1, 1], |
| 85 | + [0, 0, 1, 0, 1, 0, 0, 0, 1, 0], |
| 86 | + [0, 0, 1, 1, 1, 1, 1, 1, 1, 0], |
| 87 | + [1, 1, 1, 1, 0, 0, 0, 0, 1, 1], |
| 88 | + [0, 1, 0, 0, 1, 1, 1, 0, 0, 0], |
| 89 | + ], |
| 90 | + chunks=chunks, |
| 91 | + ) |
| 92 | + cohorts = np.array([0, 1, 0, 2, 0, 1, -1, 1, 1, 2]) |
| 93 | + np.testing.assert_array_equal( |
| 94 | + expect, cohort_statistic(variables, statistic, cohorts, axis=1) |
| 95 | + ) |
| 96 | + |
| 97 | + |
| 98 | +def test_cohort_statistic_axis0(): |
| 99 | + variables = da.asarray([2, 3, 2, 4, 3, 1, 4, 5, 3, 1]) |
| 100 | + cohorts = np.array([0, 0, 0, 0, 0, -1, 1, 1, 1, 2]) |
| 101 | + np.testing.assert_array_equal( |
| 102 | + [2.8, 4.0, 1.0], |
| 103 | + cohort_statistic(variables, np.mean, cohorts, sample_axis=0, axis=0), |
| 104 | + ) |
0 commit comments