Skip to content

Commit aba096e

Browse files
committed
Fix reference implementations for complex elementwise functions
Some of these still have some issues that need to be addressed, like overflows.
1 parent 9efcad5 commit aba096e

File tree

1 file changed

+104
-34
lines changed

1 file changed

+104
-34
lines changed

Diff for: array_api_tests/test_operators_and_elementwise_functions.py

+104-34
Original file line numberDiff line numberDiff line change
@@ -703,9 +703,9 @@ def test_abs(ctx, data):
703703
abs, # type: ignore
704704
res_stype=float if x.dtype in dh.complex_dtypes else None,
705705
expr_template="abs({})={}",
706-
filter_=lambda s: (
707-
s == float("infinity") or (math.isfinite(s) and not ph.is_neg_zero(s))
708-
),
706+
# filter_=lambda s: (
707+
# s == float("infinity") or (cmath.isfinite(s) and not ph.is_neg_zero(s))
708+
# ),
709709
)
710710

711711

@@ -714,8 +714,10 @@ def test_acos(x):
714714
out = xp.acos(x)
715715
ph.assert_dtype("acos", in_dtype=x.dtype, out_dtype=out.dtype)
716716
ph.assert_shape("acos", out_shape=out.shape, expected=x.shape)
717+
refimpl = cmath.acos if x.dtype in dh.complex_dtypes else math.acos
718+
filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 <= s <= 1
717719
unary_assert_against_refimpl(
718-
"acos", x, out, math.acos, filter_=lambda s: default_filter(s) and -1 <= s <= 1
720+
"acos", x, out, refimpl, filter_=filter_
719721
)
720722

721723

@@ -724,8 +726,10 @@ def test_acosh(x):
724726
out = xp.acosh(x)
725727
ph.assert_dtype("acosh", in_dtype=x.dtype, out_dtype=out.dtype)
726728
ph.assert_shape("acosh", out_shape=out.shape, expected=x.shape)
729+
refimpl = cmath.acosh if x.dtype in dh.complex_dtypes else math.acosh
730+
filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s >= 1
727731
unary_assert_against_refimpl(
728-
"acosh", x, out, math.acosh, filter_=lambda s: default_filter(s) and s >= 1
732+
"acosh", x, out, refimpl, filter_=filter_
729733
)
730734

731735

@@ -748,8 +752,10 @@ def test_asin(x):
748752
out = xp.asin(x)
749753
ph.assert_dtype("asin", in_dtype=x.dtype, out_dtype=out.dtype)
750754
ph.assert_shape("asin", out_shape=out.shape, expected=x.shape)
755+
refimpl = cmath.asin if x.dtype in dh.complex_dtypes else math.asin
756+
filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 <= s <= 1
751757
unary_assert_against_refimpl(
752-
"asin", x, out, math.asin, filter_=lambda s: default_filter(s) and -1 <= s <= 1
758+
"asin", x, out, refimpl, filter_=filter_
753759
)
754760

755761

@@ -758,36 +764,41 @@ def test_asinh(x):
758764
out = xp.asinh(x)
759765
ph.assert_dtype("asinh", in_dtype=x.dtype, out_dtype=out.dtype)
760766
ph.assert_shape("asinh", out_shape=out.shape, expected=x.shape)
761-
unary_assert_against_refimpl("asinh", x, out, math.asinh)
767+
refimpl = cmath.asinh if x.dtype in dh.complex_dtypes else math.asinh
768+
unary_assert_against_refimpl("asinh", x, out, refimpl)
762769

763770

764771
@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
765772
def test_atan(x):
766773
out = xp.atan(x)
767774
ph.assert_dtype("atan", in_dtype=x.dtype, out_dtype=out.dtype)
768775
ph.assert_shape("atan", out_shape=out.shape, expected=x.shape)
769-
unary_assert_against_refimpl("atan", x, out, math.atan)
776+
refimpl = cmath.atan if x.dtype in dh.complex_dtypes else math.atan
777+
unary_assert_against_refimpl("atan", x, out, refimpl)
770778

