Skip to content

Commit f5723ad

Browse files
committed
min_version() marker and other versioning nicities
1 parent 2d6b2d8 commit f5723ad

File tree

5 files changed

+75
-27
lines changed

5 files changed

+75
-27
lines changed

README.md

+6
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,12 @@ library to fail.
160160

161161
### Configuration
162162

163+
#### API version
164+
165+
You can specify the API version to use when testing via the
166+
ARRAY_API_TESTS_VERSION environment variable. Currently this defaults to
167+
`"2021.12"`.
168+
163169
#### CI flag
164170

165171
Use the `--ci` flag to run only the primary and special cases tests. You can

array_api_tests/__init__.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from functools import wraps
2+
from os import getenv
23

34
from hypothesis import strategies as st
45
from hypothesis.extra import array_api
56

7+
from . import _version
68
from ._array_module import mod as _xp
79

8-
__all__ = ["xps"]
10+
__all__ = ["COMPLEX_VER", "api_version", "xps"]
11+
12+
13+
COMPLEX_VER: str = "2022.12"
914

1015

1116
# We monkey patch floats() to always disable subnormals as they are out-of-scope
@@ -41,9 +46,7 @@ def _from_dtype(*a, **kw):
4146
pass
4247

4348

44-
xps = array_api.make_strategies_namespace(_xp, api_version="2021.12")
45-
46-
47-
from . import _version
49+
api_version = getenv("ARRAY_API_TESTS_VERSION", "2021.12")
50+
xps = array_api.make_strategies_namespace(_xp, api_version=api_version)
4851

4952
__version__ = _version.get_versions()["version"]

array_api_tests/_array_module.py

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __repr__(self):
5858
"uint8", "uint16", "uint32", "uint64",
5959
"int8", "int16", "int32", "int64",
6060
"float32", "float64",
61+
"complex64", "complex128",
6162
]
6263
_constants = ["e", "inf", "nan", "pi"]
6364
_funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs]

array_api_tests/test_operators_and_elementwise_functions.py

+45-22
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from hypothesis import strategies as st
1313
from hypothesis.control import reject
1414

15-
from . import _array_module as xp
15+
from . import COMPLEX_VER, _array_module as xp, api_version
1616
from . import array_helpers as ah
1717
from . import dtype_helpers as dh
1818
from . import hypothesis_helpers as hh
@@ -35,7 +35,10 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
3535

3636

3737
def all_floating_dtypes() -> st.SearchStrategy[DataType]:
38-
return xps.floating_dtypes() | xps.complex_dtypes()
38+
strat = xps.floating_dtypes()
39+
if api_version >= COMPLEX_VER:
40+
strat |= xps.complex_dtypes()
41+
return strat
3942

4043

4144
class OnewayPromotableDtypes(NamedTuple):
@@ -492,10 +495,15 @@ def __repr__(self):
492495

493496

494497
def make_unary_params(
495-
elwise_func_name: str, dtypes: Sequence[DataType]
498+
elwise_func_name: str,
499+
dtypes: Sequence[DataType],
500+
*,
501+
min_version: str = "2021.12",
496502
) -> List[Param[UnaryParamContext]]:
497503
if hh.FILTER_UNDEFINED_DTYPES:
498504
dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)]
505+
if api_version < COMPLEX_VER:
506+
dtypes = [d for d in dtypes if d not in dh.complex_dtypes]
499507
dtypes_strat = st.sampled_from(dtypes)
500508
strat = xps.arrays(dtype=dtypes_strat, shape=hh.shapes())
501509
func_ctx = UnaryParamContext(
@@ -505,7 +513,16 @@ def make_unary_params(
505513
op_ctx = UnaryParamContext(
506514
func_name=op_name, func=lambda x: getattr(x, op_name)(), strat=strat
507515
)
508-
return [pytest.param(func_ctx, id=func_ctx.id), pytest.param(op_ctx, id=op_ctx.id)]
516+
if api_version < min_version:
517+
marks = pytest.mark.skip(
518+
reason=f"requires ARRAY_API_TESTS_VERSION=>{min_version}"
519+
)
520+
else:
521+
marks = ()
522+
return [
523+
pytest.param(func_ctx, id=func_ctx.id, marks=marks),
524+
pytest.param(op_ctx, id=op_ctx.id, marks=marks),
525+
]
509526

510527

511528
class FuncType(Enum):
@@ -948,12 +965,14 @@ def test_ceil(x):
948965
unary_assert_against_refimpl("ceil", x, out, math.ceil, strict_check=True)
949966

950967

951-
@given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes()))
952-
def test_conj(x):
953-
out = xp.conj(x)
954-
ph.assert_dtype("conj", x.dtype, out.dtype)
955-
ph.assert_shape("conj", out.shape, x.shape)
956-
unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate"))
968+
if api_version >= COMPLEX_VER:
969+
970+
@given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes()))
971+
def test_conj(x):
972+
out = xp.conj(x)
973+
ph.assert_dtype("conj", x.dtype, out.dtype)
974+
ph.assert_shape("conj", out.shape, x.shape)
975+
unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate"))
957976

