Skip to content

Commit fdc3c3d

Browse files
headtr1ckpre-commit-ci[bot]max-sixty
authored
Fix Dataset/DataArray.isel with drop=True and scalar DataArray indexes (#6579)
* apply drop argument in isel_fancy * use literal type for error handling * add test for drop support in isel * add isel fix to whats-new * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * correct isel unit tests * add link to issue * type most (all?) occurences of errors/missing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Maximilian Roos <[email protected]>
1 parent 218e77a commit fdc3c3d

File tree

8 files changed

+79
-42
lines changed

8 files changed

+79
-42
lines changed

doc/whats-new.rst

+4
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ New Features
4444
- :py:meth:`xr.polyval` now supports :py:class:`Dataset` and :py:class:`DataArray` args of any shape,
4545
is faster and requires less memory. (:pull:`6548`)
4646
By `Michael Niklas <https://github.com/headtr1ck>`_.
47+
- Improved overall typing.
4748

4849
Breaking changes
4950
~~~~~~~~~~~~~~~~
@@ -119,6 +120,9 @@ Bug fixes
119120
:pull:`6489`). By `Spencer Clark <https://github.com/spencerkclark>`_.
120121
- Dark themes are now properly detected in Furo-themed Sphinx documents (:issue:`6500`, :pull:`6501`).
121122
By `Kevin Paul <https://github.com/kmpaul>`_.
123+
- :py:meth:`isel` with `drop=True` works as intended with scalar :py:class:`DataArray` indexers.
124+
(:issue:`6554`, :pull:`6579`)
125+
By `Michael Niklas <https://github.com/headtr1ck>`_.
122126

123127
Documentation
124128
~~~~~~~~~~~~~

xarray/core/dataarray.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
except ImportError:
7979
iris_Cube = None
8080

81-
from .types import T_DataArray, T_Xarray
81+
from .types import ErrorChoice, ErrorChoiceWithWarn, T_DataArray, T_Xarray
8282

8383

