Skip to content

Commit 551de70

Browse files
kmuehlbauerIllviljanpre-commit-ci[bot]
authored
Implement more Variable Coders (#7719)
* implement coders, adapt tests * Apply suggestions from code review Co-authored-by: Illviljan <[email protected]> * add whats-new.rst entry * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix whats-new.rst entry * add PR link to whats-new.rst entry * return early if no missing values defined * fix check --------- Co-authored-by: Illviljan <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f8127fc commit 551de70

File tree

4 files changed

+187
-139
lines changed

4 files changed

+187
-139
lines changed

doc/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ Internal Changes
7373
- Remove internal support for reading GRIB files through the ``cfgrib`` backend. ``cfgrib`` now uses the external
7474
backend interface, so no existing code should break.
7575
By `Deepak Cherian <https://github.com/dcherian>`_.
76+
- Implement CF coding functions in ``VariableCoders`` (:pull:`7719`).
77+
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_
7678

7779
- Added a config.yml file with messages for the welcome bot when a Github user creates their first ever issue or pull request or has their first PR merged. (:issue:`7685`, :pull:`7685`)
7880
By `Nishtha P <https://github.com/nishthap981>`_.

xarray/coding/variables.py

+171-8
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,71 @@ def __repr__(self) -> str:
7878
)
7979

8080

