@@ -208,7 +208,53 @@ def test_isdtype(dtype, kind):
208
208
assert out == expected , f"{ out = } , but should be { expected } [isdtype()]"
209
209
210
210
211
- @given (hh .mutually_promotable_dtypes (None ))
212
- def test_result_type (dtypes ):
213
- out = xp .result_type (* dtypes )
214
- ph .assert_dtype ("result_type" , in_dtype = dtypes , out_dtype = out , repr_name = "out" )
211
+ @pytest .mark .min_version ("2024.12" )
212
+ class TestResultType :
213
+ @given (dtypes = hh .mutually_promotable_dtypes (None ))
214
+ def test_result_type (self , dtypes ):
215
+ out = xp .result_type (* dtypes )
216
+ ph .assert_dtype ("result_type" , in_dtype = dtypes , out_dtype = out , repr_name = "out" )
217
+
218
+ @given (pair = hh .pair_of_mutually_promotable_dtypes (None ))
219
+ def test_shuffled (self , pair ):
220
+ """Test that result_type is insensitive to the order of arguments."""
221
+ s1 , s2 = pair
222
+ out1 = xp .result_type (* s1 )
223
+ out2 = xp .result_type (* s2 )
224
+ assert out1 == out2
225
+
226
+ @given (pair = hh .pair_of_mutually_promotable_dtypes (2 ), data = st .data ())
227
+ def test_arrays_and_dtypes (self , pair , data ):
228
+ s1 , s2 = pair
229
+ a2 = tuple (xp .empty (1 , dtype = dt ) for dt in s2 )
230
+ a_and_dt = data .draw (st .permutations (s1 + a2 ))
231
+ out = xp .result_type (* a_and_dt )
232
+ ph .assert_dtype ("result_type" , in_dtype = s1 + s2 , out_dtype = out , repr_name = "out" )
233
+
234
+ @given (dtypes = hh .mutually_promotable_dtypes (2 ), data = st .data ())
235
+ def test_with_scalars (self , dtypes , data ):
236
+ out = xp .result_type (* dtypes )
237
+
238
+ if out == xp .bool :
239
+ scalars = [True ]
240
+ elif out in dh .all_int_dtypes :
241
+ scalars = [1 ]
242
+ elif out in dh .real_dtypes :
243
+ scalars = [1 , 1.0 ]
244
+ elif out in dh .numeric_dtypes :
245
+ scalars = [1 , 1.0 , 1j ] # numeric_types - real_types == complex_types
246
+ else :
247
+ raise ValueError (f"unknown dtype { out = } ." )
248
+
249
+ scalar = data .draw (st .sampled_from (scalars ))
250
+ inputs = data .draw (st .permutations (dtypes + (scalar ,)))
251
+
252
+ out_scalar = xp .result_type (* inputs )
253
+ assert out_scalar == out
254
+
255
+ # retry with arrays
256
+ arrays = tuple (xp .empty (1 , dtype = dt ) for dt in dtypes )
257
+ inputs = data .draw (st .permutations (arrays + (scalar ,)))
258
+ out_scalar = xp .result_type (* inputs )
259
+ assert out_scalar == out
260
+
0 commit comments