Skip to content

Raise an informative error message when object array has mixed types #4700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,32 @@ def _var_as_tuple(var: Variable) -> T_VarTuple:
return var.dims, var.data, var.attrs.copy(), var.encoding.copy()


def _infer_dtype(array, name: T_Name = None) -> np.dtype:
"""Given an object array with no missing values, infer its dtype from its
first element
"""
def _infer_dtype(array, name=None):
"""Given an object array with no missing values, infer its dtype from all elements."""
if array.dtype.kind != "O":
raise TypeError("infer_type must be called on a dtype=object array")

if array.size == 0:
return np.dtype(float)

native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel()))
if len(native_dtypes) > 1 and native_dtypes != {bytes, str}:
raise ValueError(
"unable to infer dtype on variable {!r}; object array "
"contains mixed native types: {}".format(
name, ", ".join(x.__name__ for x in native_dtypes)
)
)

native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel()))
if len(native_dtypes) > 1 and native_dtypes != {bytes, str}:
raise ValueError(
"unable to infer dtype on variable {!r}; object array "
"contains mixed native types: {}".format(
name, ", ".join(x.__name__ for x in native_dtypes)
)
)

element = array[(0,) * array.ndim]
# We use the base types to avoid subclasses of bytes and str (which might
# not play nice with e.g. hdf5 datatypes), such as those from numpy
Expand Down
12 changes: 12 additions & 0 deletions xarray/tests/test_conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,18 @@ def test_encoding_kwarg_fixed_width_string(self) -> None:
pass


@pytest.mark.parametrize(
"data",
[
np.array([["ab", "cdef", b"X"], [1, 2, "c"]], dtype=object),
np.array([["x", 1], ["y", 2]], dtype="object"),
],
)
def test_infer_dtype_error_on_mixed_types(data):
with pytest.raises(ValueError, match="unable to infer dtype on variable"):
conventions._infer_dtype(data, "test")


class TestDecodeCFVariableWithArrayUnits:
def test_decode_cf_variable_with_array_units(self) -> None:
v = Variable(["t"], [1, 2, 3], {"units": np.array(["foobar"], dtype=object)})
Expand Down