81+
class NativeEndiannessArray(indexing.ExplicitlyIndexedNDArrayMixin):
82+
"""Decode arrays on the fly from non-native to native endianness
83+
84+
This is useful for decoding arrays from netCDF3 files (which are all
85+
big endian) into native endianness, so they can be used with Cython
86+
functions, such as those found in bottleneck and pandas.
87+
88+
>>> x = np.arange(5, dtype=">i2")
89+
90+
>>> x.dtype
91+
dtype('>i2')
92+
93+
>>> NativeEndiannessArray(x).dtype
94+
dtype('int16')
95+
96+
>>> indexer = indexing.BasicIndexer((slice(None),))
97+
>>> NativeEndiannessArray(x)[indexer].dtype
98+
dtype('int16')
99+
"""
100+
101+
__slots__ = ("array",)
102+
103+
def __init__(self, array) -> None:
104+
self.array = indexing.as_indexable(array)
105+
106+
@property
107+
def dtype(self) -> np.dtype:
108+
return np.dtype(self.array.dtype.kind + str(self.array.dtype.itemsize))
109+
110+
def __getitem__(self, key) -> np.ndarray:
111+
return np.asarray(self.array[key], dtype=self.dtype)
112+
113+
114+
class BoolTypeArray(indexing.ExplicitlyIndexedNDArrayMixin):
115+
"""Decode arrays on the fly from integer to boolean datatype
116+
117+
This is useful for decoding boolean arrays from integer typed netCDF
118+
variables.
119+
120+
>>> x = np.array([1, 0, 1, 1, 0], dtype="i1")
121+
122+
>>> x.dtype
123+
dtype('int8')
124+
125+
>>> BoolTypeArray(x).dtype
126+
dtype('bool')
127+
128+
>>> indexer = indexing.BasicIndexer((slice(None),))
129+
>>> BoolTypeArray(x)[indexer].dtype
130+
dtype('bool')
131+
"""
132+
133+
__slots__ = ("array",)
134+
135+
def __init__(self, array) -> None:
136+
self.array = indexing.as_indexable(array)
137+
138+
@property
139+
def dtype(self) -> np.dtype:
140+
return np.dtype("bool")
141+
142+
def __getitem__(self, key) -> np.ndarray:
143+
return np.asarray(self.array[key], dtype=self.dtype)
144+
145+
81146
def lazy_elemwise_func(array, func: Callable, dtype: np.typing.DTypeLike):
82147
"""Lazily apply an element-wise function to an array.
83148
Parameters
@@ -159,27 +224,29 @@ def encode(self, variable: Variable, name: T_Name = None):
159224
fv = encoding.get("_FillValue")
160225
mv = encoding.get("missing_value")
161226

162-
if (
163-
fv is not None
164-
and mv is not None
165-
and not duck_array_ops.allclose_or_equiv(fv, mv)
166-
):
227+
fv_exists = fv is not None
228+
mv_exists = mv is not None
229+
230+
if not fv_exists and not mv_exists:
231+
return variable
232+
233+
if fv_exists and mv_exists and not duck_array_ops.allclose_or_equiv(fv, mv):
167234
raise ValueError(
168235
f"Variable {name!r} has conflicting _FillValue ({fv}) and missing_value ({mv}). Cannot encode data."
169236
)
170237

171-
if fv is not None:
238+
if fv_exists:
172239
# Ensure _FillValue is cast to same dtype as data's
173240
encoding["_FillValue"] = dtype.type(fv)
174241
fill_value = pop_to(encoding, attrs, "_FillValue", name=name)
175242
if not pd.isnull(fill_value):
176243
data = duck_array_ops.fillna(data, fill_value)
177244

178-
if mv is not None:
245+
if mv_exists:
179246
# Ensure missing_value is cast to same dtype as data's
180247
encoding["missing_value"] = dtype.type(mv)
181248
fill_value = pop_to(encoding, attrs, "missing_value", name=name)
182-
if not pd.isnull(fill_value) and fv is None:
249+
if not pd.isnull(fill_value) and not fv_exists:
183250
data = duck_array_ops.fillna(data, fill_value)
184251

185252
return Variable(dims, data, attrs, encoding, fastpath=True)
@@ -349,3 +416,99 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
349416
return Variable(dims, data, attrs, encoding, fastpath=True)
350417
else:
351418
return variable
419+
420+
421+
class DefaultFillvalueCoder(VariableCoder):
422+
"""Encode default _FillValue if needed."""
423+
424+
def encode(self, variable: Variable, name: T_Name = None) -> Variable:
425+
dims, data, attrs, encoding = unpack_for_encoding(variable)
426+
# make NaN the fill value for float types
427+
if (
428+
"_FillValue" not in attrs
429+
and "_FillValue" not in encoding
430+
and np.issubdtype(variable.dtype, np.floating)
431+
):
432+
attrs["_FillValue"] = variable.dtype.type(np.nan)
433+
return Variable(dims, data, attrs, encoding, fastpath=True)
434+
else:
435+
return variable
436+
437+
def decode(self, variable: Variable, name: T_Name = None) -> Variable:
438+
raise NotImplementedError()
439+
440+
441+
class BooleanCoder(VariableCoder):
442+
"""Code boolean values."""
443+
444+
def encode(self, variable: Variable, name: T_Name = None) -> Variable:
445+
if (
446+
(variable.dtype == bool)
447+
and ("dtype" not in variable.encoding)
448+
and ("dtype" not in variable.attrs)
449+
):
450+
dims, data, attrs, encoding = unpack_for_encoding(variable)
451+
attrs["dtype"] = "bool"
452+
data = duck_array_ops.astype(data, dtype="i1", copy=True)
453+
454+
return Variable(dims, data, attrs, encoding, fastpath=True)
455+
else:
456+
return variable
457+
458+
def decode(self, variable: Variable, name: T_Name = None) -> Variable:
459+
if variable.attrs.get("dtype", False) == "bool":
460+
dims, data, attrs, encoding = unpack_for_decoding(variable)
461+
del attrs["dtype"]
462+
data = BoolTypeArray(data)
463+
return Variable(dims, data, attrs, encoding, fastpath=True)
464+
else:
465+
return variable
466+
467+
468+
class EndianCoder(VariableCoder):
469+
"""Decode Endianness to native."""
470+
471+
def encode(self):
472+
raise NotImplementedError()
473+
474+
def decode(self, variable: Variable, name: T_Name = None) -> Variable:
475+
dims, data, attrs, encoding = unpack_for_decoding(variable)
476+
if not data.dtype.isnative:
477+
data = NativeEndiannessArray(data)
478+
return Variable(dims, data, attrs, encoding, fastpath=True)
479+
else:
480+
return variable
481+
482+
483+
class NonStringCoder(VariableCoder):
484+
"""Encode NonString variables if dtypes differ."""
485+
486+
def encode(self, variable: Variable, name: T_Name = None) -> Variable:
487+
if "dtype" in variable.encoding and variable.encoding["dtype"] not in (
488+
"S1",
489+
str,
490+
):
491+
dims, data, attrs, encoding = unpack_for_encoding(variable)
492+
dtype = np.dtype(encoding.pop("dtype"))
493+
if dtype != variable.dtype:
494+
if np.issubdtype(dtype, np.integer):
495+
if (
496+
np.issubdtype(variable.dtype, np.floating)
497+
and "_FillValue" not in variable.attrs
498+
and "missing_value" not in variable.attrs
499+
):
500+
warnings.warn(
501+
f"saving variable {name} with floating "
502+
"point data as an integer dtype without "
503+
"any _FillValue to use for NaNs",
504+
SerializationWarning,
505+
stacklevel=10,
506+
)
507+
data = np.around(data)
508+
data = data.astype(dtype=dtype)
509+
return Variable(dims, data, attrs, encoding, fastpath=True)
510+
else:
511+
return variable
512+
513+
def decode(self):
514+
raise NotImplementedError()

xarray/conventions.py

+11-128
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from xarray.coding import strings, times, variables
1212
from xarray.coding.variables import SerializationWarning, pop_to
13-
from xarray.core import duck_array_ops, indexing
13+
from xarray.core import indexing
1414
from xarray.core.common import (
1515
_contains_datetime_like_objects,
1616
contains_cftime_datetimes,
@@ -48,123 +48,10 @@
4848
T_DatasetOrAbstractstore = Union[Dataset, AbstractDataStore]
4949

5050

51-
class NativeEndiannessArray(indexing.ExplicitlyIndexedNDArrayMixin):
52-
"""Decode arrays on the fly from non-native to native endianness
53-
54-
This is useful for decoding arrays from netCDF3 files (which are all
55-
big endian) into native endianness, so they can be used with Cython
56-
functions, such as those found in bottleneck and pandas.
57-
58-
>>> x = np.arange(5, dtype=">i2")
59-
60-
>>> x.dtype
61-
dtype('>i2')
62-
63-
>>> NativeEndiannessArray(x).dtype
64-
dtype('int16')
65-
66-
>>> indexer = indexing.BasicIndexer((slice(None),))
67-
>>> NativeEndiannessArray(x)[indexer].dtype
68-
dtype('int16')
69-
"""
70-
71-
__slots__ = ("array",)
72-
73-
def __init__(self, array):
74-
self.array = indexing.as_indexable(array)
75-
76-
@property
77-
def dtype(self):
78-
return np.dtype(self.array.dtype.kind + str(self.array.dtype.itemsize))
79-
80-
def __getitem__(self, key):
81-
return np.asarray(self.array[key], dtype=self.dtype)
82-
83-
84-
class BoolTypeArray(indexing.ExplicitlyIndexedNDArrayMixin):
85-
"""Decode arrays on the fly from integer to boolean datatype
86-
87-
This is useful for decoding boolean arrays from integer typed netCDF
88-
variables.
89-
90-
>>> x = np.array([1, 0, 1, 1, 0], dtype="i1")
91-
92-
>>> x.dtype
93-
dtype('int8')
94-
95-
>>> BoolTypeArray(x).dtype
96-
dtype('bool')
97-
98-
>>> indexer = indexing.BasicIndexer((slice(None),))
99-
>>> BoolTypeArray(x)[indexer].dtype
100-
dtype('bool')
101-
"""
102-
103-
__slots__ = ("array",)
104-
105-
def __init__(self, array):
106-
self.array = indexing.as_indexable(array)
107-
108-
@property
109-
def dtype(self):
110-
return np.dtype("bool")
111-
112-
def __getitem__(self, key):
113-
return np.asarray(self.array[key], dtype=self.dtype)
114-
115-
11651
def _var_as_tuple(var: Variable) -> T_VarTuple:
11752
return var.dims, var.data, var.attrs.copy(), var.encoding.copy()
11853

11954

120-
def maybe_encode_nonstring_dtype(var: Variable, name: T_Name = None) -> Variable:
121-
if "dtype" in var.encoding and var.encoding["dtype"] not in ("S1", str):
122-
dims, data, attrs, encoding = _var_as_tuple(var)
123-
dtype = np.dtype(encoding.pop("dtype"))
124-
if dtype != var.dtype:
125-
if np.issubdtype(dtype, np.integer):
126-
if (
127-
np.issubdtype(var.dtype, np.floating)
128-
and "_FillValue" not in var.attrs
129-
and "missing_value" not in var.attrs
130-
):
131-
warnings.warn(
132-
f"saving variable {name} with floating "
133-
"point data as an integer dtype without "
134-
"any _FillValue to use for NaNs",
135-
SerializationWarning,
136-
stacklevel=10,
137-
)
138-
data = np.around(data)
139-
data = data.astype(dtype=dtype)
140-
var = Variable(dims, data, attrs, encoding, fastpath=True)
141-
return var
142-
143-
144-
def maybe_default_fill_value(var: Variable) -> Variable:
145-
# make NaN the fill value for float types:
146-
if (
147-
"_FillValue" not in var.attrs
148-
and "_FillValue" not in var.encoding
149-
and np.issubdtype(var.dtype, np.floating)
150-
):
151-
var.attrs["_FillValue"] = var.dtype.type(np.nan)
152-
return var
153-
154-
155-
def maybe_encode_bools(var: Variable) -> Variable:
156-
if (
157-
(var.dtype == bool)
158-
and ("dtype" not in var.encoding)
159-
and ("dtype" not in var.attrs)
160-
):
161-
dims, data, attrs, encoding = _var_as_tuple(var)
162-
attrs["dtype"] = "bool"
163-
data = duck_array_ops.astype(data, dtype="i1", copy=True)
164-
var = Variable(dims, data, attrs, encoding, fastpath=True)
165-
return var
166-
167-
16855
def _infer_dtype(array, name: T_Name = None) -> np.dtype:
16956
"""Given an object array with no missing values, infer its dtype from its
17057
first element
@@ -292,13 +179,13 @@ def encode_cf_variable(
292179
variables.CFScaleOffsetCoder(),
293180
variables.CFMaskCoder(),
294181
variables.UnsignedIntegerCoder(),
182+
variables.NonStringCoder(),
183+
variables.DefaultFillvalueCoder(),
184+
variables.BooleanCoder(),
295185
]:
296186
var = coder.encode(var, name=name)
297187

