Skip to content

Commit 5e15269

Browse files
committed
Update thanks to @shoyer
1 parent bf5edd0 commit 5e15269

File tree

3 files changed

+9
-10
lines changed

3 files changed

+9
-10
lines changed

xarray/coding/strings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
def create_vlen_dtype(element_type):
2020
if element_type not in (str, bytes):
21-
raise TypeError("Unsupported type for vlen_dtype: `{}`".format(element_type))
21+
raise TypeError("unsupported type for vlen_dtype: {!r}".format(element_type))
2222
# based on h5py.special_dtype
2323
return np.dtype("O", metadata={"element_type": element_type})
2424

xarray/conventions.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,10 @@ def _infer_dtype(array, name=None):
157157
return np.dtype(float)
158158

159159
element = array[(0,) * array.ndim]
160+
# We use the base types to avoid subclasses of bytes and str (which might
161+
# not play nice with e.g. hdf5 datatypes), such as those from numpy
160162
if isinstance(element, bytes):
161-
return strings.create_vlen_dtype(type(element))
163+
return strings.create_vlen_dtype(bytes)
162164
elif isinstance(element, str):
163165
return strings.create_vlen_dtype(str)
164166

xarray/tests/test_coding_strings.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,15 @@ def test_vlen_dtype() -> None:
3636
assert strings.check_vlen_dtype(np.dtype(object)) is None
3737

3838

39-
@pytest.mark.parametrize("str_type", (str, np.str_))
40-
def test_numpy_str_handling(str_type) -> None:
41-
dtype = strings.create_vlen_dtype(str_type)
42-
assert dtype.metadata["element_type"] == str_type
43-
assert strings.is_unicode_dtype(dtype)
44-
assert not strings.is_bytes_dtype(dtype)
45-
assert strings.check_vlen_dtype(dtype) is str_type
39+
@pytest.mark.parametrize("numpy_str_type", (np.str_, np.bytes_))
40+
def test_numpy_subclass_handling(numpy_str_type) -> None:
41+
with pytest.raises(TypeError, match="unsupported type for vlen_dtype"):
42+
strings.create_vlen_dtype(numpy_str_type)
4643

4744

4845
@requires_netCDF4
4946
@pytest.mark.parametrize("str_type", (str, np.str_))
50-
def test_write_file_from_np_str(str_type):
47+
def test_write_file_from_np_str(str_type) -> None:
5148
# should be moved elsewhere probably
5249
scenarios = [str_type(v) for v in ["scenario_a", "scenario_b", "scenario_c"]]
5350
years = range(2015, 2100 + 1)

0 commit comments

Comments
 (0)