Skip to content

Commit 17e1728

Browse files
committed
preserve vlen string dtypes, allow vlen string fill_values
1 parent 1411474 commit 17e1728

File tree

5 files changed

+51
-46
lines changed

5 files changed

+51
-46
lines changed

xarray/backends/h5netcdf_.py

-9
Original file line numberDiff line numberDiff line change
@@ -271,15 +271,6 @@ def prepare_variable(
271271
dtype = _get_datatype(variable, raise_on_invalid_encoding=check_encoding)
272272

273273
fillvalue = attrs.pop("_FillValue", None)
274-
if dtype is str and fillvalue is not None:
275-
raise NotImplementedError(
276-
"h5netcdf does not yet support setting a fill value for "
277-
"variable-length strings "
278-
"(https://github.com/h5netcdf/h5netcdf/issues/37). "
279-
f"Either remove '_FillValue' from encoding on variable {name!r} "
280-
"or set {'dtype': 'S1'} in encoding to use the fixed width "
281-
"NC_CHAR type."
282-
)
283274

284275
if dtype is str:
285276
dtype = h5py.special_dtype(vlen=str)

xarray/backends/netCDF4_.py

-10
Original file line numberDiff line numberDiff line change
@@ -490,16 +490,6 @@ def prepare_variable(
490490

491491
fill_value = attrs.pop("_FillValue", None)
492492

493-
if datatype is str and fill_value is not None:
494-
raise NotImplementedError(
495-
"netCDF4 does not yet support setting a fill value for "
496-
"variable-length strings "
497-
"(https://github.com/Unidata/netcdf4-python/issues/730). "
498-
f"Either remove '_FillValue' from encoding on variable {name!r} "
499-
"or set {'dtype': 'S1'} in encoding to use the fixed width "
500-
"NC_CHAR type."
501-
)
502-
503493
encoding = _extract_nc4_variable_encoding(
504494
variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims
505495
)

xarray/coding/variables.py

+12
Original file line numberDiff line numberDiff line change
@@ -562,3 +562,15 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
562562

563563
def decode(self):
564564
raise NotImplementedError()
565+
566+
567+
class ObjectStringCoder(VariableCoder):
568+
def encode(self):
569+
return NotImplementedError
570+
571+
def decode(self, variable: Variable, name: T_Name = None) -> Variable:
572+
if variable.dtype == object and variable.encoding.get("dtype", False) == str:
573+
variable = variable.astype(variable.encoding["dtype"])
574+
return variable
575+
else:
576+
return variable

xarray/conventions.py

+4
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,10 @@ def decode_cf_variable(
265265
var = strings.CharacterArrayCoder().decode(var, name=name)
266266
var = strings.EncodedStringCoder().decode(var)
267267

268+
if original_dtype == object:
269+
var = variables.ObjectStringCoder().decode(var)
270+
original_dtype = var.dtype
271+
268272
if mask_and_scale:
269273
for coder in [
270274
variables.UnsignedIntegerCoder(),

xarray/tests/test_backends.py

+35-27
Original file line numberDiff line numberDiff line change
@@ -864,12 +864,13 @@ def test_roundtrip_empty_vlen_string_array(self) -> None:
864864
assert check_vlen_dtype(original["a"].dtype) == str
865865
with self.roundtrip(original) as actual:
866866
assert_identical(original, actual)
867-
assert object == actual["a"].dtype
868-
assert actual["a"].dtype == original["a"].dtype
869-
# only check metadata for capable backends
870-
# eg. NETCDF3 based backends do not roundtrip metadata
871-
if actual["a"].dtype.metadata is not None:
872-
assert check_vlen_dtype(actual["a"].dtype) == str
867+
if np.issubdtype(actual["a"].dtype, object):
868+
# only check metadata for capable backends
869+
# eg. NETCDF3 based backends do not roundtrip metadata
870+
if actual["a"].dtype.metadata is not None:
871+
assert check_vlen_dtype(actual["a"].dtype) == str
872+
else:
873+
assert actual["a"].dtype == np.dtype("<U1")
873874

874875
@pytest.mark.parametrize(
875876
"decoded_fn, encoded_fn",
@@ -1374,32 +1375,39 @@ def test_write_groups(self) -> None:
13741375
with self.open(tmp_file, group="data/2") as actual2:
13751376
assert_identical(data2, actual2)
13761377

1377-
def test_encoding_kwarg_vlen_string(self) -> None:
1378-
for input_strings in [[b"foo", b"bar", b"baz"], ["foo", "bar", "baz"]]:
1379-
original = Dataset({"x": input_strings})
1380-
expected = Dataset({"x": ["foo", "bar", "baz"]})
1381-
kwargs = dict(encoding={"x": {"dtype": str}})
1382-
with self.roundtrip(original, save_kwargs=kwargs) as actual:
1383-
assert actual["x"].encoding["dtype"] is str
1384-
assert_identical(actual, expected)
1385-
1386-
def test_roundtrip_string_with_fill_value_vlen(self) -> None:
1378+
@pytest.mark.parametrize(
1379+
"input_strings, is_bytes",
1380+
[
1381+
([b"foo", b"bar", b"baz"], True),
1382+
(["foo", "bar", "baz"], False),
1383+
(["foó", "bár", "baź"], False),
1384+
],
1385+
)
1386+
def test_encoding_kwarg_vlen_string(
1387+
self, input_strings: list[str], is_bytes: bool
1388+
) -> None:
1389+
original = Dataset({"x": input_strings})
1390+
1391+
expected_string = ["foo", "bar", "baz"] if is_bytes else input_strings
1392+
expected = Dataset({"x": expected_string})
1393+
kwargs = dict(encoding={"x": {"dtype": str}})
1394+
with self.roundtrip(original, save_kwargs=kwargs) as actual:
1395+
assert actual["x"].encoding["dtype"] == "<U3"
1396+
assert actual["x"].dtype == "<U3"
1397+
assert_identical(actual, expected)
1398+
1399+
@pytest.mark.parametrize("fill_value", ["XXX", "", "bár"])
1400+
def test_roundtrip_string_with_fill_value_vlen(self, fill_value: str) -> None:
13871401
values = np.array(["ab", "cdef", np.nan], dtype=object)
13881402
expected = Dataset({"x": ("t", values)})
13891403

1390-
# netCDF4-based backends don't support an explicit fillvalue
1391-
# for variable length strings yet.
1392-
# https://github.com/Unidata/netcdf4-python/issues/730
1393-
# https://github.com/h5netcdf/h5netcdf/issues/37
1394-
original = Dataset({"x": ("t", values, {}, {"_FillValue": "XXX"})})
1395-
with pytest.raises(NotImplementedError):
1396-
with self.roundtrip(original) as actual:
1397-
assert_identical(expected, actual)
1404+
original = Dataset({"x": ("t", values, {}, {"_FillValue": fill_value})})
1405+
with self.roundtrip(original) as actual:
1406+
assert_identical(expected, actual)
13981407

13991408
original = Dataset({"x": ("t", values, {}, {"_FillValue": ""})})
1400-
with pytest.raises(NotImplementedError):
1401-
with self.roundtrip(original) as actual:
1402-
assert_identical(expected, actual)
1409+
with self.roundtrip(original) as actual:
1410+
assert_identical(expected, actual)
14031411

14041412
def test_roundtrip_character_array(self) -> None:
14051413
with create_tmp_file() as tmp_file:

0 commit comments

Comments
 (0)