Skip to content
forked from pydata/xarray

Commit d1a3fc1

Browse files
committed
Some refactoring
1 parent 978fad9 commit d1a3fc1

File tree

3 files changed

+21
-11
lines changed

3 files changed

+21
-11
lines changed

xarray/core/dataarray.py

+8
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
Bins,
6464
DaCompatible,
6565
NetcdfWriteModes,
66+
T_Chunks,
6667
T_DataArray,
6768
T_DataArrayOrSet,
6869
ZarrWriteModes,
@@ -105,6 +106,7 @@
105106
Dims,
106107
ErrorOptions,
107108
ErrorOptionsWithWarn,
109+
GroupIndices,
108110
GroupInput,
109111
InterpOptions,
110112
PadModeOptions,
@@ -1687,6 +1689,12 @@ def sel(
16871689
)
16881690
return self._from_temp_dataset(ds)
16891691

1692+
def _shuffle(
1693+
self, dim: Hashable, *, indices: GroupIndices, chunks: T_Chunks
1694+
) -> Self:
1695+
ds = self._to_temp_dataset()._shuffle(dim=dim, indices=indices, chunks=chunks)
1696+
return self._from_temp_dataset(ds)
1697+
16901698
def head(
16911699
self,
16921700
indexers: Mapping[Any, int] | int | None = None,

xarray/core/dataset.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@
155155
DsCompatible,
156156
ErrorOptions,
157157
ErrorOptionsWithWarn,
158+
GroupIndices,
158159
GroupInput,
159160
InterpOptions,
160161
JoinOptions,
@@ -3238,7 +3239,7 @@ def sel(
32383239
result = self.isel(indexers=query_results.dim_indexers, drop=drop)
32393240
return result._overwrite_indexes(*query_results.as_tuple()[1:])
32403241

3241-
def _shuffle(self, dim, *, indices: list[list[int]], chunks: T_Chunks) -> Self:
3242+
def _shuffle(self, dim, *, indices: GroupIndices, chunks: T_Chunks) -> Self:
32423243
# Shuffling is only different from `isel` for chunked arrays.
32433244
# Extract them out, and treat them specially. The rest, we route through isel.
32443245
# This makes it easy to ensure correct handling of indexes.
@@ -3249,14 +3250,22 @@ def _shuffle(self, dim, *, indices: list[list[int]], chunks: T_Chunks) -> Self:
32493250
}
32503251
subset = self[[name for name in self._variables if name not in is_chunked]]
32513252

3253+
no_slices: list[list[int]] = [
3254+
list(range(*idx.indices(self.sizes[dim])))
3255+
if isinstance(idx, slice)
3256+
else idx
3257+
for idx in indices
3258+
]
3259+
no_slices = [idx for idx in no_slices if idx]
3260+
32523261
shuffled = (
32533262
subset
32543263
if dim not in subset.dims
3255-
else subset.isel({dim: np.concatenate(indices)})
3264+
else subset.isel({dim: np.concatenate(no_slices)})
32563265
)
32573266
for name, var in is_chunked.items():
32583267
shuffled[name] = var._shuffle(
3259-
indices=indices,
3268+
indices=no_slices,
32603269
dim=dim,
32613270
chunks=chunks,
32623271
)

xarray/core/groupby.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -743,19 +743,12 @@ def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray:
743743
was_array = isinstance(self._obj, DataArray)
744744
as_dataset = self._obj._to_temp_dataset() if was_array else self._obj
745745

746-
size = self._obj.sizes[self._group_dim]
747-
no_slices: list[list[int]] = [
748-
list(range(*idx.indices(size))) if isinstance(idx, slice) else idx
749-
for idx in self.encoded.group_indices
750-
]
751-
no_slices = [idx for idx in no_slices if idx]
752-
753746
for grouper in self.groupers:
754747
if grouper.name not in as_dataset._variables:
755748
as_dataset.coords[grouper.name] = grouper.group
756749

757750
shuffled = as_dataset._shuffle(
758-
dim=self._group_dim, indices=no_slices, chunks=chunks
751+
dim=self._group_dim, indices=self.encoded.group_indices, chunks=chunks
759752
)
760753
shuffled = self._maybe_unstack(shuffled)
761754
new_obj = self._obj._from_temp_dataset(shuffled) if was_array else shuffled

0 commit comments

Comments
 (0)