@@ -529,7 +529,6 @@ def shuffle(self) -> None:
529
529
"""
530
530
from xarray .core .dataarray import DataArray
531
531
from xarray .core .dataset import Dataset
532
- from xarray .core .duck_array_ops import shuffle_array
533
532
534
533
(grouper ,) = self .groupers
535
534
dim = self ._group_dim
@@ -538,6 +537,8 @@ def shuffle(self) -> None:
538
537
if all (isinstance (idx , slice ) for idx in self ._group_indices ):
539
538
return
540
539
540
+ indices : tuple [list [int ]] = self ._group_indices # type: ignore[assignment]
541
+
541
542
was_array = isinstance (self ._obj , DataArray )
542
543
as_dataset = self ._obj ._to_temp_dataset () if was_array else self ._obj
543
544
@@ -546,21 +547,22 @@ def shuffle(self) -> None:
546
547
if dim not in var .dims :
547
548
shuffled [name ] = var
548
549
continue
549
- shuffled_data = shuffle_array (
550
- var ._data , list (self ._group_indices ), axis = var .get_axis_num (dim )
551
- )
552
- shuffled [name ] = var ._replace (data = shuffled_data )
550
+ shuffled [name ] = var ._shuffle (indices = list (indices ), dim = dim )
553
551
554
552
# Replace self._group_indices with slices
555
553
slices = []
556
554
start = 0
557
555
for idxr in self ._group_indices :
556
+ if TYPE_CHECKING :
557
+ assert not isinstance (idxr , slice )
558
558
slices .append (slice (start , start + len (idxr )))
559
559
start += len (idxr )
560
560
# TODO: we have now broken the invariant
561
561
# self._group_indices ≠ self.groupers[0].group_indices
562
562
self ._group_indices = tuple (slices )
563
563
if was_array :
564
+ if TYPE_CHECKING :
565
+ assert isinstance (self ._obj , DataArray )
564
566
self ._obj = self ._obj ._from_temp_dataset (shuffled )
565
567
else :
566
568
self ._obj = shuffled
0 commit comments