Skip to content

Commit 88b92a0

Browse files
authored
Merge pull request #349 from ev-br/stress_test_result_type
ENH: improve testing of `result_type`
2 parents 835a9ca + d53cb39 commit 88b92a0

File tree

2 files changed

+58
-5
lines changed

2 files changed

+58
-5
lines changed

Diff for: array_api_tests/hypothesis_helpers.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from hypothesis import assume, reject
1111
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
1212
integers, complex_numbers, just, lists, none, one_of,
13-
sampled_from, shared, builds, nothing)
13+
sampled_from, shared, builds, nothing, permutations)
1414

1515
from . import _array_module as xp, api_version
1616
from . import array_helpers as ah
@@ -148,6 +148,13 @@ def mutually_promotable_dtypes(
148148
return one_of(strats).map(tuple)
149149

150150

151+
@composite
152+
def pair_of_mutually_promotable_dtypes(draw, max_size=2, *, dtypes=dh.all_dtypes):
153+
sample = draw(mutually_promotable_dtypes( max_size, dtypes=dtypes))
154+
permuted = draw(permutations(sample))
155+
return sample, tuple(permuted)
156+
157+
151158
class OnewayPromotableDtypes(NamedTuple):
152159
input_dtype: DataType
153160
result_dtype: DataType

Diff for: array_api_tests/test_data_type_functions.py

+50-4
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,53 @@ def test_isdtype(dtype, kind):
208208
assert out == expected, f"{out=}, but should be {expected} [isdtype()]"
209209

210210

211-
@given(hh.mutually_promotable_dtypes(None))
212-
def test_result_type(dtypes):
213-
out = xp.result_type(*dtypes)
214-
ph.assert_dtype("result_type", in_dtype=dtypes, out_dtype=out, repr_name="out")
211+
@pytest.mark.min_version("2024.12")
212+
class TestResultType:
213+
@given(dtypes=hh.mutually_promotable_dtypes(None))
214+
def test_result_type(self, dtypes):
215+
out = xp.result_type(*dtypes)
216+
ph.assert_dtype("result_type", in_dtype=dtypes, out_dtype=out, repr_name="out")
217+
218+
@given(pair=hh.pair_of_mutually_promotable_dtypes(None))
219+
def test_shuffled(self, pair):
220+
"""Test that result_type is insensitive to the order of arguments."""
221+
s1, s2 = pair
222+
out1 = xp.result_type(*s1)
223+
out2 = xp.result_type(*s2)
224+
assert out1 == out2
225+
226+
@given(pair=hh.pair_of_mutually_promotable_dtypes(2), data=st.data())
227+
def test_arrays_and_dtypes(self, pair, data):
228+
s1, s2 = pair
229+
a2 = tuple(xp.empty(1, dtype=dt) for dt in s2)
230+
a_and_dt = data.draw(st.permutations(s1 + a2))
231+
out = xp.result_type(*a_and_dt)
232+
ph.assert_dtype("result_type", in_dtype=s1+s2, out_dtype=out, repr_name="out")
233+
234+
@given(dtypes=hh.mutually_promotable_dtypes(2), data=st.data())
235+
def test_with_scalars(self, dtypes, data):
236+
out = xp.result_type(*dtypes)
237+
238+
if out == xp.bool:
239+
scalars = [True]
240+
elif out in dh.all_int_dtypes:
241+
scalars = [1]
242+
elif out in dh.real_dtypes:
243+
scalars = [1, 1.0]
244+
elif out in dh.numeric_dtypes:
245+
scalars = [1, 1.0, 1j] # numeric_types - real_types == complex_types
246+
else:
247+
raise ValueError(f"unknown dtype {out = }.")
248+
249+
scalar = data.draw(st.sampled_from(scalars))
250+
inputs = data.draw(st.permutations(dtypes + (scalar,)))
251+
252+
out_scalar = xp.result_type(*inputs)
253+
assert out_scalar == out
254+
255+
# retry with arrays
256+
arrays = tuple(xp.empty(1, dtype=dt) for dt in dtypes)
257+
inputs = data.draw(st.permutations(arrays + (scalar,)))
258+
out_scalar = xp.result_type(*inputs)
259+
assert out_scalar == out
260+

0 commit comments

Comments
 (0)