|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 |
| -from functools import reduce |
| 3 | +from functools import partial, reduce |
4 | 4 | from typing import TYPE_CHECKING
|
5 | 5 |
|
6 | 6 | import numpy as np
|
|
13 | 13 | _convert_expected_groups_to_index,
|
14 | 14 | _get_optimal_chunks_for_groups,
|
15 | 15 | _normalize_indexes,
|
| 16 | + _validate_reindex, |
16 | 17 | factorize_,
|
17 | 18 | find_group_cohorts,
|
18 | 19 | groupby_reduce,
|
@@ -221,14 +222,26 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
|
221 | 222 | if not has_dask:
|
222 | 223 | continue
|
223 | 224 | for method in ["map-reduce", "cohorts", "split-reduce"]:
|
224 |
| - if "arg" in func and method != "map-reduce": |
225 |
| - continue |
226 |
| - actual, *groups = groupby_reduce(array, *by, method=method, **flox_kwargs) |
227 |
| - for actual_group, expect in zip(groups, expected_groups): |
228 |
| - assert_equal(actual_group, expect, tolerance) |
229 |
| - if "arg" in func: |
230 |
| - assert actual.dtype.kind == "i" |
231 |
| - assert_equal(actual, expected, tolerance) |
| 225 | + if method == "map-reduce": |
| 226 | + reindexes = [True, False, None] |
| 227 | + else: |
| 228 | + reindexes = [None] |
| 229 | + for reindex in reindexes: |
| 230 | + call = partial( |
| 231 | + groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs |
| 232 | + ) |
| 233 | + if "arg" in func: |
| 234 | + if method != "map-reduce" or reindex is True: |
| 235 | + with pytest.raises(NotImplementedError): |
| 236 | + call() |
| 237 | + continue |
| 238 | + |
| 239 | + actual, *groups = call() |
| 240 | + for actual_group, expect in zip(groups, expected_groups): |
| 241 | + assert_equal(actual_group, expect, tolerance) |
| 242 | + if "arg" in func: |
| 243 | + assert actual.dtype.kind == "i" |
| 244 | + assert_equal(actual, expected, tolerance) |
232 | 245 |
|
233 | 246 |
|
234 | 247 | @requires_dask
|
@@ -1125,3 +1138,33 @@ def test_subset_block_2d(flatblocks, expectidx):
|
1125 | 1138 | subset = subset_to_blocks(array, flatblocks)
|
1126 | 1139 | assert len(subset.dask.layers) == 2
|
1127 | 1140 | assert_equal(subset, array.compute()[expectidx])
|
| 1141 | + |
| 1142 | + |
| 1143 | +@pytest.mark.parametrize("method", ["map-reduce", "cohorts"]) |
| 1144 | +@pytest.mark.parametrize( |
| 1145 | + "expected, reindex, func, expected_groups, by_is_dask", |
| 1146 | + [ |
| 1147 | + # argmax only False |
| 1148 | + [False, None, "argmax", None, False], |
| 1149 | + # True when by is numpy but expected is None |
| 1150 | + [True, None, "sum", None, False], |
| 1151 | + # False when by is dask but expected is None |
| 1152 | + [False, None, "sum", None, True], |
| 1153 | + # if expected_groups then always True |
| 1154 | + [True, None, "sum", [1, 2, 3], False], |
| 1155 | + [True, None, "sum", ([1], [2]), False], |
| 1156 | + [True, None, "sum", ([1], [2]), True], |
| 1157 | + [True, None, "sum", ([1], None), False], |
| 1158 | + [True, None, "sum", ([1], None), True], |
| 1159 | + ], |
| 1160 | +) |
| 1161 | +def test_validate_reindex(expected, reindex, func, method, expected_groups, by_is_dask): |
| 1162 | + if by_is_dask and method == "cohorts": |
| 1163 | + # This should error elsewhere |
| 1164 | + pytest.skip() |
| 1165 | + call = partial(_validate_reindex, reindex, func, method, expected_groups, by_is_dask) |
| 1166 | + if "arg" in func and method == "cohorts": |
| 1167 | + with pytest.raises(NotImplementedError): |
| 1168 | + call() |
| 1169 | + else: |
| 1170 | + assert call() == expected |
0 commit comments