298-
# TODO(shoyer): convert all of these to use coders, too:
299-
var = maybe_encode_nonstring_dtype(var, name=name)
300-
var = maybe_default_fill_value(var)
301-
var = maybe_encode_bools(var)
188+
# TODO(kmuehlbauer): check if ensure_dtype_not_object can be moved to backends:
302189
var = ensure_dtype_not_object(var, name=name)
303190

304191
for attr_name in CF_RELATED_DATA:
@@ -389,19 +276,15 @@ def decode_cf_variable(
389276
if decode_times:
390277
var = times.CFDatetimeCoder(use_cftime=use_cftime).decode(var, name=name)
391278

392-
dimensions, data, attributes, encoding = variables.unpack_for_decoding(var)
393-
# TODO(shoyer): convert everything below to use coders
279+
if decode_endianness and not var.dtype.isnative:
280+
var = variables.EndianCoder().decode(var)
281+
original_dtype = var.dtype
394282

395-
if decode_endianness and not data.dtype.isnative:
396-
# do this last, so it's only done if we didn't already unmask/scale
397-
data = NativeEndiannessArray(data)
398-
original_dtype = data.dtype
283+
var = variables.BooleanCoder().decode(var)
399284

400-
encoding.setdefault("dtype", original_dtype)
285+
dimensions, data, attributes, encoding = variables.unpack_for_decoding(var)
401286

402-
if "dtype" in attributes and attributes["dtype"] == "bool":
403-
del attributes["dtype"]
404-
data = BoolTypeArray(data)
287+
encoding.setdefault("dtype", original_dtype)
405288

406289
if not is_duck_dask_array(data):
407290
data = indexing.LazilyIndexedArray(data)

0 commit comments

Comments
 (0)