771779

772780
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
773781
def test_atan2(x1, x2):
774782
out = xp.atan2(x1, x2)
775783
ph.assert_dtype("atan2", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
776784
ph.assert_result_shape("atan2", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
777-
binary_assert_against_refimpl("atan2", x1, x2, out, math.atan2)
785+
refimpl = cmath.atan2 if x1.dtype in dh.complex_dtypes else math.atan2
786+
binary_assert_against_refimpl("atan2", x1, x2, out, refimpl)
778787

779788

780789
@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
781790
def test_atanh(x):
782791
out = xp.atanh(x)
783792
ph.assert_dtype("atanh", in_dtype=x.dtype, out_dtype=out.dtype)
784793
ph.assert_shape("atanh", out_shape=out.shape, expected=x.shape)
794+
refimpl = cmath.atanh if x.dtype in dh.complex_dtypes else math.atanh
795+
filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 < s < 1
785796
unary_assert_against_refimpl(
786797
"atanh",
787798
x,
788799
out,
789-
math.atanh,
790-
filter_=lambda s: default_filter(s) and -1 <= s <= 1,
800+
refimpl,
801+
filter_=filter_,
791802
)
792803

793804

@@ -1065,15 +1076,17 @@ def test_cos(x):
10651076
out = xp.cos(x)
10661077
ph.assert_dtype("cos", in_dtype=x.dtype, out_dtype=out.dtype)
10671078
ph.assert_shape("cos", out_shape=out.shape, expected=x.shape)
1068-
unary_assert_against_refimpl("cos", x, out, math.cos)
1079+
refimpl = cmath.cos if x.dtype in dh.complex_dtypes else math.cos
1080+
unary_assert_against_refimpl("cos", x, out, refimpl)
10691081

10701082

10711083
@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
10721084
def test_cosh(x):
10731085
out = xp.cosh(x)
10741086
ph.assert_dtype("cosh", in_dtype=x.dtype, out_dtype=out.dtype)
10751087
ph.assert_shape("cosh", out_shape=out.shape, expected=x.shape)
1076-
unary_assert_against_refimpl("cosh", x, out, math.cosh)
1088+
refimpl = cmath.cosh if x.dtype in dh.complex_dtypes else math.cosh
1089+
unary_assert_against_refimpl("cosh", x, out, refimpl)
10771090

10781091

10791092
@pytest.mark.parametrize("ctx", make_binary_params("divide", dh.all_float_dtypes))
@@ -1097,7 +1110,7 @@ def test_divide(ctx, data):
10971110
res,
10981111
"/",
10991112
operator.truediv,
1100-
filter_=lambda s: math.isfinite(s) and s != 0,
1113+
filter_=lambda s: cmath.isfinite(s) and s != 0,
11011114
)
11021115

11031116

@@ -1134,23 +1147,45 @@ def test_exp(x):
11341147
out = xp.exp(x)
11351148
ph.assert_dtype("exp", in_dtype=x.dtype, out_dtype=out.dtype)
11361149
ph.assert_shape("exp", out_shape=out.shape, expected=x.shape)
1137-
unary_assert_against_refimpl("exp", x, out, math.exp)
1150+
refimpl = cmath.exp if x.dtype in dh.complex_dtypes else math.exp
1151+
unary_assert_against_refimpl("exp", x, out, refimpl)
11381152

11391153

11401154
@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
11411155
def test_expm1(x):
11421156
out = xp.expm1(x)
11431157
ph.assert_dtype("expm1", in_dtype=x.dtype, out_dtype=out.dtype)
11441158
ph.assert_shape("expm1", out_shape=out.shape, expected=x.shape)
1145-
unary_assert_against_refimpl("expm1", x, out, math.expm1)
1159+
if x.dtype in dh.complex_dtypes:
1160+
def refimpl(z):
1161+
# There's no cmath.expm1. Use
1162+
#
1163+
# exp(x+yi) - 1
1164+
# = exp(x)exp(yi) - 1
1165+
# = exp(x)(cos(y) + sin(y)i) - 1
1166+
# = (exp(x) - 1)cos(y) + (cos(y) - 1) + exp(x)sin(y)i
1167+
# = expm1(x)cos(y) - 2sin(y/2)^2 + exp(x)sin(y)i
1168+
#
1169+
# where 1 - cos(y) = 2sin(y/2)^2 is used to avoid loss of
1170+
# significance near y = 0.
1171+
re, im = z.real, z.imag
1172+
return math.expm1(re)*math.cos(im) - 2*math.sin(im/2)**2 + 1j*math.exp(re)*math.sin(im)
1173+
else:
1174+
refimpl = math.expm1
1175+
unary_assert_against_refimpl("expm1", x, out, refimpl)
11461176

11471177

11481178
@given(hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes()))
11491179
def test_floor(x):
11501180
out = xp.floor(x)
11511181
ph.assert_dtype("floor", in_dtype=x.dtype, out_dtype=out.dtype)
11521182
ph.assert_shape("floor", out_shape=out.shape, expected=x.shape)
1153-
unary_assert_against_refimpl("floor", x, out, math.floor, strict_check=True)
1183+
if x.dtype in dh.complex_dtypes:
1184+
def refimpl(z):
1185+
return complex(math.floor(z.real), math.floor(z.imag))
1186+
else:
1187+
refimpl = math.floor
1188+
unary_assert_against_refimpl("floor", x, out, refimpl, strict_check=True)
11541189