8484
def _infer_coords_and_dims(
@@ -1171,7 +1171,7 @@ def isel(
11711171
self,
11721172
indexers: Mapping[Any, Any] = None,
11731173
drop: bool = False,
1174-
missing_dims: str = "raise",
1174+
missing_dims: ErrorChoiceWithWarn = "raise",
11751175
**indexers_kwargs: Any,
11761176
) -> DataArray:
11771177
"""Return a new DataArray whose data is given by integer indexing
@@ -1186,7 +1186,7 @@ def isel(
11861186
If DataArrays are passed as indexers, xarray-style indexing will be
11871187
carried out. See :ref:`indexing` for the details.
11881188
One of indexers or indexers_kwargs must be provided.
1189-
drop : bool, optional
1189+
drop : bool, default: False
11901190
If ``drop=True``, drop coordinates variables indexed by integers
11911191
instead of making them scalar.
11921192
missing_dims : {"raise", "warn", "ignore"}, default: "raise"
@@ -2335,7 +2335,7 @@ def transpose(
23352335
self,
23362336
*dims: Hashable,
23372337
transpose_coords: bool = True,
2338-
missing_dims: str = "raise",
2338+
missing_dims: ErrorChoiceWithWarn = "raise",
23392339
) -> DataArray:
23402340
"""Return a new DataArray object with transposed dimensions.
23412341
@@ -2386,16 +2386,16 @@ def T(self) -> DataArray:
23862386
return self.transpose()
23872387

23882388
def drop_vars(
2389-
self, names: Hashable | Iterable[Hashable], *, errors: str = "raise"
2389+
self, names: Hashable | Iterable[Hashable], *, errors: ErrorChoice = "raise"
23902390
) -> DataArray:
23912391
"""Returns an array with dropped variables.
23922392
23932393
Parameters
23942394
----------
23952395
names : hashable or iterable of hashable
23962396
Name(s) of variables to drop.
2397-
errors : {"raise", "ignore"}, optional
2398-
If 'raise' (default), raises a ValueError error if any of the variable
2397+
errors : {"raise", "ignore"}, default: "raise"
2398+
If 'raise', raises a ValueError error if any of the variable
23992399
passed are not in the dataset. If 'ignore', any given names that are in the
24002400
DataArray are dropped and no error is raised.
24012401
@@ -2412,7 +2412,7 @@ def drop(
24122412
labels: Mapping = None,
24132413
dim: Hashable = None,
24142414
*,
2415-
errors: str = "raise",
2415+
errors: ErrorChoice = "raise",
24162416
**labels_kwargs,
24172417
) -> DataArray:
24182418
"""Backward compatible method based on `drop_vars` and `drop_sel`
@@ -2431,7 +2431,7 @@ def drop_sel(
24312431
self,
24322432
labels: Mapping[Any, Any] = None,
24332433
*,
2434-
errors: str = "raise",
2434+
errors: ErrorChoice = "raise",
24352435
**labels_kwargs,
24362436
) -> DataArray:
24372437
"""Drop index labels from this DataArray.
@@ -2440,8 +2440,8 @@ def drop_sel(
24402440
----------
24412441
labels : mapping of hashable to Any
24422442
Index labels to drop
2443-
errors : {"raise", "ignore"}, optional
2444-
If 'raise' (default), raises a ValueError error if
2443+
errors : {"raise", "ignore"}, default: "raise"
2444+
If 'raise', raises a ValueError error if
24452445
any of the index labels passed are not
24462446
in the dataset. If 'ignore', any given labels that are in the
24472447
dataset are dropped and no error is raised.
@@ -4589,7 +4589,7 @@ def query(
45894589
queries: Mapping[Any, Any] = None,
45904590
parser: str = "pandas",
45914591
engine: str = None,
4592-
missing_dims: str = "raise",
4592+
missing_dims: ErrorChoiceWithWarn = "raise",
45934593
**queries_kwargs: Any,
45944594
) -> DataArray:
45954595
"""Return a new data array indexed along the specified

xarray/core/dataset.py

+23-17
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@
106106
from ..backends import AbstractDataStore, ZarrStore
107107
from .dataarray import DataArray
108108
from .merge import CoercibleMapping
109-
from .types import T_Xarray
109+
from .types import ErrorChoice, ErrorChoiceWithWarn, T_Xarray
110110

111111
try:
112112
from dask.delayed import Delayed
@@ -2059,7 +2059,7 @@ def chunk(
20592059
return self._replace(variables)
20602060

20612061
def _validate_indexers(
2062-
self, indexers: Mapping[Any, Any], missing_dims: str = "raise"
2062+
self, indexers: Mapping[Any, Any], missing_dims: ErrorChoiceWithWarn = "raise"
20632063
) -> Iterator[tuple[Hashable, int | slice | np.ndarray | Variable]]:
20642064
"""Here we make sure
20652065
+ indexer has a valid keys
@@ -2164,7 +2164,7 @@ def isel(
21642164
self,
21652165
indexers: Mapping[Any, Any] = None,
21662166
drop: bool = False,
2167-
missing_dims: str = "raise",
2167+
missing_dims: ErrorChoiceWithWarn = "raise",
21682168
**indexers_kwargs: Any,
21692169
) -> Dataset:
21702170
"""Returns a new dataset with each array indexed along the specified
@@ -2183,14 +2183,14 @@ def isel(
21832183
If DataArrays are passed as indexers, xarray-style indexing will be
21842184
carried out. See :ref:`indexing` for the details.
21852185
One of indexers or indexers_kwargs must be provided.
2186-
drop : bool, optional
2186+
drop : bool, default: False
21872187
If ``drop=True``, drop coordinates variables indexed by integers
21882188
instead of making them scalar.
21892189
missing_dims : {"raise", "warn", "ignore"}, default: "raise"
21902190
What to do if dimensions that should be selected from are not present in the
21912191
Dataset:
21922192
- "raise": raise an exception
2193-
- "warning": raise a warning, and ignore the missing dimensions
2193+
- "warn": raise a warning, and ignore the missing dimensions
21942194
- "ignore": ignore the missing dimensions
21952195
**indexers_kwargs : {dim: indexer, ...}, optional
21962196
The keyword arguments form of ``indexers``.
@@ -2255,7 +2255,7 @@ def _isel_fancy(
22552255
indexers: Mapping[Any, Any],
22562256
*,
22572257
drop: bool,
2258-
missing_dims: str = "raise",
2258+
missing_dims: ErrorChoiceWithWarn = "raise",
22592259
) -> Dataset:
22602260
valid_indexers = dict(self._validate_indexers(indexers, missing_dims))
22612261

@@ -2271,6 +2271,10 @@ def _isel_fancy(
22712271
}
22722272
if var_indexers:
22732273
new_var = var.isel(indexers=var_indexers)
2274+
# drop scalar coordinates
2275+
# https://github.com/pydata/xarray/issues/6554
2276+
if name in self.coords and drop and new_var.ndim == 0:
2277+
continue
22742278
else:
22752279
new_var = var.copy(deep=False)
22762280
if name not in indexes:
@@ -4521,16 +4525,16 @@ def _assert_all_in_dataset(
45214525
)
45224526

45234527
def drop_vars(
4524-
self, names: Hashable | Iterable[Hashable], *, errors: str = "raise"
4528+
self, names: Hashable | Iterable[Hashable], *, errors: ErrorChoice = "raise"
45254529
) -> Dataset:
45264530
"""Drop variables from this dataset.
45274531
45284532
Parameters
45294533
----------
45304534
names : hashable or iterable of hashable
45314535
Name(s) of variables to drop.
4532-
errors : {"raise", "ignore"}, optional
4533-
If 'raise' (default), raises a ValueError error if any of the variable
4536+
errors : {"raise", "ignore"}, default: "raise"
4537+
If 'raise', raises a ValueError error if any of the variable
45344538
passed are not in the dataset. If 'ignore', any given names that are in the
45354539
dataset are dropped and no error is raised.
45364540
@@ -4556,7 +4560,9 @@ def drop_vars(
45564560
variables, coord_names=coord_names, indexes=indexes
45574561
)
45584562

4559-
def drop(self, labels=None, dim=None, *, errors="raise", **labels_kwargs):
4563+
def drop(
4564+
self, labels=None, dim=None, *, errors: ErrorChoice = "raise", **labels_kwargs
4565+
):
45604566
"""Backward compatible method based on `drop_vars` and `drop_sel`
45614567
45624568
Using either `drop_vars` or `drop_sel` is encouraged
@@ -4605,15 +4611,15 @@ def drop(self, labels=None, dim=None, *, errors="raise", **labels_kwargs):
46054611
)
46064612
return self.drop_sel(labels, errors=errors)
46074613

4608-
def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs):
4614+
def drop_sel(self, labels=None, *, errors: ErrorChoice = "raise", **labels_kwargs):
46094615
"""Drop index labels from this dataset.
46104616
46114617
Parameters
46124618
----------
46134619
labels : mapping of hashable to Any
46144620
Index labels to drop
4615-
errors : {"raise", "ignore"}, optional
4616-
If 'raise' (default), raises a ValueError error if
4621+
errors : {"raise", "ignore"}, default: "raise"
4622+
If 'raise', raises a ValueError error if
46174623
any of the index labels passed are not
46184624
in the dataset. If 'ignore', any given labels that are in the
46194625
dataset are dropped and no error is raised.
@@ -4740,7 +4746,7 @@ def drop_isel(self, indexers=None, **indexers_kwargs):
47404746
return ds
47414747

47424748
def drop_dims(
4743-
self, drop_dims: Hashable | Iterable[Hashable], *, errors: str = "raise"
4749+
self, drop_dims: Hashable | Iterable[Hashable], *, errors: ErrorChoice = "raise"
47444750
) -> Dataset:
47454751
"""Drop dimensions and associated variables from this dataset.
47464752
@@ -4780,7 +4786,7 @@ def drop_dims(
47804786
def transpose(
47814787
self,
47824788
*dims: Hashable,
4783-
missing_dims: str = "raise",
4789+
missing_dims: ErrorChoiceWithWarn = "raise",
47844790
) -> Dataset:
47854791
"""Return a new Dataset object with all array dimensions transposed.
47864792
@@ -7714,7 +7720,7 @@ def query(
77147720
queries: Mapping[Any, Any] = None,
77157721
parser: str = "pandas",
77167722
engine: str = None,
7717-
missing_dims: str = "raise",
7723+
missing_dims: ErrorChoiceWithWarn = "raise",
77187724
**queries_kwargs: Any,
77197725
) -> Dataset:
77207726
"""Return a new dataset with each array indexed along the specified
@@ -7747,7 +7753,7 @@ def query(
77477753
Dataset:
77487754
77497755
- "raise": raise an exception
7750-
- "warning": raise a warning, and ignore the missing dimensions
7756+
- "warn": raise a warning, and ignore the missing dimensions
77517757
- "ignore": ignore the missing dimensions
77527758
77537759
**queries_kwargs : {dim: query, ...}, optional

xarray/core/indexes.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222

2323
from . import formatting, nputils, utils
2424
from .indexing import IndexSelResult, PandasIndexingAdapter, PandasMultiIndexingAdapter
25-
from .types import T_Index
2625
from .utils import Frozen, get_valid_numpy_dtype, is_dict_like, is_scalar
2726

2827
if TYPE_CHECKING:
28+
from .types import ErrorChoice, T_Index
2929
from .variable import Variable
3030

3131
IndexVars = Dict[Any, "Variable"]
@@ -1098,15 +1098,15 @@ def is_multi(self, key: Hashable) -> bool:
10981098
return len(self._id_coord_names[self._coord_name_id[key]]) > 1
10991099

11001100
def get_all_coords(
1101-
self, key: Hashable, errors: str = "raise"
1101+
self, key: Hashable, errors: ErrorChoice = "raise"
11021102
) -> dict[Hashable, Variable]:
11031103
"""Return all coordinates having the same index.
11041104
11051105
Parameters
11061106
----------
11071107
key : hashable
11081108
Index key.
1109-
errors : {"raise", "ignore"}, optional
1109+
errors : {"raise", "ignore"}, default: "raise"
11101110
If "raise", raises a ValueError if `key` is not in indexes.
11111111
If "ignore", an empty tuple is returned instead.
11121112
@@ -1129,15 +1129,15 @@ def get_all_coords(
11291129
return {k: self._variables[k] for k in all_coord_names}
11301130

11311131
def get_all_dims(
1132-
self, key: Hashable, errors: str = "raise"
1132+
self, key: Hashable, errors: ErrorChoice = "raise"
11331133
) -> Mapping[Hashable, int]:
11341134
"""Return all dimensions shared by an index.
11351135
11361136
Parameters
11371137
----------
11381138
key : hashable
11391139
Index key.
1140-
errors : {"raise", "ignore"}, optional
1140+
errors : {"raise", "ignore"}, default: "raise"
11411141
If "raise", raises a ValueError if `key` is not in indexes.
11421142
If "ignore", an empty tuple is returned instead.
11431143

xarray/core/types.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, TypeVar, Union
3+
from typing import TYPE_CHECKING, Literal, TypeVar, Union
44

55
import numpy as np
66

@@ -33,3 +33,6 @@
3333
DaCompatible = Union["DataArray", "Variable", "DataArrayGroupBy", "ScalarOrArray"]
3434
VarCompatible = Union["Variable", "ScalarOrArray"]
3535
GroupByIncompatible = Union["Variable", "GroupBy"]
36+
37+
ErrorChoice = Literal["raise", "ignore"]
38+
ErrorChoiceWithWarn = Literal["raise", "warn", "ignore"]

xarray/core/utils.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
import numpy as np
3030
import pandas as pd
3131

32+
if TYPE_CHECKING:
33+
from .types import ErrorChoiceWithWarn
34+
3235
K = TypeVar("K")
3336
V = TypeVar("V")
3437
T = TypeVar("T")
@@ -756,7 +759,9 @@ def __len__(self) -> int:
756759

757760

758761
def infix_dims(
759-
dims_supplied: Collection, dims_all: Collection, missing_dims: str = "raise"
762+
dims_supplied: Collection,
763+
dims_all: Collection,
764+
missing_dims: ErrorChoiceWithWarn = "raise",
760765
) -> Iterator:
761766
"""
762767
Resolves a supplied list containing an ellipsis representing other items, to
@@ -804,7 +809,7 @@ def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable:
804809
def drop_dims_from_indexers(
805810
indexers: Mapping[Any, Any],
806811
dims: list | Mapping[Any, int],
807-
missing_dims: str,
812+
missing_dims: ErrorChoiceWithWarn,
808813
) -> Mapping[Hashable, Any]:
809814
"""Depending on the setting of missing_dims, drop any dimensions from indexers that
810815
are not present in dims.
@@ -850,7 +855,7 @@ def drop_dims_from_indexers(
850855

851856

852857
def drop_missing_dims(
853-
supplied_dims: Collection, dims: Collection, missing_dims: str
858+
supplied_dims: Collection, dims: Collection, missing_dims: ErrorChoiceWithWarn
854859
) -> Collection:
855860
"""Depending on the setting of missing_dims, drop any dimensions from supplied_dims that
856861
are not present in dims.

xarray/core/variable.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
BASIC_INDEXING_TYPES = integer_types + (slice,)
6060

6161
if TYPE_CHECKING:
62-
from .types import T_Variable
62+
from .types import ErrorChoiceWithWarn, T_Variable
6363

6464

6565
class MissingDimensionsError(ValueError):
@@ -1159,7 +1159,7 @@ def _to_dense(self):
11591159
def isel(
11601160
self: T_Variable,
11611161
indexers: Mapping[Any, Any] = None,
1162-
missing_dims: str = "raise",
1162+
missing_dims: ErrorChoiceWithWarn = "raise",
11631163
**indexers_kwargs: Any,
11641164
) -> T_Variable:
11651165
"""Return a new array indexed along the specified dimension(s).
@@ -1173,7 +1173,7 @@ def isel(
11731173
What to do if dimensions that should be selected from are not present in the
11741174
DataArray:
11751175
- "raise": raise an exception
1176-
- "warning": raise a warning, and ignore the missing dimensions
1176+
- "warn": raise a warning, and ignore the missing dimensions
11771177
- "ignore": ignore the missing dimensions
11781178
11791179
Returns
@@ -1436,7 +1436,7 @@ def roll(self, shifts=None, **shifts_kwargs):
14361436
def transpose(
14371437
self,
14381438
*dims,
1439-
missing_dims: str = "raise",
1439+
missing_dims: ErrorChoiceWithWarn = "raise",
14401440
) -> Variable:
14411441
"""Return a new Variable object with transposed dimensions.
14421442

0 commit comments

Comments
 (0)