@@ -690,6 +690,40 @@ def binary_param_assert_against_refimpl(
690
690
)
691
691
692
692
693
+ def _convert_scalars_helper (x1 , x2 ):
694
+ """Convert python scalar to arrays, record the shapes/dtypes of arrays.
695
+
696
+ For inputs being scalars or arrays, return the dtypes and shapes of array arguments,
697
+ and all arguments converted to arrays.
698
+
699
+ dtypes are separate to help distinguishing between
700
+ `py_scalar + f32_array -> f32_array` and `f64_array + f32_array -> f64_array`
701
+ """
702
+ if dh .is_scalar (x1 ):
703
+ in_dtypes = [x2 .dtype ]
704
+ in_shapes = [x2 .shape ]
705
+ x1a , x2a = xp .asarray (x1 ), x2
706
+ elif dh .is_scalar (x2 ):
707
+ in_dtypes = [x1 .dtype ]
708
+ in_shapes = [x1 .shape ]
709
+ x1a , x2a = x1 , xp .asarray (x2 )
710
+ else :
711
+ in_dtypes = [x1 .dtype , x2 .dtype ]
712
+ in_shapes = [x1 .shape , x2 .shape ]
713
+ x1a , x2a = x1 , x2
714
+
715
+ return in_dtypes , in_shapes , (x1a , x2a )
716
+
717
+
718
+ def _assert_correctness_binary (
719
+ name , func , in_dtypes , in_shapes , in_arrs , out , expected_dtype = None , ** kwargs
720
+ ):
721
+ x1a , x2a = in_arrs
722
+ ph .assert_dtype (name , in_dtype = in_dtypes , out_dtype = out .dtype , expected = expected_dtype )
723
+ ph .assert_result_shape (name , in_shapes = in_shapes , out_shape = out .shape )
724
+ binary_assert_against_refimpl (name , x1a , x2a , out , func , ** kwargs )
725
+
726
+
693
727
@pytest .mark .parametrize ("ctx" , make_unary_params ("abs" , dh .numeric_dtypes ))
694
728
@given (data = st .data ())
695
729
def test_abs (ctx , data ):
@@ -789,10 +823,14 @@ def test_atan(x):
789
823
@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
790
824
def test_atan2 (x1 , x2 ):
791
825
out = xp .atan2 (x1 , x2 )
792
- ph .assert_dtype ("atan2" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
793
- ph .assert_result_shape ("atan2" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
794
- refimpl = cmath .atan2 if x1 .dtype in dh .complex_dtypes else math .atan2
795
- binary_assert_against_refimpl ("atan2" , x1 , x2 , out , refimpl )
826
+ _assert_correctness_binary (
827
+ "atan" ,
828
+ cmath .atan2 if x1 .dtype in dh .complex_dtypes else math .atan2 ,
829
+ in_dtypes = [x1 .dtype , x2 .dtype ],
830
+ in_shapes = [x1 .shape , x2 .shape ],
831
+ in_arrs = [x1 , x2 ],
832
+ out = out ,
833
+ )
796
834
797
835
798
836
@given (hh .arrays (dtype = hh .all_floating_dtypes (), shape = hh .shapes ()))
@@ -1258,10 +1296,14 @@ def test_greater_equal(ctx, data):
1258
1296
@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
1259
1297
def test_hypot (x1 , x2 ):
1260
1298
out = xp .hypot (x1 , x2 )
1261
- ph .assert_dtype ("hypot" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1262
- ph .assert_result_shape ("hypot" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1263
- binary_assert_against_refimpl ("hypot" , x1 , x2 , out , math .hypot )
1264
-
1299
+ _assert_correctness_binary (
1300
+ "hypot" ,
1301
+ math .hypot ,
1302
+ in_dtypes = [x1 .dtype , x2 .dtype ],
1303
+ in_shapes = [x1 .shape , x2 .shape ],
1304
+ in_arrs = [x1 , x2 ],
1305
+ out = out
1306
+ )
1265
1307
1266
1308
1267
1309
@pytest .mark .min_version ("2022.12" )
@@ -1411,21 +1453,17 @@ def logaddexp_refimpl(l: float, r: float) -> float:
1411
1453
raise OverflowError
1412
1454
1413
1455
1456
+ @pytest .mark .min_version ("2023.12" )
1414
1457
@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
1415
1458
def test_logaddexp (x1 , x2 ):
1416
1459
out = xp .logaddexp (x1 , x2 )
1417
- ph .assert_dtype ("logaddexp" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1418
- ph .assert_result_shape ("logaddexp" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1419
- binary_assert_against_refimpl ("logaddexp" , x1 , x2 , out , logaddexp_refimpl )
1420
-
1421
-
1422
- @given (* hh .two_mutual_arrays ([xp .bool ]))
1423
- def test_logical_and (x1 , x2 ):
1424
- out = xp .logical_and (x1 , x2 )
1425
- ph .assert_dtype ("logical_and" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1426
- ph .assert_result_shape ("logical_and" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1427
- binary_assert_against_refimpl (
1428
- "logical_and" , x1 , x2 , out , operator .and_ , expr_template = "({} and {})={}"
1460
+ _assert_correctness_binary (
1461
+ "logaddexp" ,
1462
+ logaddexp_refimpl ,
1463
+ in_dtypes = [x1 .dtype , x2 .dtype ],
1464
+ in_shapes = [x1 .shape , x2 .shape ],
1465
+ in_arrs = [x1 , x2 ],
1466
+ out = out
1429
1467
)
1430
1468
1431
1469
@@ -1439,42 +1477,64 @@ def test_logical_not(x):
1439
1477
)
1440
1478
1441
1479
1480
+ @given (* hh .two_mutual_arrays ([xp .bool ]))
1481
+ def test_logical_and (x1 , x2 ):
1482
+ out = xp .logical_and (x1 , x2 )
1483
+ _assert_correctness_binary (
1484
+ "logical_and" ,
1485
+ operator .and_ ,
1486
+ in_dtypes = [x1 .dtype , x2 .dtype ],
1487
+ in_shapes = [x1 .shape , x2 .shape ],
1488
+ in_arrs = [x1 , x2 ],
1489
+ out = out ,
1490
+ expr_template = "({} and {})={}"
1491
+ )
1492
+
1493
+
1442
1494
@given (* hh .two_mutual_arrays ([xp .bool ]))
1443
1495
def test_logical_or (x1 , x2 ):
1444
1496
out = xp .logical_or (x1 , x2 )
1445
- ph .assert_dtype ("logical_or" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1446
- ph .assert_result_shape ("logical_or" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1447
- binary_assert_against_refimpl (
1448
- "logical_or" , x1 , x2 , out , operator .or_ , expr_template = "({} or {})={}"
1497
+ _assert_correctness_binary (
1498
+ "logical_or" ,
1499
+ operator .or_ ,
1500
+ in_dtypes = [x1 .dtype , x2 .dtype ],
1501
+ in_shapes = [x1 .shape , x2 .shape ],
1502
+ in_arrs = [x1 , x2 ],
1503
+ out = out ,
1504
+ expr_template = "({} or {})={}"
1449
1505
)
1450
1506
1451
1507
1452
1508
@given (* hh .two_mutual_arrays ([xp .bool ]))
1453
1509
def test_logical_xor (x1 , x2 ):
1454
1510
out = xp .logical_xor (x1 , x2 )
1455
- ph .assert_dtype ("logical_xor" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1456
- ph .assert_result_shape ("logical_xor" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1457
- binary_assert_against_refimpl (
1458
- "logical_xor" , x1 , x2 , out , operator .xor , expr_template = "({} ^ {})={}"
1511
+ _assert_correctness_binary (
1512
+ "logical_xor" ,
1513
+ operator .xor ,
1514
+ in_dtypes = [x1 .dtype , x2 .dtype ],
1515
+ in_shapes = [x1 .shape , x2 .shape ],
1516
+ in_arrs = [x1 , x2 ],
1517
+ out = out ,
1518
+ expr_template = "({} ^ {})={}"
1459
1519
)
1460
1520
1461
1521
1462
1522
@pytest .mark .min_version ("2023.12" )
1463
1523
@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
1464
1524
def test_maximum (x1 , x2 ):
1465
1525
out = xp .maximum (x1 , x2 )
1466
- ph . assert_dtype ( "maximum" , in_dtype = [ x1 . dtype , x2 . dtype ], out_dtype = out . dtype )
1467
- ph . assert_result_shape ( "maximum" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1468
- binary_assert_against_refimpl ( "maximum" , x1 , x2 , out , max , strict_check = True )
1526
+ _assert_correctness_binary (
1527
+ "maximum" , max , [x1 .dtype , x2 .dtype ], [ x1 .shape , x2 . shape ], ( x1 , x2 ), out , strict_check = True
1528
+ )
1469
1529
1470
1530
1471
1531
@pytest .mark .min_version ("2023.12" )
1472
1532
@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
1473
1533
def test_minimum (x1 , x2 ):
1474
1534
out = xp .minimum (x1 , x2 )
1475
- ph . assert_dtype ( "minimum" , in_dtype = [ x1 . dtype , x2 . dtype ], out_dtype = out . dtype )
1476
- ph . assert_result_shape ( "minimum" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1477
- binary_assert_against_refimpl ( "minimum" , x1 , x2 , out , min , strict_check = True )
1535
+ _assert_correctness_binary (
1536
+ "minimum" , min , [x1 .dtype , x2 .dtype ], [ x1 .shape , x2 . shape ], ( x1 , x2 ), out , strict_check = True
1537
+ )
1478
1538
1479
1539
1480
1540
@pytest .mark .parametrize ("ctx" , make_binary_params ("multiply" , dh .numeric_dtypes ))
@@ -1719,3 +1779,88 @@ def test_trunc(x):
1719
1779
ph .assert_dtype ("trunc" , in_dtype = x .dtype , out_dtype = out .dtype )
1720
1780
ph .assert_shape ("trunc" , out_shape = out .shape , expected = x .shape )
1721
1781
unary_assert_against_refimpl ("trunc" , x , out , math .trunc , strict_check = True )
1782
+
1783
+
1784
+ def _check_binary_with_scalars (func_data , x1x2 ):
1785
+ x1 , x2 = x1x2
1786
+ func , name , refimpl , kwds , expected_dtype = func_data
1787
+ out = func (x1 , x2 )
1788
+ in_dtypes , in_shapes , (x1a , x2a ) = _convert_scalars_helper (x1 , x2 )
1789
+ _assert_correctness_binary (
1790
+ name , refimpl , in_dtypes , in_shapes , (x1a , x2a ), out , expected_dtype , ** kwds
1791
+ )
1792
+
1793
+
1794
+ def _filter_zero (x ):
1795
+ return x != 0 if dh .is_scalar (x ) else (not xp .any (x == 0 ))
1796
+
1797
+
1798
+ @pytest .mark .min_version ("2024.12" )
1799
+ @pytest .mark .parametrize ('func_data' ,
1800
+ # xp_func, name, refimpl, kwargs, expected_dtype
1801
+ [
1802
+ (xp .add , "add" , operator .add , {}, None ),
1803
+ (xp .atan2 , "atan2" , math .atan2 , {}, None ),
1804
+ (xp .copysign , "copysign" , math .copysign , {}, None ),
1805
+ (xp .divide , "divide" , operator .truediv , {"filter_" : lambda s : s != 0 }, None ),
1806
+ (xp .hypot , "hypot" , math .hypot , {}, None ),
1807
+ (xp .logaddexp , "logaddexp" , logaddexp_refimpl , {}, None ),
1808
+ (xp .maximum , "maximum" , max , {'strict_check' : True }, None ),
1809
+ (xp .minimum , "minimum" , min , {'strict_check' : True }, None ),
1810
+ (xp .multiply , "mul" , operator .mul , {}, None ),
1811
+ (xp .subtract , "sub" , operator .sub , {}, None ),
1812
+
1813
+ (xp .equal , "equal" , operator .eq , {}, xp .bool ),
1814
+ (xp .not_equal , "neq" , operator .ne , {}, xp .bool ),
1815
+ (xp .less , "less" , operator .lt , {}, xp .bool ),
1816
+ (xp .less_equal , "les_equal" , operator .le , {}, xp .bool ),
1817
+ (xp .greater , "greater" , operator .gt , {}, xp .bool ),
1818
+ (xp .greater_equal , "greater_equal" , operator .ge , {}, xp .bool ),
1819
+ (xp .remainder , "remainder" , operator .mod , {}, None ),
1820
+ (xp .floor_divide , "floor_divide" , operator .floordiv , {}, None ),
1821
+ ],
1822
+ ids = lambda func_data : func_data [1 ] # use names for test IDs
1823
+ )
1824
+ @given (x1x2 = hh .array_and_py_scalar (dh .real_float_dtypes ))
1825
+ def test_binary_with_scalars_real (func_data , x1x2 ):
1826
+
1827
+ if func_data [1 ] == "remainder" :
1828
+ assume (_filter_zero (x1x2 [1 ]))
1829
+ if func_data [1 ] == "floor_divide" :
1830
+ assume (_filter_zero (x1x2 [0 ]) and _filter_zero (x1x2 [1 ]))
1831
+
1832
+ _check_binary_with_scalars (func_data , x1x2 )
1833
+
1834
+
1835
+ @pytest .mark .min_version ("2024.12" )
1836
+ @pytest .mark .parametrize ('func_data' ,
1837
+ # xp_func, name, refimpl, kwargs, expected_dtype
1838
+ [
1839
+ (xp .logical_and , "logical_and" , operator .and_ , {"expr_template" : "({} or {})={}" }, None ),
1840
+ (xp .logical_or , "logical_or" , operator .or_ , {"expr_template" : "({} or {})={}" }, None ),
1841
+ (xp .logical_xor , "logical_xor" , operator .xor , {"expr_template" : "({} or {})={}" }, None ),
1842
+ ],
1843
+ ids = lambda func_data : func_data [1 ] # use names for test IDs
1844
+ )
1845
+ @given (x1x2 = hh .array_and_py_scalar ([xp .bool ]))
1846
+ def test_binary_with_scalars_bool (func_data , x1x2 ):
1847
+ _check_binary_with_scalars (func_data , x1x2 )
1848
+
1849
+
1850
+ @pytest .mark .min_version ("2024.12" )
1851
+ @pytest .mark .parametrize ('func_data' ,
1852
+ # xp_func, name, refimpl, kwargs, expected_dtype
1853
+ [
1854
+ (xp .bitwise_and , "bitwise_and" , operator .and_ , {}, None ),
1855
+ (xp .bitwise_or , "bitwise_or" , operator .or_ , {}, None ),
1856
+ (xp .bitwise_xor , "bitwise_xor" , operator .xor , {}, None ),
1857
+ ],
1858
+ ids = lambda func_data : func_data [1 ] # use names for test IDs
1859
+ )
1860
+ @given (x1x2 = hh .array_and_py_scalar ([xp .int32 ]))
1861
+ def test_binary_with_scalars_bitwise (func_data , x1x2 ):
1862
+ xp_func , name , refimpl , kwargs , expected = func_data
1863
+ # repack the refimpl
1864
+ refimpl_ = lambda l , r : mock_int_dtype (refimpl (l , r ), xp .int32 )
1865
+ _check_binary_with_scalars ((xp_func , name , refimpl_ , kwargs ,expected ), x1x2 )
1866
+
0 commit comments