958977

959978
@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
@@ -1108,12 +1127,14 @@ def test_greater_equal(ctx, data):
11081127
)
11091128

11101129

1111-
@given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes()))
1112-
def test_imag(x):
1113-
out = xp.imag(x)
1114-
ph.assert_dtype("imag", x.dtype, out.dtype, dh.dtype_components[x.dtype])
1115-
ph.assert_shape("imag", out.shape, x.shape)
1116-
unary_assert_against_refimpl("imag", x, out, operator.attrgetter("imag"))
1130+
if api_version >= COMPLEX_VER:
1131+
1132+
@given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes()))
1133+
def test_imag(x):
1134+
out = xp.imag(x)
1135+
ph.assert_dtype("imag", x.dtype, out.dtype, dh.dtype_components[x.dtype])
1136+
ph.assert_shape("imag", out.shape, x.shape)
1137+
unary_assert_against_refimpl("imag", x, out, operator.attrgetter("imag"))
11171138

11181139

11191140
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
@@ -1357,12 +1378,14 @@ def test_pow(ctx, data):
13571378
# Values testing pow is too finicky
13581379

13591380

1360-
@given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes()))
1361-
def test_real(x):
1362-
out = xp.real(x)
1363-
ph.assert_dtype("real", x.dtype, out.dtype, dh.dtype_components[x.dtype])
1364-
ph.assert_shape("real", out.shape, x.shape)
1365-
unary_assert_against_refimpl("real", x, out, operator.attrgetter("real"))
1381+
if api_version >= COMPLEX_VER:
1382+
1383+
@given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes()))
1384+
def test_real(x):
1385+
out = xp.real(x)
1386+
ph.assert_dtype("real", x.dtype, out.dtype, dh.dtype_components[x.dtype])
1387+
ph.assert_shape("real", out.shape, x.shape)
1388+
unary_assert_against_refimpl("real", x, out, operator.attrgetter("real"))
13661389

13671390

13681391
@pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.real_dtypes))

conftest.py

+15
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pytest import mark
66

77
from array_api_tests import _array_module as xp
8+
from array_api_tests import api_version
89
from array_api_tests._array_module import _UndefinedStub
910

1011
from reporting import pytest_metadata, pytest_json_modifyreport, add_extra_json_metadata # noqa
@@ -59,6 +60,10 @@ def pytest_configure(config):
5960
"markers", "data_dependent_shapes: output shapes are dependent on inputs"
6061
)
6162
config.addinivalue_line("markers", "ci: primary test")
63+
config.addinivalue_line(
64+
"markers",
65+
"min_version(api_version): run when greater or equal to api_version",
66+
)
6267
# Hypothesis
6368
hypothesis_max_examples = config.getoption("--hypothesis-max-examples")
6469
disable_deadline = config.getoption("--hypothesis-disable-deadline")
@@ -126,3 +131,13 @@ def pytest_collection_modifyitems(config, items):
126131
ci_mark = next((m for m in markers if m.name == "ci"), None)
127132
if ci_mark is None:
128133
item.add_marker(mark.skip(reason="disabled via --ci"))
134+
# skip if test is for greater api_version
135+
ver_mark = next((m for m in markers if m.name == "min_version"), None)
136+
if ver_mark is not None:
137+
min_version = ver_mark.args[0]
138+
if api_version < min_version:
139+
item.add_marker(
140+
mark.skip(
141+
reason=f"requires ARRAY_API_TESTS_VERSION=>{min_version}"
142+
)
143+
)

0 commit comments

Comments
 (0)