11551190

11561191
@pytest.mark.parametrize("ctx", make_binary_params("floor_divide", dh.real_dtypes))
@@ -1236,23 +1271,26 @@ def test_isfinite(x):
12361271
out = xp.isfinite(x)
12371272
ph.assert_dtype("isfinite", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool)
12381273
ph.assert_shape("isfinite", out_shape=out.shape, expected=x.shape)
1239-
unary_assert_against_refimpl("isfinite", x, out, math.isfinite, res_stype=bool)
1274+
refimpl = cmath.isfinite if x.dtype in dh.complex_dtypes else math.isfinite
1275+
unary_assert_against_refimpl("isfinite", x, out, refimpl, res_stype=bool)
12401276

12411277

12421278
@given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes()))
12431279
def test_isinf(x):
12441280
out = xp.isinf(x)
12451281
ph.assert_dtype("isfinite", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool)
12461282
ph.assert_shape("isinf", out_shape=out.shape, expected=x.shape)
1247-
unary_assert_against_refimpl("isinf", x, out, math.isinf, res_stype=bool)
1283+
refimpl = cmath.isinf if x.dtype in dh.complex_dtypes else math.isinf
1284+
unary_assert_against_refimpl("isinf", x, out, refimpl, res_stype=bool)
12481285

12491286

12501287
@given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes()))
12511288
def test_isnan(x):
12521289
out = xp.isnan(x)
12531290
ph.assert_dtype("isnan", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool)
12541291
ph.assert_shape("isnan", out_shape=out.shape, expected=x.shape)
1255-
unary_assert_against_refimpl("isnan", x, out, math.isnan, res_stype=bool)
1292+
refimpl = cmath.isnan if x.dtype in dh.complex_dtypes else math.isnan
1293+
unary_assert_against_refimpl("isnan", x, out, refimpl, res_stype=bool)
12561294

12571295

12581296
@pytest.mark.parametrize("ctx", make_binary_params("less", dh.real_dtypes))
@@ -1300,8 +1338,10 @@ def test_log(x):
13001338
out = xp.log(x)
13011339
ph.assert_dtype("log", in_dtype=x.dtype, out_dtype=out.dtype)
13021340
ph.assert_shape("log", out_shape=out.shape, expected=x.shape)
1341+
refimpl = cmath.log if x.dtype in dh.complex_dtypes else math.log
1342+
filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0
13031343
unary_assert_against_refimpl(
1304-
"log", x, out, math.log, filter_=lambda s: default_filter(s) and s >= 1
1344+
"log", x, out, refimpl, filter_=filter_
13051345
)
13061346

