Skip to content

Commit d20ba0d

Browse files
authored
Add support for netCDF4.EnumType (#8147)
1 parent 33d51c8 commit d20ba0d

File tree

6 files changed

+214
-25
lines changed

6 files changed

+214
-25
lines changed

doc/whats-new.rst

+4
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,10 @@ New Features
224224

225225
- Use `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_ for :py:func:`xarray.dot` by default if installed.
226226
By `Deepak Cherian <https://github.com/dcherian>`_. (:issue:`7764`, :pull:`8373`).
227+
- Decode/Encode netCDF4 enums and store the enum definition in dataarrays' dtype metadata.
228+
If multiple variables share the same enum in netCDF4, each dataarray will have its own
229+
enum definition in their respective dtype metadata.
230+
By `Abel Aoun <https://github.com/bzah>_`(:issue:`8144`, :pull:`8147`)
227231
- Add ``DataArray.dt.total_seconds()`` method to match the Pandas API. (:pull:`8435`).
228232
By `Ben Mares <https://github.com/maresb>`_.
229233
- Allow passing ``region="auto"`` in :py:meth:`Dataset.to_zarr` to automatically infer the

xarray/backends/netCDF4_.py

+58-16
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
# string used by netCDF4.
5050
_endian_lookup = {"=": "native", ">": "big", "<": "little", "|": "native"}
5151

52-
5352
NETCDF4_PYTHON_LOCK = combine_locks([NETCDFC_LOCK, HDF5_LOCK])
5453

5554

@@ -141,7 +140,9 @@ def _check_encoding_dtype_is_vlen_string(dtype):
141140
)
142141

143142

144-
def _get_datatype(var, nc_format="NETCDF4", raise_on_invalid_encoding=False):
143+
def _get_datatype(
144+
var, nc_format="NETCDF4", raise_on_invalid_encoding=False
145+
) -> np.dtype:
145146
if nc_format == "NETCDF4":
146147
return _nc4_dtype(var)
147148
if "dtype" in var.encoding:
@@ -234,13 +235,13 @@ def _force_native_endianness(var):
234235

235236

236237
def _extract_nc4_variable_encoding(
237-
variable,
238+
variable: Variable,
238239
raise_on_invalid=False,
239240
lsd_okay=True,
240241
h5py_okay=False,
241242
backend="netCDF4",
242243
unlimited_dims=None,
243-
):
244+
) -> dict[str, Any]:
244245
if unlimited_dims is None:
245246
unlimited_dims = ()
246247

@@ -308,7 +309,7 @@ def _extract_nc4_variable_encoding(
308309
return encoding
309310

310311

311-
def _is_list_of_strings(value):
312+
def _is_list_of_strings(value) -> bool:
312313
arr = np.asarray(value)
313314
return arr.dtype.kind in ["U", "S"] and arr.size > 1
314315

@@ -414,13 +415,25 @@ def _acquire(self, needs_lock=True):
414415
def ds(self):
415416
return self._acquire()
416417

417-
def open_store_variable(self, name, var):
418+
def open_store_variable(self, name: str, var):
419+
import netCDF4
420+
418421
dimensions = var.dimensions
419-
data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self))
420422
attributes = {k: var.getncattr(k) for k in var.ncattrs()}
423+
data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self))
424+
encoding: dict[str, Any] = {}
425+
if isinstance(var.datatype, netCDF4.EnumType):
426+
encoding["dtype"] = np.dtype(
427+
data.dtype,
428+
metadata={
429+
"enum": var.datatype.enum_dict,
430+
"enum_name": var.datatype.name,
431+
},
432+
)
433+
else:
434+
encoding["dtype"] = var.dtype
421435
_ensure_fill_value_valid(data, attributes)
422436
# netCDF4 specific encoding; save _FillValue for later
423-
encoding = {}
424437
filters = var.filters()
425438
if filters is not None:
426439
encoding.update(filters)
@@ -440,7 +453,6 @@ def open_store_variable(self, name, var):
440453
# save source so __repr__ can detect if it's local or not
441454
encoding["source"] = self._filename
442455
encoding["original_shape"] = var.shape
443-
encoding["dtype"] = var.dtype
444456

