diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9c16fb74a7b..735159fb33a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -38,6 +38,8 @@ Deprecations Bug fixes ~~~~~~~~~ +- Added support for `numpy.bool_` attributes in roundtrips using `h5netcdf` engine with `invalid_netcdf=True` [which casts `bool`s to `numpy.bool_`] (:issue:`4981`, :pull:`4986`). + By `Victor Negîrneac `_. - Don't allow passing ``axis`` to :py:meth:`Dataset.reduce` methods (:issue:`3510`, :pull:`4940`). By `Justus Magin `_. diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 4fa34b39925..aca6524a3bd 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -199,12 +199,21 @@ def check_name(name): check_name(k) -def _validate_attrs(dataset): +def _validate_attrs(dataset, invalid_netcdf=False): """`attrs` must have a string key and a value which is either: a number, - a string, an ndarray or a list/tuple of numbers/strings. + a string, an ndarray, a list/tuple of numbers/strings, or a numpy.bool_. + + Notes + ----- + A numpy.bool_ is only allowed when using the h5netcdf engine with + `invalid_netcdf=True`. """ - def check_attr(name, value): + valid_types = (str, Number, np.ndarray, np.number, list, tuple) + if invalid_netcdf: + valid_types += (np.bool_,) + + def check_attr(name, value, valid_types): if isinstance(name, str): if not name: raise ValueError( @@ -218,22 +227,21 @@ def check_attr(name, value): "serialization to netCDF files" ) - if not isinstance(value, (str, Number, np.ndarray, np.number, list, tuple)): + if not isinstance(value, valid_types): raise TypeError( - f"Invalid value for attr {name!r}: {value!r} must be a number, " - "a string, an ndarray or a list/tuple of " - "numbers/strings for serialization to netCDF " - "files" + f"Invalid value for attr {name!r}: {value!r}. For serialization to " + "netCDF files, its value must be of one of the following types: " + f"{', '.join([vtype.__name__ for vtype in valid_types])}" ) # Check attrs on the dataset itself for k, v in dataset.attrs.items(): - check_attr(k, v) + check_attr(k, v, valid_types) # Check attrs on each variable within the dataset for variable in dataset.variables.values(): for k, v in variable.attrs.items(): - check_attr(k, v) + check_attr(k, v, valid_types) def _protect_dataset_variables_inplace(dataset, cache): @@ -1056,7 +1064,7 @@ def to_netcdf( # validate Dataset keys, DataArray names, and attr keys/values _validate_dataset_names(dataset) - _validate_attrs(dataset) + _validate_attrs(dataset, invalid_netcdf=invalid_netcdf and engine == "h5netcdf") try: store_open = WRITEABLE_STORES[engine] diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index d15736e608d..f6c00a2a9a9 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2541,6 +2541,14 @@ def test_complex(self, invalid_netcdf, warntype, num_warns): assert recorded_num_warns == num_warns + def test_numpy_bool_(self): + # h5netcdf loads booleans as numpy.bool_, this type needs to be supported + # when writing invalid_netcdf datasets in order to support a roundtrip + expected = Dataset({"x": ("y", np.ones(5), {"numpy_bool": np.bool_(True)})}) + save_kwargs = {"invalid_netcdf": True} + with self.roundtrip(expected, save_kwargs=save_kwargs) as actual: + assert_identical(expected, actual) + def test_cross_engine_read_write_netcdf4(self): # Drop dim3, because its labels include strings. These appear to be # not properly read with python-netCDF4, which converts them into