Skip to content

Commit 60d7619

Browse files
committed
Cleanup
1 parent 3bc51bd commit 60d7619

File tree

6 files changed

+39
-22
lines changed

6 files changed

+39
-22
lines changed

xarray/core/duck_array_ops.py

-15
Original file line numberDiff line numberDiff line change
@@ -831,18 +831,3 @@ def chunked_nanfirst(darray, axis):
831831

832832
def chunked_nanlast(darray, axis):
833833
return _chunked_first_or_last(darray, axis, op=nputils.nanlast)
834-
835-
836-
def shuffle_array(array, indices: list[list[int]], axis: int):
837-
# TODO: do chunk manager dance here.
838-
if is_duck_dask_array(array):
839-
if not module_available("dask", minversion="2024.08.0"):
840-
raise ValueError(
841-
"This method is very inefficient on dask<2024.08.0. Please upgrade."
842-
)
843-
# TODO: handle dimensions
844-
return array.shuffle(indexer=indices, axis=axis)
845-
else:
846-
indexer = np.concatenate(indices)
847-
# TODO: Do the array API thing here.
848-
return np.take(array, indices=indexer, axis=axis)

xarray/core/groupby.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,6 @@ def shuffle(self) -> None:
529529
"""
530530
from xarray.core.dataarray import DataArray
531531
from xarray.core.dataset import Dataset
532-
from xarray.core.duck_array_ops import shuffle_array
533532

534533
(grouper,) = self.groupers
535534
dim = self._group_dim
@@ -538,6 +537,8 @@ def shuffle(self) -> None:
538537
if all(isinstance(idx, slice) for idx in self._group_indices):
539538
return
540539

540+
indices: tuple[list[int]] = self._group_indices # type: ignore[assignment]
541+
541542
was_array = isinstance(self._obj, DataArray)
542543
as_dataset = self._obj._to_temp_dataset() if was_array else self._obj
543544

@@ -546,21 +547,22 @@ def shuffle(self) -> None:
546547
if dim not in var.dims:
547548
shuffled[name] = var
548549
continue
549-
shuffled_data = shuffle_array(
550-
var._data, list(self._group_indices), axis=var.get_axis_num(dim)
551-
)
552-
shuffled[name] = var._replace(data=shuffled_data)
550+
shuffled[name] = var._shuffle(indices=list(indices), dim=dim)
553551

554552
# Replace self._group_indices with slices
555553
slices = []
556554
start = 0
557555
for idxr in self._group_indices:
556+
if TYPE_CHECKING:
557+
assert not isinstance(idxr, slice)
558558
slices.append(slice(start, start + len(idxr)))
559559
start += len(idxr)
560560
# TODO: we have now broken the invariant
561561
# self._group_indices ≠ self.groupers[0].group_indices
562562
self._group_indices = tuple(slices)
563563
if was_array:
564+
if TYPE_CHECKING:
565+
assert isinstance(self._obj, DataArray)
564566
self._obj = self._obj._from_temp_dataset(shuffled)
565567
else:
566568
self._obj = shuffled

xarray/core/types.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def copy(
297297
ZarrWriteModes = Literal["w", "w-", "a", "a-", "r+", "r"]
298298

299299
GroupKey = Any
300-
GroupIndex = Union[int, slice, list[int]]
300+
GroupIndex = Union[slice, list[int]]
301301
GroupIndices = tuple[GroupIndex, ...]
302302
Bins = Union[
303303
int, Sequence[int], Sequence[float], Sequence[pd.Timestamp], np.ndarray, pd.Index

xarray/core/variable.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,13 @@
4444
maybe_coerce_to_str,
4545
)
4646
from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions
47-
from xarray.namedarray.pycompat import integer_types, is_0d_dask_array, to_duck_array
47+
from xarray.namedarray.parallelcompat import get_chunked_array_type
48+
from xarray.namedarray.pycompat import (
49+
integer_types,
50+
is_0d_dask_array,
51+
is_chunked_array,
52+
to_duck_array,
53+
)
4854
from xarray.util.deprecation_helpers import deprecate_dims
4955

5056
NON_NUMPY_SUPPORTED_ARRAY_TYPES = (
@@ -998,6 +1004,16 @@ def compute(self, **kwargs):
9981004
new = self.copy(deep=False)
9991005
return new.load(**kwargs)
10001006

1007+
def _shuffle(self, indices: list[list[int]], dim: Hashable) -> Self:
1008+
array = self._data
1009+
if is_chunked_array(array):
1010+
chunkmanager = get_chunked_array_type(array)
1011+
return chunkmanager.shuffle(
1012+
array, indexer=indices, axis=self.get_axis_num(dim)
1013+
)
1014+
else:
1015+
return self.isel({dim: np.concatenate(indices)})
1016+
10011017
def isel(
10021018
self,
10031019
indexers: Mapping[Any, Any] | None = None,

xarray/namedarray/daskmanager.py

+9
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,12 @@ def store(
251251
targets=targets,
252252
**kwargs,
253253
)
254+
255+
def shuffle(self, x: DaskArray, indexer: list[list[int]], axis: int) -> DaskArray:
256+
import dask.array
257+
258+
if not module_available("dask", minversion="2024.08.0"):
259+
raise ValueError(
260+
"This method is very inefficient on dask<2024.08.0. Please upgrade."
261+
)
262+
return dask.array.shuffle(x, indexer, axis)

xarray/namedarray/parallelcompat.py

+5
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,11 @@ def compute(
364364
"""
365365
raise NotImplementedError()
366366

367+
def shuffle(
368+
self, x: T_ChunkedArray, indexer: list[list[int]], axis: int
369+
) -> T_ChunkedArray:
370+
raise NotImplementedError()
371+
367372
@property
368373
def array_api(self) -> Any:
369374
"""

0 commit comments

Comments
 (0)