Skip to content

Commit 5d890cd

Browse files
authored
Merge pull request #327 from ev-br/api_version_guards
MAINT: simplify API version guards
2 parents 606cc4d + 0e91b5e commit 5d890cd

File tree

1 file changed

+21
-23
lines changed

1 file changed

+21
-23
lines changed

Diff for: array_api_tests/test_operators_and_elementwise_functions.py

+21-23
Original file line numberDiff line numberDiff line change
@@ -1061,14 +1061,13 @@ def refimpl(_x, _min, _max):
10611061
)
10621062

10631063

1064-
if api_version >= "2022.12":
1065-
1066-
@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
1067-
def test_conj(x):
1068-
out = xp.conj(x)
1069-
ph.assert_dtype("conj", in_dtype=x.dtype, out_dtype=out.dtype)
1070-
ph.assert_shape("conj", out_shape=out.shape, expected=x.shape)
1071-
unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate"))
1064+
@pytest.mark.min_version("2022.12")
1065+
@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
1066+
def test_conj(x):
1067+
out = xp.conj(x)
1068+
ph.assert_dtype("conj", in_dtype=x.dtype, out_dtype=out.dtype)
1069+
ph.assert_shape("conj", out_shape=out.shape, expected=x.shape)
1070+
unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate"))
10721071

10731072

10741073
@pytest.mark.min_version("2023.12")
@@ -1263,14 +1262,14 @@ def test_hypot(x1, x2):
12631262
binary_assert_against_refimpl("hypot", x1, x2, out, math.hypot)
12641263

12651264

1266-
if api_version >= "2022.12":
12671265

1268-
@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
1269-
def test_imag(x):
1270-
out = xp.imag(x)
1271-
ph.assert_dtype("imag", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype])
1272-
ph.assert_shape("imag", out_shape=out.shape, expected=x.shape)
1273-
unary_assert_against_refimpl("imag", x, out, operator.attrgetter("imag"))
1266+
@pytest.mark.min_version("2022.12")
1267+
@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
1268+
def test_imag(x):
1269+
out = xp.imag(x)
1270+
ph.assert_dtype("imag", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype])
1271+
ph.assert_shape("imag", out_shape=out.shape, expected=x.shape)
1272+
unary_assert_against_refimpl("imag", x, out, operator.attrgetter("imag"))
12741273

12751274

12761275
@given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes()))
@@ -1559,14 +1558,13 @@ def test_pow(ctx, data):
15591558
# Values testing pow is too finicky
15601559

15611560

1562-
if api_version >= "2022.12":
1563-
1564-
@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
1565-
def test_real(x):
1566-
out = xp.real(x)
1567-
ph.assert_dtype("real", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype])
1568-
ph.assert_shape("real", out_shape=out.shape, expected=x.shape)
1569-
unary_assert_against_refimpl("real", x, out, operator.attrgetter("real"))
1561+
@pytest.mark.min_version("2022.12")
1562+
@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
1563+
def test_real(x):
1564+
out = xp.real(x)
1565+
ph.assert_dtype("real", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype])
1566+
ph.assert_shape("real", out_shape=out.shape, expected=x.shape)
1567+
unary_assert_against_refimpl("real", x, out, operator.attrgetter("real"))
15701568

15711569

15721570
@pytest.mark.skip(reason="flaky")

0 commit comments

Comments
 (0)