@@ -129,9 +129,12 @@ def _is_nested_tuple(possible_tuple):
129
129
)
130
130
131
131
132
- def normalize_label (value , extract_scalar = False ):
132
+ def normalize_label (value , extract_scalar = False , dtype = None ):
133
133
if getattr (value , "ndim" , 1 ) <= 1 :
134
134
value = _asarray_tuplesafe (value )
135
+ if dtype is not None and dtype .kind == "f" :
136
+ # see https://github.com/pydata/xarray/pull/3153 for details
137
+ value = np .asarray (value , dtype = dtype )
135
138
if extract_scalar :
136
139
# see https://github.com/pydata/xarray/pull/4292 for details
137
140
value = value [()] if value .dtype .kind in "mM" else value .item ()
@@ -151,12 +154,16 @@ def get_indexer_nd(index, labels, method=None, tolerance=None):
151
154
class PandasIndex (Index ):
152
155
"""Wrap a pandas.Index as an xarray compatible index."""
153
156
154
- __slots__ = ("index" , "dim" )
157
+ __slots__ = ("index" , "dim" , "coord_dtype" )
155
158
156
- def __init__ (self , array : Any , dim : Hashable ):
159
+ def __init__ (self , array : Any , dim : Hashable , coord_dtype : Any = None ):
157
160
self .index = utils .safe_cast_to_index (array )
158
161
self .dim = dim
159
162
163
+ if coord_dtype is None :
164
+ coord_dtype = self .index .dtype
165
+ self .coord_dtype = coord_dtype
166
+
160
167
@classmethod
161
168
def from_variables (cls , variables : Mapping [Hashable , "Variable" ]):
162
169
from .variable import IndexVariable
@@ -176,7 +183,7 @@ def from_variables(cls, variables: Mapping[Hashable, "Variable"]):
176
183
177
184
dim = var .dims [0 ]
178
185
179
- obj = cls (var .data , dim )
186
+ obj = cls (var .data , dim , coord_dtype = var . dtype )
180
187
181
188
data = PandasIndexingAdapter (obj .index , dtype = var .dtype )
182
189
index_var = IndexVariable (
@@ -219,7 +226,7 @@ def query(self, labels, method=None, tolerance=None):
219
226
"a dimension that does not have a MultiIndex"
220
227
)
221
228
else :
222
- label = normalize_label (label )
229
+ label = normalize_label (label , dtype = self . coord_dtype )
223
230
if label .ndim == 0 :
224
231
label_value = normalize_label (label , extract_scalar = True )
225
232
if isinstance (self .index , pd .CategoricalIndex ):
@@ -289,6 +296,16 @@ def _create_variables_from_multiindex(index, dim, level_meta=None):
289
296
290
297
291
298
class PandasMultiIndex (PandasIndex ):
299
+
300
+ __slots__ = ("index" , "dim" , "coord_dtype" , "level_coords_dtype" )
301
+
302
+ def __init__ (self , array : Any , dim : Hashable , level_coords_dtype : Any = None ):
303
+ super ().__init__ (array , dim )
304
+
305
+ if level_coords_dtype is None :
306
+ level_coords_dtype = {idx .name : idx .dtype for idx in self .index .levels }
307
+ self .level_coords_dtype = level_coords_dtype
308
+
292
309
@classmethod
293
310
def from_variables (cls , variables : Mapping [Hashable , "Variable" ]):
294
311
if any ([var .ndim != 1 for var in variables .values ()]):
@@ -305,7 +322,8 @@ def from_variables(cls, variables: Mapping[Hashable, "Variable"]):
305
322
index = pd .MultiIndex .from_arrays (
306
323
[var .values for var in variables .values ()], names = variables .keys ()
307
324
)
308
- obj = cls (index , dim )
325
+ level_coords_dtype = {name : var .dtype for name , var in variables .items ()}
326
+ obj = cls (index , dim , level_coords_dtype = level_coords_dtype )
309
327
310
328
level_meta = {
311
329
name : {"dtype" : var .dtype , "attrs" : var .attrs , "encoding" : var .encoding }
@@ -346,7 +364,10 @@ def query(self, labels, method=None, tolerance=None):
346
364
if all ([lbl in self .index .names for lbl in labels ]):
347
365
is_nested_vals = _is_nested_tuple (tuple (labels .values ()))
348
366
labels = {
349
- k : normalize_label (v , extract_scalar = True ) for k , v in labels .items ()
367
+ k : normalize_label (
368
+ v , extract_scalar = True , dtype = self .level_coords_dtype [k ]
369
+ )
370
+ for k , v in labels .items ()
350
371
}
351
372
352
373
if len (labels ) == self .index .nlevels and not is_nested_vals :
0 commit comments