445457
return Variable(dimensions, data, attributes, encoding)
446458

@@ -485,21 +497,24 @@ def encode_variable(self, variable):
485497
return variable
486498

487499
def prepare_variable(
488-
self, name, variable, check_encoding=False, unlimited_dims=None
500+
self, name, variable: Variable, check_encoding=False, unlimited_dims=None
489501
):
490502
_ensure_no_forward_slash_in_name(name)
491-
503+
attrs = variable.attrs.copy()
504+
fill_value = attrs.pop("_FillValue", None)
492505
datatype = _get_datatype(
493506
variable, self.format, raise_on_invalid_encoding=check_encoding
494507
)
495-
attrs = variable.attrs.copy()
496-
497-
fill_value = attrs.pop("_FillValue", None)
498-
508+
# check enum metadata and use netCDF4.EnumType
509+
if (
510+
(meta := np.dtype(datatype).metadata)
511+
and (e_name := meta.get("enum_name"))
512+
and (e_dict := meta.get("enum"))
513+
):
514+
datatype = self._build_and_get_enum(name, datatype, e_name, e_dict)
499515
encoding = _extract_nc4_variable_encoding(
500516
variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims
501517
)
502-
503518
if name in self.ds.variables:
504519
nc4_var = self.ds.variables[name]
505520
else:
@@ -527,6 +542,33 @@ def prepare_variable(
527542

528543
return target, variable.data
529544

545+
def _build_and_get_enum(
546+
self, var_name: str, dtype: np.dtype, enum_name: str, enum_dict: dict[str, int]
547+
) -> Any:
548+
"""
549+
Add or get the netCDF4 Enum based on the dtype in encoding.
550+
The return type should be ``netCDF4.EnumType``,
551+
but we avoid importing netCDF4 globally for performances.
552+
"""
553+
if enum_name not in self.ds.enumtypes:
554+
return self.ds.createEnumType(
555+
dtype,
556+
enum_name,
557+
enum_dict,
558+
)
559+
datatype = self.ds.enumtypes[enum_name]
560+
if datatype.enum_dict != enum_dict:
561+
error_msg = (
562+
f"Cannot save variable `{var_name}` because an enum"
563+
f" `{enum_name}` already exists in the Dataset but have"
564+
" a different definition. To fix this error, make sure"
565+
" each variable have a uniquely named enum in their"
566+
" `encoding['dtype'].metadata` or, if they should share"
567+
" the same enum type, make sure the enums are identical."
568+
)
569+
raise ValueError(error_msg)
570+
return datatype
571+
530572
def sync(self):
531573
self.ds.sync()
532574

xarray/coding/variables.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -566,11 +566,30 @@ def decode(self):
566566

567567
class ObjectVLenStringCoder(VariableCoder):
568568
def encode(self):
569-
return NotImplementedError
569+
raise NotImplementedError
570570

571571
def decode(self, variable: Variable, name: T_Name = None) -> Variable:
572572
if variable.dtype == object and variable.encoding.get("dtype", False) == str:
573573
variable = variable.astype(variable.encoding["dtype"])
574574
return variable
575575
else:
576576
return variable
577+
578+
579+
class NativeEnumCoder(VariableCoder):
580+
"""Encode Enum into variable dtype metadata."""
581+
582+
def encode(self, variable: Variable, name: T_Name = None) -> Variable:
583+
if (
584+
"dtype" in variable.encoding
585+
and np.dtype(variable.encoding["dtype"]).metadata
586+
and "enum" in variable.encoding["dtype"].metadata
587+
):
588+
dims, data, attrs, encoding = unpack_for_encoding(variable)
589+
data = data.astype(dtype=variable.encoding.pop("dtype"))
590+
return Variable(dims, data, attrs, encoding, fastpath=True)
591+
else:
592+
return variable
593+
594+
def decode(self, variable: Variable, name: T_Name = None) -> Variable:
595+
raise NotImplementedError()

xarray/conventions.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,6 @@
4848
T_DatasetOrAbstractstore = Union[Dataset, AbstractDataStore]
4949

5050

51-
def _var_as_tuple(var: Variable) -> T_VarTuple:
52-
return var.dims, var.data, var.attrs.copy(), var.encoding.copy()
53-
54-
5551
def _infer_dtype(array, name=None):
5652
"""Given an object array with no missing values, infer its dtype from all elements."""
5753
if array.dtype.kind != "O":
@@ -111,7 +107,7 @@ def _copy_with_dtype(data, dtype: np.typing.DTypeLike):
111107
def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable:
112108
# TODO: move this from conventions to backends? (it's not CF related)
113109
if var.dtype.kind == "O":
114-
dims, data, attrs, encoding = _var_as_tuple(var)
110+
dims, data, attrs, encoding = variables.unpack_for_encoding(var)
115111

116112
# leave vlen dtypes unchanged
117113
if strings.check_vlen_dtype(data.dtype) is not None:
@@ -162,7 +158,7 @@ def encode_cf_variable(
162158
var: Variable, needs_copy: bool = True, name: T_Name = None
163159
) -> Variable:
164160
"""
165-
Converts an Variable into an Variable which follows some
161+
Converts a Variable into a Variable which follows some
166162
of the CF conventions:
167163
168164
- Nans are masked using _FillValue (or the deprecated missing_value)
@@ -188,6 +184,7 @@ def encode_cf_variable(
188184
variables.CFScaleOffsetCoder(),
189185
variables.CFMaskCoder(),
190186
variables.UnsignedIntegerCoder(),
187+
variables.NativeEnumCoder(),
191188
variables.NonStringCoder(),
192189
variables.DefaultFillvalueCoder(),
193190
variables.BooleanCoder(),
@@ -447,7 +444,7 @@ def stackable(dim: Hashable) -> bool:
447444
decode_timedelta=decode_timedelta,
448445
)
449446
except Exception as e:
450-
raise type(e)(f"Failed to decode variable {k!r}: {e}")
447+
raise type(e)(f"Failed to decode variable {k!r}: {e}") from e
451448
if decode_coords in [True, "coordinates", "all"]:
452449
var_attrs = new_vars[k].attrs
453450
if "coordinates" in var_attrs:
@@ -633,7 +630,11 @@ def cf_decoder(
633630
decode_cf_variable
634631
"""
635632
variables, attributes, _ = decode_cf_variables(
636-
variables, attributes, concat_characters, mask_and_scale, decode_times
633+
variables,
634+
attributes,
635+
concat_characters,
636+
mask_and_scale,
637+
decode_times,
637638
)
638639
return variables, attributes
639640

xarray/core/dataarray.py

+3
Original file line numberDiff line numberDiff line change
@@ -4062,6 +4062,9 @@ def to_netcdf(
40624062
name is the same as a coordinate name, then it is given the name
40634063
``"__xarray_dataarray_variable__"``.
40644064
4065+
[netCDF4 backend only] netCDF4 enums are decoded into the
4066+
dataarray dtype metadata.
4067+
40654068
See Also
40664069
--------
40674070
Dataset.to_netcdf

xarray/tests/test_backends.py

+120
Original file line numberDiff line numberDiff line change
@@ -1704,6 +1704,126 @@ def test_raise_on_forward_slashes_in_names(self) -> None:
17041704
with self.roundtrip(ds):
17051705
pass
17061706

1707+
@requires_netCDF4
1708+
def test_encoding_enum__no_fill_value(self):
1709+
with create_tmp_file() as tmp_file:
1710+
cloud_type_dict = {"clear": 0, "cloudy": 1}
1711+
with nc4.Dataset(tmp_file, mode="w") as nc:
1712+
nc.createDimension("time", size=2)
1713+
cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict)
1714+
v = nc.createVariable(
1715+
"clouds",
1716+
cloud_type,
1717+
"time",
1718+
fill_value=None,
1719+
)
1720+
v[:] = 1
1721+
with open_dataset(tmp_file) as original:
1722+
save_kwargs = {}
1723+
if self.engine == "h5netcdf":
1724+
save_kwargs["invalid_netcdf"] = True
1725+
with self.roundtrip(original, save_kwargs=save_kwargs) as actual:
1726+
assert_equal(original, actual)
1727+
assert (
1728+
actual.clouds.encoding["dtype"].metadata["enum"]
1729+
== cloud_type_dict
1730+
)
1731+
if self.engine != "h5netcdf":
1732+
# not implemented in h5netcdf yet
1733+
assert (
1734+
actual.clouds.encoding["dtype"].metadata["enum_name"]
1735+
== "cloud_type"
1736+
)
1737+
1738+
@requires_netCDF4
1739+
def test_encoding_enum__multiple_variable_with_enum(self):
1740+
with create_tmp_file() as tmp_file:
1741+
cloud_type_dict = {"clear": 0, "cloudy": 1, "missing": 255}
1742+
with nc4.Dataset(tmp_file, mode="w") as nc:
1743+
nc.createDimension("time", size=2)
1744+
cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict)
1745+
nc.createVariable(
1746+
"clouds",
1747+
cloud_type,
1748+
"time",
1749+
fill_value=255,
1750+
)
1751+
nc.createVariable(
1752+
"tifa",
1753+
cloud_type,
1754+
"time",
1755+
fill_value=255,
1756+
)
1757+
with open_dataset(tmp_file) as original:
1758+
save_kwargs = {}
1759+
if self.engine == "h5netcdf":
1760+
save_kwargs["invalid_netcdf"] = True
1761+
with self.roundtrip(original, save_kwargs=save_kwargs) as actual:
1762+
assert_equal(original, actual)
1763+
assert (
1764+
actual.clouds.encoding["dtype"] == actual.tifa.encoding["dtype"]
1765+
)
1766+
assert (
1767+
actual.clouds.encoding["dtype"].metadata
1768+
== actual.tifa.encoding["dtype"].metadata
1769+
)
1770+
assert (
1771+
actual.clouds.encoding["dtype"].metadata["enum"]
1772+
== cloud_type_dict
1773+
)
1774+
if self.engine != "h5netcdf":
1775+
# not implemented in h5netcdf yet
1776+
assert (
1777+
actual.clouds.encoding["dtype"].metadata["enum_name"]
1778+
== "cloud_type"
1779+
)
1780+
1781+
@requires_netCDF4
1782+
def test_encoding_enum__error_multiple_variable_with_changing_enum(self):
1783+
"""
1784+
Given 2 variables, if they share the same enum type,
1785+
the 2 enum definition should be identical.
1786+
"""
1787+
with create_tmp_file() as tmp_file:
1788+
cloud_type_dict = {"clear": 0, "cloudy": 1, "missing": 255}
1789+
with nc4.Dataset(tmp_file, mode="w") as nc:
1790+
nc.createDimension("time", size=2)
1791+
cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict)
1792+
nc.createVariable(
1793+
"clouds",
1794+
cloud_type,
1795+
"time",
1796+
fill_value=255,
1797+
)
1798+
nc.createVariable(
1799+
"tifa",
1800+
cloud_type,
1801+
"time",
1802+
fill_value=255,
1803+
)
1804+
with open_dataset(tmp_file) as original:
1805+
assert (
1806+
original.clouds.encoding["dtype"].metadata
1807+
== original.tifa.encoding["dtype"].metadata
1808+
)
1809+
modified_enum = original.clouds.encoding["dtype"].metadata["enum"]
1810+
modified_enum.update({"neblig": 2})
1811+
original.clouds.encoding["dtype"] = np.dtype(
1812+
"u1",
1813+
metadata={"enum": modified_enum, "enum_name": "cloud_type"},
1814+
)
1815+
if self.engine != "h5netcdf":
1816+
# not implemented yet in h5netcdf
1817+
with pytest.raises(
1818+
ValueError,
1819+
match=(
1820+
"Cannot save variable .*"
1821+
" because an enum `cloud_type` already exists in the Dataset .*"
1822+
),
1823+
):
1824+
with self.roundtrip(original):
1825+
pass
1826+
17071827

17081828
@requires_netCDF4
17091829
class TestNetCDF4Data(NetCDF4Base):

0 commit comments

Comments
 (0)