Skip to content

Commit 1a154fb

Browse files
committed
ENH: add a result_type test with python scalars
1 parent 59a8135 commit 1a154fb

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

Diff for: array_api_tests/test_data_type_functions.py

+28
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def test_isdtype(dtype, kind):
208208
assert out == expected, f"{out=}, but should be {expected} [isdtype()]"
209209

210210

211+
@pytest.mark.min_version("2024.12")
211212
class TestResultType:
212213
@given(dtypes=hh.mutually_promotable_dtypes(None))
213214
def test_result_type(self, dtypes):
@@ -230,3 +231,30 @@ def test_arrays_and_dtypes(self, pair, data):
230231
out = xp.result_type(*a_and_dt)
231232
ph.assert_dtype("result_type", in_dtype=s1+s2, out_dtype=out, repr_name="out")
232233

234+
@given(dtypes=hh.mutually_promotable_dtypes(2), data=st.data())
235+
def test_with_scalars(self, dtypes, data):
236+
out = xp.result_type(*dtypes)
237+
238+
if out == xp.bool:
239+
scalars = [True]
240+
elif out in dh.all_int_dtypes:
241+
scalars = [1]
242+
elif out in dh.real_dtypes:
243+
scalars = [1, 1.0]
244+
elif out in dh.numeric_dtypes:
245+
scalars = [1, 1.0, 1j] # numeric_types - real_types == complex_types
246+
else:
247+
raise ValueError(f"unknown dtype {out = }.")
248+
249+
scalar = data.draw(st.sampled_from(scalars))
250+
inputs = data.draw(st.permutations(dtypes + (scalar,)))
251+
252+
out_scalar = xp.result_type(*inputs)
253+
assert out_scalar == out
254+
255+
# retry with arrays
256+
arrays = tuple(xp.empty(1, dtype=dt) for dt in dtypes)
257+
inputs = data.draw(st.permutations(arrays + (scalar,)))
258+
out_scalar = xp.result_type(*inputs)
259+
assert out_scalar == out
260+

0 commit comments

Comments
 (0)