13071347

@@ -1310,8 +1350,19 @@ def test_log1p(x):
13101350
out = xp.log1p(x)
13111351
ph.assert_dtype("log1p", in_dtype=x.dtype, out_dtype=out.dtype)
13121352
ph.assert_shape("log1p", out_shape=out.shape, expected=x.shape)
1353+
# There isn't a cmath.log1p, and implementing one isn't straightforward
1354+
# (see
1355+
# https://stackoverflow.com/questions/78318212/unexpected-behaviour-of-log1p-numpy).
1356+
# For now, just use log(1+p) for complex inputs, which should hopefully be
1357+
# fine given the very loose numerical tolerances we use. If it isn't, we
1358+
# can try using something like a series expansion for small p.
1359+
if x.dtype in dh.complex_dtypes:
1360+
refimpl = lambda z: cmath.log(1+z)
1361+
else:
1362+
refimpl = math.log1p
1363+
filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > -1
13131364
unary_assert_against_refimpl(
1314-
"log1p", x, out, math.log1p, filter_=lambda s: default_filter(s) and s >= 1
1365+
"log1p", x, out, refimpl, filter_=filter_
13151366
)
13161367

13171368

@@ -1320,8 +1371,13 @@ def test_log2(x):
13201371
out = xp.log2(x)
13211372
ph.assert_dtype("log2", in_dtype=x.dtype, out_dtype=out.dtype)
13221373
ph.assert_shape("log2", out_shape=out.shape, expected=x.shape)
1374+
if x.dtype in dh.complex_dtypes:
1375+
refimpl = lambda z: cmath.log(z)/math.log(2)
1376+
else:
1377+
refimpl = math.log2
1378+
filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0
13231379
unary_assert_against_refimpl(
1324-
"log2", x, out, math.log2, filter_=lambda s: default_filter(s) and s > 1
1380+
"log2", x, out, refimpl, filter_=filter_
13251381
)
13261382

13271383

@@ -1330,12 +1386,17 @@ def test_log10(x):
13301386
out = xp.log10(x)
13311387
ph.assert_dtype("log10", in_dtype=x.dtype, out_dtype=out.dtype)
13321388
ph.assert_shape("log10", out_shape=out.shape, expected=x.shape)
1389+
if x.dtype in dh.complex_dtypes:
1390+
refimpl = lambda z: cmath.log(z)/math.log(10)
1391+
else:
1392+
refimpl = math.log10
1393+
filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0
13331394
unary_assert_against_refimpl(
1334-
"log10", x, out, math.log10, filter_=lambda s: default_filter(s) and s > 0
1395+
"log10", x, out, refimpl, filter_=filter_
13351396
)
13361397

13371398

1338-
def logaddexp(l: float, r: float) -> float:
1399+
def logaddexp_refimpl(l: float, r: float) -> float:
13391400
return math.log(math.exp(l) + math.exp(r))
13401401

13411402

