diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index 16795fc..c3c8462 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -204,35 +204,31 @@ def result_type( # required by the spec rather than using np.result_type. NumPy implements # too many extra type promotions like int64 + uint64 -> float64, and does # value-based casting on scalar arrays. - A = [] + dtypes = [] scalars = [] for a in arrays_and_dtypes: - if isinstance(a, Array): - a = a.dtype + if isinstance(a, DType): + dtypes.append(a) + elif isinstance(a, Array): + dtypes.append(a.dtype) elif isinstance(a, (bool, int, float, complex)): scalars.append(a) - elif isinstance(a, np.ndarray) or a not in _all_dtypes: - raise TypeError("result_type() inputs must be array_api arrays or dtypes") - A.append(a) - - # remove python scalars - B = [a for a in A if not isinstance(a, (bool, int, float, complex))] + else: + raise TypeError( + "result_type() inputs must be Array API arrays, dtypes, or scalars" + ) - if len(B) == 0: + if not dtypes: raise ValueError("at least one array or dtype is required") - elif len(B) == 1: - result = B[0] - else: - t = B[0] - for t2 in B[1:]: - t = _result_type(t, t2) - result = t + result = dtypes[0] + for t2 in dtypes[1:]: + result = _result_type(result, t2) - if len(scalars) == 0: + if not scalars: return result if get_array_api_strict_flags()['api_version'] <= '2023.12': - raise TypeError("result_type() inputs must be array_api arrays or dtypes") + raise TypeError("result_type() inputs must be Array API arrays or dtypes") # promote python scalars given the result_type for all arrays/dtypes from ._creation_functions import empty