@@ -176,6 +176,7 @@ def find_group_cohorts(labels, chunks, merge: bool = True):
176
176
axis = range (- labels .ndim , 0 )
177
177
# Easier to create a dask array and use the .blocks property
178
178
array = dask .array .ones (tuple (sum (c ) for c in chunks ), chunks = chunks )
179
+ labels = np .broadcast_to (labels , array .shape [- labels .ndim :])
179
180
180
181
# Iterate over each block and create a new block of same shape with "chunk number"
181
182
shape = tuple (array .blocks .shape [ax ] for ax in axis )
@@ -479,7 +480,7 @@ def factorize_(
479
480
idx , groups = pd .factorize (flat , sort = sort )
480
481
481
482
found_groups .append (np .array (groups ))
482
- factorized .append (idx )
483
+ factorized .append (idx . reshape ( groupvar . shape ) )
483
484
484
485
grp_shape = tuple (len (grp ) for grp in found_groups )
485
486
ngroups = math .prod (grp_shape )
@@ -489,20 +490,18 @@ def factorize_(
489
490
# Restore these after the raveling
490
491
nan_by_mask = reduce (np .logical_or , [(f == - 1 ) for f in factorized ])
491
492
group_idx [nan_by_mask ] = - 1
492
- group_idx = group_idx .reshape (by [0 ].shape )
493
493
else :
494
494
group_idx = factorized [0 ]
495
495
496
496
if fastpath :
497
- return group_idx . reshape ( by [ 0 ]. shape ) , found_groups , grp_shape
497
+ return group_idx , found_groups , grp_shape
498
498
499
499
if np .isscalar (axis ) and groupvar .ndim > 1 :
500
500
# Not reducing along all dimensions of by
501
501
# this is OK because for 3D by and axis=(1,2),
502
502
# we collapse to a 2D by and axis=-1
503
503
offset_group = True
504
504
group_idx , size = offset_labels (group_idx .reshape (by [0 ].shape ), ngroups )
505
- group_idx = group_idx .reshape (- 1 )
506
505
else :
507
506
size = ngroups
508
507
offset_group = False
@@ -647,6 +646,8 @@ def chunk_reduce(
647
646
else :
648
647
nax = by .ndim
649
648
649
+ assert by .ndim <= array .ndim
650
+
650
651
final_array_shape = array .shape [:- nax ] + (1 ,) * (nax - 1 )
651
652
final_groups_shape = (1 ,) * (nax - 1 )
652
653
@@ -667,9 +668,17 @@ def chunk_reduce(
667
668
)
668
669
groups = groups [0 ]
669
670
671
+ if isinstance (axis , Sequence ):
672
+ needs_broadcast = any (
673
+ group_idx .shape [ax ] != array .shape [ax ] and group_idx .shape [ax ] == 1
674
+ for ax in range (- len (axis ), 0 )
675
+ )
676
+ if needs_broadcast :
677
+ group_idx = np .broadcast_to (group_idx , array .shape [- by .ndim :])
670
678
# always reshape to 1D along group dimensions
671
679
newshape = array .shape [: array .ndim - by .ndim ] + (math .prod (array .shape [- by .ndim :]),)
672
680
array = array .reshape (newshape )
681
+ group_idx = group_idx .reshape (- 1 )
673
682
674
683
assert group_idx .ndim == 1
675
684
empty = np .all (props .nanmask )
@@ -1219,7 +1228,9 @@ def dask_groupby_agg(
1219
1228
# chunk numpy arrays like the input array
1220
1229
# This removes an extra rechunk-merge layer that would be
1221
1230
# added otherwise
1222
- by = dask .array .from_array (by , chunks = tuple (array .chunks [ax ] for ax in range (- by .ndim , 0 )))
1231
+ chunks = tuple (array .chunks [ax ] if by .shape [ax ] != 1 else (1 ,) for ax in range (- by .ndim , 0 ))
1232
+
1233
+ by = dask .array .from_array (by , chunks = chunks )
1223
1234
_ , (array , by ) = dask .array .unify_chunks (array , inds , by , inds [- by .ndim :])
1224
1235
1225
1236
# preprocess the array: for argreductions, this zips the index together with the array block
@@ -1429,7 +1440,7 @@ def _extract_result(result_dict: FinalResultsDict, key) -> np.ndarray:
1429
1440
1430
1441
1431
1442
def _validate_reindex (
1432
- reindex : bool | None , func , method : T_Method , expected_groups , by_is_dask : bool
1443
+ reindex : bool | None , func , method : T_Method , expected_groups , any_by_dask : bool
1433
1444
) -> bool :
1434
1445
if reindex is True :
1435
1446
if _is_arg_reduction (func ):
@@ -1447,7 +1458,7 @@ def _validate_reindex(
1447
1458
reindex = False
1448
1459
1449
1460
elif method == "map-reduce" :
1450
- if expected_groups is None and by_is_dask :
1461
+ if expected_groups is None and any_by_dask :
1451
1462
reindex = False
1452
1463
else :
1453
1464
reindex = True
@@ -1457,8 +1468,9 @@ def _validate_reindex(
1457
1468
1458
1469
1459
1470
def _assert_by_is_aligned (shape , by ):
1471
+ assert all (b .ndim == by [0 ].ndim for b in by [1 :])
1460
1472
for idx , b in enumerate (by ):
1461
- if shape [- b .ndim :] != b .shape :
1473
+ if not all ( j in [ i , 1 ] for i , j in zip ( shape [- b .ndim :], b .shape )) :
1462
1474
raise ValueError (
1463
1475
"`array` and `by` arrays must be aligned "
1464
1476
"i.e. array.shape[-by.ndim :] == by.shape. "
@@ -1495,26 +1507,34 @@ def _lazy_factorize_wrapper(*by, **kwargs):
1495
1507
return group_idx
1496
1508
1497
1509
1498
- def _factorize_multiple (by , expected_groups , by_is_dask , reindex ):
1510
+ def _factorize_multiple (by , expected_groups , any_by_dask , reindex ):
1499
1511
kwargs = dict (
1500
1512
expected_groups = expected_groups ,
1501
1513
axis = None , # always None, we offset later if necessary.
1502
1514
fastpath = True ,
1503
1515
reindex = reindex ,
1504
1516
)
1505
- if by_is_dask :
1517
+ if any_by_dask :
1506
1518
import dask .array
1507
1519
1520
+ # unifying chunks will make sure all arrays in `by` are dask arrays
1521
+ # with compatible chunks, even if there was originally a numpy array
1522
+ inds = tuple (range (by [0 ].ndim ))
1523
+ chunks , by_ = dask .array .unify_chunks (* itertools .chain (* zip (by , (inds ,) * len (by ))))
1524
+
1508
1525
group_idx = dask .array .map_blocks (
1509
1526
_lazy_factorize_wrapper ,
1510
- * np .broadcast_arrays (* by ),
1527
+ * by_ ,
1528
+ chunks = tuple (chunks .values ()),
1511
1529
meta = np .array ((), dtype = np .int64 ),
1512
1530
** kwargs ,
1513
1531
)
1514
1532
found_groups = tuple (
1515
1533
None if is_duck_dask_array (b ) else pd .unique (b .reshape (- 1 )) for b in by
1516
1534
)
1517
- grp_shape = tuple (len (e ) for e in expected_groups )
1535
+ grp_shape = tuple (
1536
+ len (e ) if e is not None else len (f ) for e , f in zip (expected_groups , found_groups )
1537
+ )
1518
1538
else :
1519
1539
group_idx , found_groups , grp_shape = factorize_ (by , ** kwargs )
1520
1540
@@ -1644,15 +1664,16 @@ def groupby_reduce(
1644
1664
1645
1665
bys = tuple (np .asarray (b ) if not is_duck_array (b ) else b for b in by )
1646
1666
nby = len (bys )
1647
- by_is_dask = any (is_duck_dask_array (b ) for b in bys )
1667
+ by_is_dask = tuple (is_duck_dask_array (b ) for b in bys )
1668
+ any_by_dask = any (by_is_dask )
1648
1669
1649
- if method in ["split-reduce" , "cohorts" ] and by_is_dask :
1670
+ if method in ["split-reduce" , "cohorts" ] and any_by_dask :
1650
1671
raise ValueError (f"method={ method !r} can only be used when grouping by numpy arrays." )
1651
1672
1652
1673
if method == "split-reduce" :
1653
1674
method = "cohorts"
1654
1675
1655
- reindex = _validate_reindex (reindex , func , method , expected_groups , by_is_dask )
1676
+ reindex = _validate_reindex (reindex , func , method , expected_groups , any_by_dask )
1656
1677
1657
1678
if not is_duck_array (array ):
1658
1679
array = np .asarray (array )
@@ -1667,6 +1688,11 @@ def groupby_reduce(
1667
1688
expected_groups = (None ,) * nby
1668
1689
1669
1690
_assert_by_is_aligned (array .shape , bys )
1691
+ for idx , (expect , is_dask ) in enumerate (zip (expected_groups , by_is_dask )):
1692
+ if is_dask and (reindex or nby > 1 ) and expect is None :
1693
+ raise ValueError (
1694
+ f"`expected_groups` for array { idx } in `by` cannot be None since it is a dask.array."
1695
+ )
1670
1696
1671
1697
if nby == 1 and not isinstance (expected_groups , tuple ):
1672
1698
expected_groups = (np .asarray (expected_groups ),)
@@ -1686,7 +1712,7 @@ def groupby_reduce(
1686
1712
)
1687
1713
if factorize_early :
1688
1714
bys , final_groups , grp_shape = _factorize_multiple (
1689
- bys , expected_groups , by_is_dask = by_is_dask , reindex = reindex
1715
+ bys , expected_groups , any_by_dask = any_by_dask , reindex = reindex
1690
1716
)
1691
1717
expected_groups = (pd .RangeIndex (math .prod (grp_shape )),)
1692
1718
@@ -1709,7 +1735,7 @@ def groupby_reduce(
1709
1735
1710
1736
# TODO: make sure expected_groups is unique
1711
1737
if nax == 1 and by_ .ndim > 1 and expected_groups is None :
1712
- if not by_is_dask :
1738
+ if not any_by_dask :
1713
1739
expected_groups = _get_expected_groups (by_ , sort )
1714
1740
else :
1715
1741
# When we reduce along all axes, we are guaranteed to see all
0 commit comments