12
12
from hypothesis import strategies as st
13
13
from hypothesis .control import reject
14
14
15
- from . import _array_module as xp
15
+ from . import COMPLEX_VER , _array_module as xp , api_version
16
16
from . import array_helpers as ah
17
17
from . import dtype_helpers as dh
18
18
from . import hypothesis_helpers as hh
@@ -35,7 +35,10 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
35
35
36
36
37
37
def all_floating_dtypes () -> st .SearchStrategy [DataType ]:
38
- return xps .floating_dtypes () | xps .complex_dtypes ()
38
+ strat = xps .floating_dtypes ()
39
+ if api_version >= COMPLEX_VER :
40
+ strat |= xps .complex_dtypes ()
41
+ return strat
39
42
40
43
41
44
class OnewayPromotableDtypes (NamedTuple ):
@@ -492,10 +495,15 @@ def __repr__(self):
492
495
493
496
494
497
def make_unary_params (
495
- elwise_func_name : str , dtypes : Sequence [DataType ]
498
+ elwise_func_name : str ,
499
+ dtypes : Sequence [DataType ],
500
+ * ,
501
+ min_version : str = "2021.12" ,
496
502
) -> List [Param [UnaryParamContext ]]:
497
503
if hh .FILTER_UNDEFINED_DTYPES :
498
504
dtypes = [d for d in dtypes if not isinstance (d , xp ._UndefinedStub )]
505
+ if api_version < COMPLEX_VER :
506
+ dtypes = [d for d in dtypes if d not in dh .complex_dtypes ]
499
507
dtypes_strat = st .sampled_from (dtypes )
500
508
strat = xps .arrays (dtype = dtypes_strat , shape = hh .shapes ())
501
509
func_ctx = UnaryParamContext (
@@ -505,7 +513,16 @@ def make_unary_params(
505
513
op_ctx = UnaryParamContext (
506
514
func_name = op_name , func = lambda x : getattr (x , op_name )(), strat = strat
507
515
)
508
- return [pytest .param (func_ctx , id = func_ctx .id ), pytest .param (op_ctx , id = op_ctx .id )]
516
+ if api_version < min_version :
517
+ marks = pytest .mark .skip (
518
+ reason = f"requires ARRAY_API_TESTS_VERSION=>{ min_version } "
519
+ )
520
+ else :
521
+ marks = ()
522
+ return [
523
+ pytest .param (func_ctx , id = func_ctx .id , marks = marks ),
524
+ pytest .param (op_ctx , id = op_ctx .id , marks = marks ),
525
+ ]
509
526
510
527
511
528
class FuncType (Enum ):
@@ -948,12 +965,14 @@ def test_ceil(x):
948
965
unary_assert_against_refimpl ("ceil" , x , out , math .ceil , strict_check = True )
949
966
950
967
951
- @given (xps .arrays (dtype = xps .complex_dtypes (), shape = hh .shapes ()))
952
- def test_conj (x ):
953
- out = xp .conj (x )
954
- ph .assert_dtype ("conj" , x .dtype , out .dtype )
955
- ph .assert_shape ("conj" , out .shape , x .shape )
956
- unary_assert_against_refimpl ("conj" , x , out , operator .methodcaller ("conjugate" ))
968
+ if api_version >= COMPLEX_VER :
969
+
970
+ @given (xps .arrays (dtype = xps .complex_dtypes (), shape = hh .shapes ()))
971
+ def test_conj (x ):
972
+ out = xp .conj (x )
973
+ ph .assert_dtype ("conj" , x .dtype , out .dtype )
974
+ ph .assert_shape ("conj" , out .shape , x .shape )
975
+ unary_assert_against_refimpl ("conj" , x , out , operator .methodcaller ("conjugate" ))
957
976
958
977
959
978
@given (xps .arrays (dtype = all_floating_dtypes (), shape = hh .shapes ()))
@@ -1108,12 +1127,14 @@ def test_greater_equal(ctx, data):
1108
1127
)
1109
1128
1110
1129
1111
- @given (xps .arrays (dtype = xps .complex_dtypes (), shape = hh .shapes ()))
1112
- def test_imag (x ):
1113
- out = xp .imag (x )
1114
- ph .assert_dtype ("imag" , x .dtype , out .dtype , dh .dtype_components [x .dtype ])
1115
- ph .assert_shape ("imag" , out .shape , x .shape )
1116
- unary_assert_against_refimpl ("imag" , x , out , operator .attrgetter ("imag" ))
1130
+ if api_version >= COMPLEX_VER :
1131
+
1132
+ @given (xps .arrays (dtype = xps .complex_dtypes (), shape = hh .shapes ()))
1133
+ def test_imag (x ):
1134
+ out = xp .imag (x )
1135
+ ph .assert_dtype ("imag" , x .dtype , out .dtype , dh .dtype_components [x .dtype ])
1136
+ ph .assert_shape ("imag" , out .shape , x .shape )
1137
+ unary_assert_against_refimpl ("imag" , x , out , operator .attrgetter ("imag" ))
1117
1138
1118
1139
1119
1140
@given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes ()))
@@ -1357,12 +1378,14 @@ def test_pow(ctx, data):
1357
1378
# Values testing pow is too finicky
1358
1379
1359
1380
1360
- @given (xps .arrays (dtype = xps .complex_dtypes (), shape = hh .shapes ()))
1361
- def test_real (x ):
1362
- out = xp .real (x )
1363
- ph .assert_dtype ("real" , x .dtype , out .dtype , dh .dtype_components [x .dtype ])
1364
- ph .assert_shape ("real" , out .shape , x .shape )
1365
- unary_assert_against_refimpl ("real" , x , out , operator .attrgetter ("real" ))
1381
+ if api_version >= COMPLEX_VER :
1382
+
1383
+ @given (xps .arrays (dtype = xps .complex_dtypes (), shape = hh .shapes ()))
1384
+ def test_real (x ):
1385
+ out = xp .real (x )
1386
+ ph .assert_dtype ("real" , x .dtype , out .dtype , dh .dtype_components [x .dtype ])
1387
+ ph .assert_shape ("real" , out .shape , x .shape )
1388
+ unary_assert_against_refimpl ("real" , x , out , operator .attrgetter ("real" ))
1366
1389
1367
1390
1368
1391
@pytest .mark .parametrize ("ctx" , make_binary_params ("remainder" , dh .real_dtypes ))
0 commit comments