@@ -703,9 +703,9 @@ def test_abs(ctx, data):
703
703
abs , # type: ignore
704
704
res_stype = float if x .dtype in dh .complex_dtypes else None ,
705
705
expr_template = "abs({})={}" ,
706
- filter_ = lambda s : (
707
- s == float ("infinity" ) or (math .isfinite (s ) and not ph .is_neg_zero (s ))
708
- ),
706
+ # filter_=lambda s: (
707
+ # s == float("infinity") or (cmath .isfinite(s) and not ph.is_neg_zero(s))
708
+ # ),
709
709
)
710
710
711
711
@@ -714,8 +714,10 @@ def test_acos(x):
714
714
out = xp .acos (x )
715
715
ph .assert_dtype ("acos" , in_dtype = x .dtype , out_dtype = out .dtype )
716
716
ph .assert_shape ("acos" , out_shape = out .shape , expected = x .shape )
717
+ refimpl = cmath .acos if x .dtype in dh .complex_dtypes else math .acos
718
+ filter_ = default_filter if x .dtype in dh .complex_dtypes else lambda s : default_filter (s ) and - 1 <= s <= 1
717
719
unary_assert_against_refimpl (
718
- "acos" , x , out , math . acos , filter_ = lambda s : default_filter ( s ) and - 1 <= s <= 1
720
+ "acos" , x , out , refimpl , filter_ = filter_
719
721
)
720
722
721
723
@@ -724,8 +726,10 @@ def test_acosh(x):
724
726
out = xp .acosh (x )
725
727
ph .assert_dtype ("acosh" , in_dtype = x .dtype , out_dtype = out .dtype )
726
728
ph .assert_shape ("acosh" , out_shape = out .shape , expected = x .shape )
729
+ refimpl = cmath .acosh if x .dtype in dh .complex_dtypes else math .acosh
730
+ filter_ = default_filter if x .dtype in dh .complex_dtypes else lambda s : default_filter (s ) and s >= 1
727
731
unary_assert_against_refimpl (
728
- "acosh" , x , out , math . acosh , filter_ = lambda s : default_filter ( s ) and s >= 1
732
+ "acosh" , x , out , refimpl , filter_ = filter_
729
733
)
730
734
731
735
@@ -748,8 +752,10 @@ def test_asin(x):
748
752
out = xp .asin (x )
749
753
ph .assert_dtype ("asin" , in_dtype = x .dtype , out_dtype = out .dtype )
750
754
ph .assert_shape ("asin" , out_shape = out .shape , expected = x .shape )
755
+ refimpl = cmath .asin if x .dtype in dh .complex_dtypes else math .asin
756
+ filter_ = default_filter if x .dtype in dh .complex_dtypes else lambda s : default_filter (s ) and - 1 <= s <= 1
751
757
unary_assert_against_refimpl (
752
- "asin" , x , out , math . asin , filter_ = lambda s : default_filter ( s ) and - 1 <= s <= 1
758
+ "asin" , x , out , refimpl , filter_ = filter_
753
759
)
754
760
755
761
@@ -758,36 +764,41 @@ def test_asinh(x):
758
764
out = xp .asinh (x )
759
765
ph .assert_dtype ("asinh" , in_dtype = x .dtype , out_dtype = out .dtype )
760
766
ph .assert_shape ("asinh" , out_shape = out .shape , expected = x .shape )
761
- unary_assert_against_refimpl ("asinh" , x , out , math .asinh )
767
+ refimpl = cmath .asinh if x .dtype in dh .complex_dtypes else math .asinh
768
+ unary_assert_against_refimpl ("asinh" , x , out , refimpl )
762
769
763
770
764
771
@given (hh .arrays (dtype = hh .all_floating_dtypes (), shape = hh .shapes ()))
765
772
def test_atan (x ):
766
773
out = xp .atan (x )
767
774
ph .assert_dtype ("atan" , in_dtype = x .dtype , out_dtype = out .dtype )
768
775
ph .assert_shape ("atan" , out_shape = out .shape , expected = x .shape )
769
- unary_assert_against_refimpl ("atan" , x , out , math .atan )
776
+ refimpl = cmath .atan if x .dtype in dh .complex_dtypes else math .atan
777
+ unary_assert_against_refimpl ("atan" , x , out , refimpl )
770
778
771
779
772
780
@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
773
781
def test_atan2 (x1 , x2 ):
774
782
out = xp .atan2 (x1 , x2 )
775
783
ph .assert_dtype ("atan2" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
776
784
ph .assert_result_shape ("atan2" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
777
- binary_assert_against_refimpl ("atan2" , x1 , x2 , out , math .atan2 )
785
+ refimpl = cmath .atan2 if x1 .dtype in dh .complex_dtypes else math .atan2
786
+ binary_assert_against_refimpl ("atan2" , x1 , x2 , out , refimpl )
778
787
779
788
780
789
@given (hh .arrays (dtype = hh .all_floating_dtypes (), shape = hh .shapes ()))
781
790
def test_atanh (x ):
782
791
out = xp .atanh (x )
783
792
ph .assert_dtype ("atanh" , in_dtype = x .dtype , out_dtype = out .dtype )
784
793
ph .assert_shape ("atanh" , out_shape = out .shape , expected = x .shape )
794
+ refimpl = cmath .atanh if x .dtype in dh .complex_dtypes else math .atanh
795
+ filter_ = default_filter if x .dtype in dh .complex_dtypes else lambda s : default_filter (s ) and - 1 < s < 1
785
796
unary_assert_against_refimpl (
786
797
"atanh" ,
787
798
x ,
788
799
out ,
789
- math . atanh ,
790
- filter_ = lambda s : default_filter ( s ) and - 1 <= s <= 1 ,
800
+ refimpl ,
801
+ filter_ = filter_ ,
791
802
)
792
803
793
804
@@ -1065,15 +1076,17 @@ def test_cos(x):
1065
1076
out = xp .cos (x )
1066
1077
ph .assert_dtype ("cos" , in_dtype = x .dtype , out_dtype = out .dtype )
1067
1078
ph .assert_shape ("cos" , out_shape = out .shape , expected = x .shape )
1068
- unary_assert_against_refimpl ("cos" , x , out , math .cos )
1079
+ refimpl = cmath .cos if x .dtype in dh .complex_dtypes else math .cos
1080
+ unary_assert_against_refimpl ("cos" , x , out , refimpl )
1069
1081
1070
1082
1071
1083
@given (hh .arrays (dtype = hh .all_floating_dtypes (), shape = hh .shapes ()))
1072
1084
def test_cosh (x ):
1073
1085
out = xp .cosh (x )
1074
1086
ph .assert_dtype ("cosh" , in_dtype = x .dtype , out_dtype = out .dtype )
1075
1087
ph .assert_shape ("cosh" , out_shape = out .shape , expected = x .shape )
1076
- unary_assert_against_refimpl ("cosh" , x , out , math .cosh )
1088
+ refimpl = cmath .cosh if x .dtype in dh .complex_dtypes else math .cosh
1089
+ unary_assert_against_refimpl ("cosh" , x , out , refimpl )
1077
1090
1078
1091
1079
1092
@pytest .mark .parametrize ("ctx" , make_binary_params ("divide" , dh .all_float_dtypes ))
@@ -1097,7 +1110,7 @@ def test_divide(ctx, data):
1097
1110
res ,
1098
1111
"/" ,
1099
1112
operator .truediv ,
1100
- filter_ = lambda s : math .isfinite (s ) and s != 0 ,
1113
+ filter_ = lambda s : cmath .isfinite (s ) and s != 0 ,
1101
1114
)
1102
1115
1103
1116
@@ -1134,23 +1147,45 @@ def test_exp(x):
1134
1147
out = xp .exp (x )
1135
1148
ph .assert_dtype ("exp" , in_dtype = x .dtype , out_dtype = out .dtype )
1136
1149
ph .assert_shape ("exp" , out_shape = out .shape , expected = x .shape )
1137
- unary_assert_against_refimpl ("exp" , x , out , math .exp )
1150
+ refimpl = cmath .exp if x .dtype in dh .complex_dtypes else math .exp
1151
+ unary_assert_against_refimpl ("exp" , x , out , refimpl )
1138
1152
1139
1153
1140
1154
@given (hh .arrays (dtype = hh .all_floating_dtypes (), shape = hh .shapes ()))
1141
1155
def test_expm1 (x ):
1142
1156
out = xp .expm1 (x )
1143
1157
ph .assert_dtype ("expm1" , in_dtype = x .dtype , out_dtype = out .dtype )
1144
1158
ph .assert_shape ("expm1" , out_shape = out .shape , expected = x .shape )
1145
- unary_assert_against_refimpl ("expm1" , x , out , math .expm1 )
1159
+ if x .dtype in dh .complex_dtypes :
1160
+ def refimpl (z ):
1161
+ # There's no cmath.expm1. Use
1162
+ #
1163
+ # exp(x+yi) - 1
1164
+ # = exp(x)exp(yi) - 1
1165
+ # = exp(x)(cos(y) + sin(y)i) - 1
1166
+ # = (exp(x) - 1)cos(y) + (cos(y) - 1) + exp(x)sin(y)i
1167
+ # = expm1(x)cos(y) - 2sin(y/2)^2 + exp(x)sin(y)i
1168
+ #
1169
+ # where 1 - cos(y) = 2sin(y/2)^2 is used to avoid loss of
1170
+ # significance near y = 0.
1171
+ re , im = z .real , z .imag
1172
+ return math .expm1 (re )* math .cos (im ) - 2 * math .sin (im / 2 )** 2 + 1j * math .exp (re )* math .sin (im )
1173
+ else :
1174
+ refimpl = math .expm1
1175
+ unary_assert_against_refimpl ("expm1" , x , out , refimpl )
1146
1176
1147
1177
1148
1178
@given (hh .arrays (dtype = hh .real_dtypes , shape = hh .shapes ()))
1149
1179
def test_floor (x ):
1150
1180
out = xp .floor (x )
1151
1181
ph .assert_dtype ("floor" , in_dtype = x .dtype , out_dtype = out .dtype )
1152
1182
ph .assert_shape ("floor" , out_shape = out .shape , expected = x .shape )
1153
- unary_assert_against_refimpl ("floor" , x , out , math .floor , strict_check = True )
1183
+ if x .dtype in dh .complex_dtypes :
1184
+ def refimpl (z ):
1185
+ return complex (math .floor (z .real ), math .floor (z .imag ))
1186
+ else :
1187
+ refimpl = math .floor
1188
+ unary_assert_against_refimpl ("floor" , x , out , refimpl , strict_check = True )
1154
1189
1155
1190
1156
1191
@pytest .mark .parametrize ("ctx" , make_binary_params ("floor_divide" , dh .real_dtypes ))
@@ -1236,23 +1271,26 @@ def test_isfinite(x):
1236
1271
out = xp .isfinite (x )
1237
1272
ph .assert_dtype ("isfinite" , in_dtype = x .dtype , out_dtype = out .dtype , expected = xp .bool )
1238
1273
ph .assert_shape ("isfinite" , out_shape = out .shape , expected = x .shape )
1239
- unary_assert_against_refimpl ("isfinite" , x , out , math .isfinite , res_stype = bool )
1274
+ refimpl = cmath .isfinite if x .dtype in dh .complex_dtypes else math .isfinite
1275
+ unary_assert_against_refimpl ("isfinite" , x , out , refimpl , res_stype = bool )
1240
1276
1241
1277
1242
1278
@given (hh .arrays (dtype = hh .numeric_dtypes , shape = hh .shapes ()))
1243
1279
def test_isinf (x ):
1244
1280
out = xp .isinf (x )
1245
1281
ph .assert_dtype ("isfinite" , in_dtype = x .dtype , out_dtype = out .dtype , expected = xp .bool )
1246
1282
ph .assert_shape ("isinf" , out_shape = out .shape , expected = x .shape )
1247
- unary_assert_against_refimpl ("isinf" , x , out , math .isinf , res_stype = bool )
1283
+ refimpl = cmath .isinf if x .dtype in dh .complex_dtypes else math .isinf
1284
+ unary_assert_against_refimpl ("isinf" , x , out , refimpl , res_stype = bool )
1248
1285
1249
1286
1250
1287
@given (hh .arrays (dtype = hh .numeric_dtypes , shape = hh .shapes ()))
1251
1288
def test_isnan (x ):
1252
1289
out = xp .isnan (x )
1253
1290
ph .assert_dtype ("isnan" , in_dtype = x .dtype , out_dtype = out .dtype , expected = xp .bool )
1254
1291
ph .assert_shape ("isnan" , out_shape = out .shape , expected = x .shape )
1255
- unary_assert_against_refimpl ("isnan" , x , out , math .isnan , res_stype = bool )
1292
+ refimpl = cmath .isnan if x .dtype in dh .complex_dtypes else math .isnan
1293
+ unary_assert_against_refimpl ("isnan" , x , out , refimpl , res_stype = bool )
1256
1294
1257
1295
1258
1296
@pytest .mark .parametrize ("ctx" , make_binary_params ("less" , dh .real_dtypes ))
@@ -1300,8 +1338,10 @@ def test_log(x):
1300
1338
out = xp .log (x )
1301
1339
ph .assert_dtype ("log" , in_dtype = x .dtype , out_dtype = out .dtype )
1302
1340
ph .assert_shape ("log" , out_shape = out .shape , expected = x .shape )
1341
+ refimpl = cmath .log if x .dtype in dh .complex_dtypes else math .log
1342
+ filter_ = default_filter if x .dtype in dh .complex_dtypes else lambda s : default_filter (s ) and s > 0
1303
1343
unary_assert_against_refimpl (
1304
- "log" , x , out , math . log , filter_ = lambda s : default_filter ( s ) and s >= 1
1344
+ "log" , x , out , refimpl , filter_ = filter_
1305
1345
)
1306
1346
1307
1347
@@ -1310,8 +1350,19 @@ def test_log1p(x):
1310
1350
out = xp .log1p (x )
1311
1351
ph .assert_dtype ("log1p" , in_dtype = x .dtype , out_dtype = out .dtype )
1312
1352
ph .assert_shape ("log1p" , out_shape = out .shape , expected = x .shape )
1353
+ # There isn't a cmath.log1p, and implementing one isn't straightforward
1354
+ # (see
1355
+ # https://stackoverflow.com/questions/78318212/unexpected-behaviour-of-log1p-numpy).
1356
+ # For now, just use log(1+p) for complex inputs, which should hopefully be
1357
+ # fine given the very loose numerical tolerances we use. If it isn't, we
1358
+ # can try using something like a series expansion for small p.
1359
+ if x .dtype in dh .complex_dtypes :
1360
+ refimpl = lambda z : cmath .log (1 + z )
1361
+ else :
1362
+ refimpl = math .log1p
1363
+ filter_ = default_filter if x .dtype in dh .complex_dtypes else lambda s : default_filter (s ) and s > - 1
1313
1364
unary_assert_against_refimpl (
1314
- "log1p" , x , out , math . log1p , filter_ = lambda s : default_filter ( s ) and s >= 1
1365
+ "log1p" , x , out , refimpl , filter_ = filter_
1315
1366
)
1316
1367
1317
1368
@@ -1320,8 +1371,13 @@ def test_log2(x):
1320
1371
out = xp .log2 (x )
1321
1372
ph .assert_dtype ("log2" , in_dtype = x .dtype , out_dtype = out .dtype )
1322
1373
ph .assert_shape ("log2" , out_shape = out .shape , expected = x .shape )
1374
+ if x .dtype in dh .complex_dtypes :
1375
+ refimpl = lambda z : cmath .log (z )/ math .log (2 )
1376
+ else :
1377
+ refimpl = math .log2
1378
+ filter_ = default_filter if x .dtype in dh .complex_dtypes else lambda s : default_filter (s ) and s > 0
1323
1379
unary_assert_against_refimpl (
1324
- "log2" , x , out , math . log2 , filter_ = lambda s : default_filter ( s ) and s > 1
1380
+ "log2" , x , out , refimpl , filter_ = filter_
1325
1381
)
1326
1382
1327
1383
@@ -1330,12 +1386,17 @@ def test_log10(x):
1330
1386
out = xp .log10 (x )
1331
1387
ph .assert_dtype ("log10" , in_dtype = x .dtype , out_dtype = out .dtype )
1332
1388
ph .assert_shape ("log10" , out_shape = out .shape , expected = x .shape )
1389
+ if x .dtype in dh .complex_dtypes :
1390
+ refimpl = lambda z : cmath .log (z )/ math .log (10 )
1391
+ else :
1392
+ refimpl = math .log10
1393
+ filter_ = default_filter if x .dtype in dh .complex_dtypes else lambda s : default_filter (s ) and s > 0
1333
1394
unary_assert_against_refimpl (
1334
- "log10" , x , out , math . log10 , filter_ = lambda s : default_filter ( s ) and s > 0
1395
+ "log10" , x , out , refimpl , filter_ = filter_
1335
1396
)
1336
1397
1337
1398
1338
- def logaddexp (l : float , r : float ) -> float :
1399
+ def logaddexp_refimpl (l : float , r : float ) -> float :
1339
1400
return math .log (math .exp (l ) + math .exp (r ))
1340
1401
1341
1402
@@ -1344,7 +1405,7 @@ def test_logaddexp(x1, x2):
1344
1405
out = xp .logaddexp (x1 , x2 )
1345
1406
ph .assert_dtype ("logaddexp" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1346
1407
ph .assert_result_shape ("logaddexp" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1347
- binary_assert_against_refimpl ("logaddexp" , x1 , x2 , out , logaddexp )
1408
+ binary_assert_against_refimpl ("logaddexp" , x1 , x2 , out , logaddexp_refimpl )
1348
1409
1349
1410
1350
1411
@given (* hh .two_mutual_arrays ([xp .bool ]))
@@ -1521,7 +1582,11 @@ def test_round(x):
1521
1582
out = xp .round (x )
1522
1583
ph .assert_dtype ("round" , in_dtype = x .dtype , out_dtype = out .dtype )
1523
1584
ph .assert_shape ("round" , out_shape = out .shape , expected = x .shape )
1524
- unary_assert_against_refimpl ("round" , x , out , round , strict_check = True )
1585
+ if x .dtype in dh .complex_dtypes :
1586
+ refimpl = lambda z : complex (round (z .real ), round (z .imag ))
1587
+ else :
1588
+ refimpl = round
1589
+ unary_assert_against_refimpl ("round" , x , out , refimpl , strict_check = True )
1525
1590
1526
1591
1527
1592
@pytest .mark .min_version ("2023.12" )
@@ -1539,13 +1604,12 @@ def test_sign(x):
1539
1604
out = xp .sign (x )
1540
1605
ph .assert_dtype ("sign" , in_dtype = x .dtype , out_dtype = out .dtype )
1541
1606
ph .assert_shape ("sign" , out_shape = out .shape , expected = x .shape )
1542
- refimpl = lambda x : x / math . abs (x ) if x != 0 else 0
1607
+ refimpl = lambda x : x / abs (x ) if x != 0 else 0
1543
1608
unary_assert_against_refimpl (
1544
1609
"sign" ,
1545
1610
x ,
1546
1611
out ,
1547
1612
refimpl ,
1548
- filter_ = lambda s : s != 0 ,
1549
1613
strict_check = True ,
1550
1614
)
1551
1615
@@ -1555,15 +1619,17 @@ def test_sin(x):
1555
1619
out = xp .sin (x )
1556
1620
ph .assert_dtype ("sin" , in_dtype = x .dtype , out_dtype = out .dtype )
1557
1621
ph .assert_shape ("sin" , out_shape = out .shape , expected = x .shape )
1558
- unary_assert_against_refimpl ("sin" , x , out , math .sin )
1622
+ refimpl = cmath .sin if x .dtype in dh .complex_dtypes else math .sin
1623
+ unary_assert_against_refimpl ("sin" , x , out , refimpl )
1559
1624
1560
1625
1561
1626
@given (hh .arrays (dtype = hh .all_floating_dtypes (), shape = hh .shapes ()))
1562
1627
def test_sinh (x ):
1563
1628
out = xp .sinh (x )
1564
1629
ph .assert_dtype ("sinh" , in_dtype = x .dtype , out_dtype = out .dtype )
1565
1630
ph .assert_shape ("sinh" , out_shape = out .shape , expected = x .shape )
1566
- unary_assert_against_refimpl ("sinh" , x , out , math .sinh )
1631
+ refimpl = cmath .sinh if x .dtype in dh .complex_dtypes else math .sinh
1632
+ unary_assert_against_refimpl ("sinh" , x , out , refimpl )
1567
1633
1568
1634
1569
1635
@given (hh .arrays (dtype = hh .numeric_dtypes , shape = hh .shapes ()))
@@ -1581,8 +1647,10 @@ def test_sqrt(x):
1581
1647
out = xp .sqrt (x )
1582
1648
ph .assert_dtype ("sqrt" , in_dtype = x .dtype , out_dtype = out .dtype )
1583
1649
ph .assert_shape ("sqrt" , out_shape = out .shape , expected = x .shape )
1650
+ refimpl = cmath .sqrt if x .dtype in dh .complex_dtypes else math .sqrt
1651
+ filter_ = default_filter if x .dtype in dh .complex_dtypes else lambda s : default_filter (s ) and s >= 0
1584
1652
unary_assert_against_refimpl (
1585
- "sqrt" , x , out , math . sqrt , filter_ = lambda s : default_filter ( s ) and s >= 0
1653
+ "sqrt" , x , out , refimpl , filter_ = filter_
1586
1654
)
1587
1655
1588
1656
@@ -1605,15 +1673,17 @@ def test_tan(x):
1605
1673
out = xp .tan (x )
1606
1674
ph .assert_dtype ("tan" , in_dtype = x .dtype , out_dtype = out .dtype )
1607
1675
ph .assert_shape ("tan" , out_shape = out .shape , expected = x .shape )
1608
- unary_assert_against_refimpl ("tan" , x , out , math .tan )
1676
+ refimpl = cmath .tan if x .dtype in dh .complex_dtypes else math .tan
1677
+ unary_assert_against_refimpl ("tan" , x , out , refimpl )
1609
1678
1610
1679
1611
1680
@given (hh .arrays (dtype = hh .all_floating_dtypes (), shape = hh .shapes ()))
1612
1681
def test_tanh (x ):
1613
1682
out = xp .tanh (x )
1614
1683
ph .assert_dtype ("tanh" , in_dtype = x .dtype , out_dtype = out .dtype )
1615
1684
ph .assert_shape ("tanh" , out_shape = out .shape , expected = x .shape )
1616
- unary_assert_against_refimpl ("tanh" , x , out , math .tanh )
1685
+ refimpl = cmath .tanh if x .dtype in dh .complex_dtypes else math .tanh
1686
+ unary_assert_against_refimpl ("tanh" , x , out , refimpl )
1617
1687
1618
1688
1619
1689
@given (hh .arrays (dtype = hh .real_dtypes , shape = xps .array_shapes ()))
0 commit comments