1
1
import inspect
2
+ from itertools import chain
2
3
3
4
import pytest
4
5
5
6
from ._array_module import mod , mod_name , ones , eye , float64 , bool , int64 , _UndefinedStub
6
7
from .pytest_helpers import raises , doesnt_raise
7
8
from . import dtype_helpers as dh
8
9
9
- from . import function_stubs
10
10
from . import stubs
11
11
12
12
13
- def stub_module (name ):
14
- for m in stubs .extensions :
15
- if name in getattr (function_stubs , m ).__all__ :
16
- return m
13
+ def extension_module (name ) -> bool :
14
+ for funcs in stubs .extension_to_funcs .values ():
15
+ for func in funcs :
16
+ if name == func .__name__ :
17
+ return True
18
+ else :
19
+ return False
17
20
18
- def extension_module (name ):
19
- return name in stubs .extensions and name in function_stubs .__all__
20
21
21
- extension_module_names = []
22
- for n in function_stubs .__all__ :
23
- if extension_module (n ):
24
- extension_module_names .extend ([f'{ n } .{ i } ' for i in getattr (function_stubs , n ).__all__ ])
22
+ params = []
23
+ for name in [f .__name__ for funcs in stubs .category_to_funcs .values () for f in funcs ]:
24
+ if name in ["where" , "expand_dims" , "reshape" ]:
25
+ params .append (pytest .param (name , marks = pytest .mark .skip (reason = "faulty test" )))
26
+ else :
27
+ params .append (name )
25
28
26
29
27
- params = []
28
- for name in function_stubs .__all__ :
29
- marks = []
30
- if extension_module (name ):
31
- marks .append (pytest .mark .xp_extension (name ))
32
- params .append (pytest .param (name , marks = marks ))
33
- for name in extension_module_names :
34
- ext = name .split ('.' )[0 ]
35
- mark = pytest .mark .xp_extension (ext )
36
- params .append (pytest .param (name , marks = [mark ]))
30
+ for ext , name in [(ext , f .__name__ ) for ext , funcs in stubs .extension_to_funcs .items () for f in funcs ]:
31
+ params .append (pytest .param (name , marks = pytest .mark .xp_extension (ext )))
37
32
38
33
39
- def array_method (name ):
40
- return stub_module ( name ) == 'array_object'
34
+ def array_method (name ) -> bool :
35
+ return name in [ f . __name__ for f in stubs . array_methods ]
41
36
42
- def function_category (name ):
43
- return stub_module (name ).rsplit ('_' , 1 )[0 ].replace ('_' , ' ' )
37
+ def function_category (name ) -> str :
38
+ for category , funcs in chain (stubs .category_to_funcs .items (), stubs .extension_to_funcs .items ()):
39
+ for func in funcs :
40
+ if name == func .__name__ :
41
+ return category
44
42
45
43
def example_argument (arg , func_name , dtype ):
46
44
"""
@@ -138,7 +136,7 @@ def example_argument(arg, func_name, dtype):
138
136
return ones ((3 ,), dtype = dtype )
139
137
# Linear algebra functions tend to error if the input isn't "nice" as
140
138
# a matrix
141
- elif arg .startswith ('x' ) and func_name in function_stubs . linalg . __all__ :
139
+ elif arg .startswith ('x' ) and func_name in [ f . __name__ for f in stubs . extension_to_funcs [ "linalg" ]] :
142
140
return eye (3 )
143
141
return known_args [arg ]
144
142
else :
@@ -147,13 +145,15 @@ def example_argument(arg, func_name, dtype):
147
145
@pytest .mark .parametrize ('name' , params )
148
146
def test_has_names (name ):
149
147
if extension_module (name ):
150
- assert hasattr (mod , name ), f'{ mod_name } is missing the { name } extension'
151
- elif '.' in name :
152
- extension_mod , name = name .split ('.' )
153
- assert hasattr (getattr (mod , extension_mod ), name ), f"{ mod_name } is missing the { function_category (name )} extension function { name } ()"
148
+ ext = next (
149
+ ext for ext , funcs in stubs .extension_to_funcs .items ()
150
+ if name in [f .__name__ for f in funcs ]
151
+ )
152
+ ext_mod = getattr (mod , ext )
153
+ assert hasattr (ext_mod , name ), f"{ mod_name } is missing the { function_category (name )} extension function { name } ()"
154
154
elif array_method (name ):
155
155
arr = ones ((1 , 1 ))
156
- if getattr ( function_stubs . array_object , name ) is None :
156
+ if name not in [ f . __name__ for f in stubs . array_methods ] :
157
157
assert hasattr (arr , name ), f"The array object is missing the attribute { name } "
158
158
else :
159
159
assert hasattr (arr , name ), f"The array object is missing the method { name } ()"
@@ -192,14 +192,12 @@ def test_function_positional_args(name):
192
192
_mod = ones ((), dtype = float64 )
193
193
else :
194
194
_mod = example_argument ('self' , name , dtype )
195
- stub_func = getattr (function_stubs , name )
196
195
elif '.' in name :
197
196
extension_module_name , name = name .split ('.' )
198
197
_mod = getattr (mod , extension_module_name )
199
- stub_func = getattr (getattr (function_stubs , extension_module_name ), name )
200
198
else :
201
199
_mod = mod
202
- stub_func = getattr ( function_stubs , name )
200
+ stub_func = stubs . name_to_func [ name ]
203
201
204
202
if not hasattr (_mod , name ):
205
203
pytest .skip (f"{ mod_name } does not have { name } (), skipping." )
@@ -245,14 +243,12 @@ def test_function_keyword_only_args(name):
245
243
246
244
if array_method (name ):
247
245
_mod = ones ((1 , 1 ))
248
- stub_func = getattr (function_stubs , name )
249
246
elif '.' in name :
250
247
extension_module_name , name = name .split ('.' )
251
248
_mod = getattr (mod , extension_module_name )
252
- stub_func = getattr (getattr (function_stubs , extension_module_name ), name )
253
249
else :
254
250
_mod = mod
255
- stub_func = getattr ( function_stubs , name )
251
+ stub_func = stubs . name_to_func [ name ]
256
252
257
253
if not hasattr (_mod , name ):
258
254
pytest .skip (f"{ mod_name } does not have { name } (), skipping." )
0 commit comments