@@ -1344,7 +1405,7 @@ def test_logaddexp(x1, x2):
13441405
out = xp.logaddexp(x1, x2)
13451406
ph.assert_dtype("logaddexp", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
13461407
ph.assert_result_shape("logaddexp", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1347-
binary_assert_against_refimpl("logaddexp", x1, x2, out, logaddexp)
1408+
binary_assert_against_refimpl("logaddexp", x1, x2, out, logaddexp_refimpl)
13481409

13491410

13501411
@given(*hh.two_mutual_arrays([xp.bool]))
@@ -1521,7 +1582,11 @@ def test_round(x):
15211582
out = xp.round(x)
15221583
ph.assert_dtype("round", in_dtype=x.dtype, out_dtype=out.dtype)
15231584
ph.assert_shape("round", out_shape=out.shape, expected=x.shape)
1524-
unary_assert_against_refimpl("round", x, out, round, strict_check=True)
1585+
if x.dtype in dh.complex_dtypes:
1586+
refimpl = lambda z: complex(round(z.real), round(z.imag))
1587+
else:
1588+
refimpl = round
1589+
unary_assert_against_refimpl("round", x, out, refimpl, strict_check=True)
15251590

15261591

15271592
@pytest.mark.min_version("2023.12")
@@ -1539,13 +1604,12 @@ def test_sign(x):
15391604
out = xp.sign(x)
15401605
ph.assert_dtype("sign", in_dtype=x.dtype, out_dtype=out.dtype)
15411606
ph.assert_shape("sign", out_shape=out.shape, expected=x.shape)
1542-
refimpl = lambda x: x / math.abs(x) if x != 0 else 0
1607+
refimpl = lambda x: x / abs(x) if x != 0 else 0
15431608
unary_assert_against_refimpl(
15441609
"sign",
15451610
x,
15461611
out,
15471612
refimpl,
1548-
filter_=lambda s: s != 0,
15491613
strict_check=True,
15501614
)
15511615

@@ -1555,15 +1619,17 @@ def test_sin(x):
15551619
out = xp.sin(x)
15561620
ph.assert_dtype("sin", in_dtype=x.dtype, out_dtype=out.dtype)
15571621
ph.assert_shape("sin", out_shape=out.shape, expected=x.shape)
1558-
unary_assert_against_refimpl("sin", x, out, math.sin)
1622+
refimpl = cmath.sin if x.dtype in dh.complex_dtypes else math.sin
1623+
unary_assert_against_refimpl("sin", x, out, refimpl)
15591624

15601625

15611626
@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
15621627
def test_sinh(x):
15631628
out = xp.sinh(x)
15641629
ph.assert_dtype("sinh", in_dtype=x.dtype, out_dtype=out.dtype)
15651630
ph.assert_shape("sinh", out_shape=out.shape, expected=x.shape)
1566-
unary_assert_against_refimpl("sinh", x, out, math.sinh)
1631+
refimpl = cmath.sinh if x.dtype in dh.complex_dtypes else math.sinh
1632+
unary_assert_against_refimpl("sinh", x, out, refimpl)
15671633

15681634

15691635
@given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes()))
@@ -1581,8 +1647,10 @@ def test_sqrt(x):
15811647
out = xp.sqrt(x)
15821648
ph.assert_dtype("sqrt", in_dtype=x.dtype, out_dtype=out.dtype)
15831649
ph.assert_shape("sqrt", out_shape=out.shape, expected=x.shape)
1650+
refimpl = cmath.sqrt if x.dtype in dh.complex_dtypes else math.sqrt
1651+
filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s >= 0
15841652
unary_assert_against_refimpl(
1585-
"sqrt", x, out, math.sqrt, filter_=lambda s: default_filter(s) and s >= 0
1653+
"sqrt", x, out, refimpl, filter_=filter_
15861654
)
15871655

15881656

@@ -1605,15 +1673,17 @@ def test_tan(x):
16051673
out = xp.tan(x)
16061674
ph.assert_dtype("tan", in_dtype=x.dtype, out_dtype=out.dtype)
16071675
ph.assert_shape("tan", out_shape=out.shape, expected=x.shape)
1608-
unary_assert_against_refimpl("tan", x, out, math.tan)
1676+
refimpl = cmath.tan if x.dtype in dh.complex_dtypes else math.tan
1677+
unary_assert_against_refimpl("tan", x, out, refimpl)
16091678

16101679

16111680
@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
16121681
def test_tanh(x):
16131682
out = xp.tanh(x)
16141683
ph.assert_dtype("tanh", in_dtype=x.dtype, out_dtype=out.dtype)
16151684
ph.assert_shape("tanh", out_shape=out.shape, expected=x.shape)
1616-
unary_assert_against_refimpl("tanh", x, out, math.tanh)
1685+
refimpl = cmath.tanh if x.dtype in dh.complex_dtypes else math.tanh
1686+
unary_assert_against_refimpl("tanh", x, out, refimpl)
16171687

16181688

16191689
@given(hh.arrays(dtype=hh.real_dtypes, shape=xps.array_shapes()))

0 commit comments

Comments
 (0)