Skip to content

Commit 5149cc7

Browse files
committed
Update thanks to @shoyer
1 parent b123215 commit 5149cc7

File tree

3 files changed

+8
-9
lines changed

3 files changed

+8
-9
lines changed

xarray/coding/strings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
def create_vlen_dtype(element_type):
2020
if element_type not in (str, bytes):
21-
raise TypeError("Unsupported type for vlen_dtype: `{}`".format(element_type))
21+
raise TypeError("unsupported type for vlen_dtype: {!r}".format(element_type))
2222
# based on h5py.special_dtype
2323
return np.dtype("O", metadata={"element_type": element_type})
2424

xarray/conventions.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,10 @@ def _infer_dtype(array, name=None):
157157
return np.dtype(float)
158158

159159
element = array[(0,) * array.ndim]
160+
# We use the base types to avoid subclasses of bytes and str (which might
161+
# not play nice with e.g. hdf5 datatypes), such as those from numpy
160162
if isinstance(element, bytes):
161-
return strings.create_vlen_dtype(type(element))
163+
return strings.create_vlen_dtype(bytes)
162164
elif isinstance(element, str):
163165
return strings.create_vlen_dtype(str)
164166

xarray/tests/test_coding_strings.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,10 @@ def test_vlen_dtype():
3636
assert strings.check_vlen_dtype(np.dtype(object)) is None
3737

3838

39-
@pytest.mark.parametrize("str_type", (str, np.str_))
40-
def test_numpy_str_handling(str_type):
41-
dtype = strings.create_vlen_dtype(str_type)
42-
assert dtype.metadata["element_type"] == str_type
43-
assert strings.is_unicode_dtype(dtype)
44-
assert not strings.is_bytes_dtype(dtype)
45-
assert strings.check_vlen_dtype(dtype) is str_type
39+
@pytest.mark.parametrize("numpy_str_type", (np.str_, np.bytes_))
40+
def test_numpy_subclass_handling(numpy_str_type):
41+
with pytest.raises(TypeError, match="unsupported type for vlen_dtype"):
42+
strings.create_vlen_dtype(numpy_str_type)
4643

4744

4845
@requires_netCDF4

0 commit comments

Comments
 (0)