Skip to content

Commit cc2d9c9

Browse files
committed
refactor cast label indexer to coord dtype
Make the fix in pydata#3153 specific to pandas indexes (i.e., do not apply it to other, custom indexes). See pydata#5697 for details. This should also fix pydata#5700 although no test has been added yet (we need to refactor set_index first).
1 parent 05c488d commit cc2d9c9

File tree

3 files changed

+28
-16
lines changed

3 files changed

+28
-16
lines changed

xarray/core/indexes.py

+28-7
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,12 @@ def _is_nested_tuple(possible_tuple):
129129
)
130130

131131

132-
def normalize_label(value, extract_scalar=False):
132+
def normalize_label(value, extract_scalar=False, dtype=None):
133133
if getattr(value, "ndim", 1) <= 1:
134134
value = _asarray_tuplesafe(value)
135+
if dtype is not None and dtype.kind == "f":
136+
# see https://github.com/pydata/xarray/pull/3153 for details
137+
value = np.asarray(value, dtype=dtype)
135138
if extract_scalar:
136139
# see https://github.com/pydata/xarray/pull/4292 for details
137140
value = value[()] if value.dtype.kind in "mM" else value.item()
@@ -151,12 +154,16 @@ def get_indexer_nd(index, labels, method=None, tolerance=None):
151154
class PandasIndex(Index):
152155
"""Wrap a pandas.Index as an xarray compatible index."""
153156

154-
__slots__ = ("index", "dim")
157+
__slots__ = ("index", "dim", "coord_dtype")
155158

156-
def __init__(self, array: Any, dim: Hashable):
159+
def __init__(self, array: Any, dim: Hashable, coord_dtype: Any = None):
157160
self.index = utils.safe_cast_to_index(array)
158161
self.dim = dim
159162

163+
if coord_dtype is None:
164+
coord_dtype = self.index.dtype
165+
self.coord_dtype = coord_dtype
166+
160167
@classmethod
161168
def from_variables(cls, variables: Mapping[Hashable, "Variable"]):
162169
from .variable import IndexVariable
@@ -176,7 +183,7 @@ def from_variables(cls, variables: Mapping[Hashable, "Variable"]):
176183

177184
dim = var.dims[0]
178185

179-
obj = cls(var.data, dim)
186+
obj = cls(var.data, dim, coord_dtype=var.dtype)
180187

181188
data = PandasIndexingAdapter(obj.index, dtype=var.dtype)
182189
index_var = IndexVariable(
@@ -219,7 +226,7 @@ def query(self, labels, method=None, tolerance=None):
219226
"a dimension that does not have a MultiIndex"
220227
)
221228
else:
222-
label = normalize_label(label)
229+
label = normalize_label(label, dtype=self.coord_dtype)
223230
if label.ndim == 0:
224231
label_value = normalize_label(label, extract_scalar=True)
225232
if isinstance(self.index, pd.CategoricalIndex):
@@ -289,6 +296,16 @@ def _create_variables_from_multiindex(index, dim, level_meta=None):
289296

290297

291298
class PandasMultiIndex(PandasIndex):
299+
300+
__slots__ = ("index", "dim", "coord_dtype", "level_coords_dtype")
301+
302+
def __init__(self, array: Any, dim: Hashable, level_coords_dtype: Any = None):
303+
super().__init__(array, dim)
304+
305+
if level_coords_dtype is None:
306+
level_coords_dtype = {idx.name: idx.dtype for idx in self.index.levels}
307+
self.level_coords_dtype = level_coords_dtype
308+
292309
@classmethod
293310
def from_variables(cls, variables: Mapping[Hashable, "Variable"]):
294311
if any([var.ndim != 1 for var in variables.values()]):
@@ -305,7 +322,8 @@ def from_variables(cls, variables: Mapping[Hashable, "Variable"]):
305322
index = pd.MultiIndex.from_arrays(
306323
[var.values for var in variables.values()], names=variables.keys()
307324
)
308-
obj = cls(index, dim)
325+
level_coords_dtype = {name: var.dtype for name, var in variables.items()}
326+
obj = cls(index, dim, level_coords_dtype=level_coords_dtype)
309327

310328
level_meta = {
311329
name: {"dtype": var.dtype, "attrs": var.attrs, "encoding": var.encoding}
@@ -346,7 +364,10 @@ def query(self, labels, method=None, tolerance=None):
346364
if all([lbl in self.index.names for lbl in labels]):
347365
is_nested_vals = _is_nested_tuple(tuple(labels.values()))
348366
labels = {
349-
k: normalize_label(v, extract_scalar=True) for k, v in labels.items()
367+
k: normalize_label(
368+
v, extract_scalar=True, dtype=self.level_coords_dtype[k]
369+
)
370+
for k, v in labels.items()
350371
}
351372

352373
if len(labels) == self.index.nlevels and not is_nested_vals:

xarray/core/indexing.py

-3
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
is_duck_dask_array,
3333
sparse_array_type,
3434
)
35-
from .utils import maybe_cast_to_coords_dtype
3635

3736
if TYPE_CHECKING:
3837
from .dataarray import DataArray
@@ -185,12 +184,10 @@ def group_indexers_by_index(
185184

186185
for key, label in indexers.items():
187186
index = obj.xindexes.get(key, None)
188-
coord = obj.coords.get(key, None)
189187

190188
if index is not None:
191189
index_id = id(index)
192190
unique_indexes[index_id] = index
193-
label = maybe_cast_to_coords_dtype(label, coord.dtype) # type: ignore
194191
grouped_indexers[index_id][key] = label
195192
elif key in obj.coords:
196193
raise KeyError(f"no index found for coordinate {key}")

xarray/core/utils.py

-6
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,6 @@ def _maybe_cast_to_cftimeindex(index: pd.Index) -> pd.Index:
7272
return index
7373

7474

75-
def maybe_cast_to_coords_dtype(label, coords_dtype):
76-
if coords_dtype.kind == "f" and not isinstance(label, slice):
77-
label = np.asarray(label, dtype=coords_dtype)
78-
return label
79-
80-
8175
def maybe_coerce_to_str(index, original_coords):
8276
"""maybe coerce a pandas Index back to a nunpy array of type str
8377

0 commit comments

Comments
 (0)