1
1
import operator
2
2
from builtins import all as all_
3
3
4
- from numpy .testing import assert_raises , suppress_warnings
4
+ import numpy .testing
5
5
import numpy as np
6
6
import pytest
7
7
29
29
30
30
import array_api_strict
31
31
32
+ def assert_raises (exception , func , msg = None ):
33
+ with numpy .testing .assert_raises (exception , msg = msg ):
34
+ func ()
35
+
32
36
def test_validate_index ():
33
37
# The indexing tests in the official array API test suite test that the
34
38
# array object correctly handles the subset of indices that are required
@@ -90,7 +94,7 @@ def test_validate_index():
90
94
91
95
def test_operators ():
92
96
# For every operator, we test that it works for the required type
93
- # combinations and raises TypeError otherwise
97
+ # combinations and assert_raises TypeError otherwise
94
98
binary_op_dtypes = {
95
99
"__add__" : "numeric" ,
96
100
"__and__" : "integer_or_boolean" ,
@@ -111,6 +115,7 @@ def test_operators():
111
115
"__truediv__" : "floating" ,
112
116
"__xor__" : "integer_or_boolean" ,
113
117
}
118
+ comparison_ops = ["__eq__" , "__ne__" , "__le__" , "__ge__" , "__lt__" , "__gt__" ]
114
119
# Recompute each time because of in-place ops
115
120
def _array_vals ():
116
121
for d in _integer_dtypes :
@@ -124,7 +129,7 @@ def _array_vals():
124
129
BIG_INT = int (1e30 )
125
130
for op , dtypes in binary_op_dtypes .items ():
126
131
ops = [op ]
127
- if op not in [ "__eq__" , "__ne__" , "__le__" , "__ge__" , "__lt__" , "__gt__" ] :
132
+ if op not in comparison_ops :
128
133
rop = "__r" + op [2 :]
129
134
iop = "__i" + op [2 :]
130
135
ops += [rop , iop ]
@@ -155,16 +160,16 @@ def _array_vals():
155
160
or a .dtype in _complex_floating_dtypes and type (s ) in [complex , float , int ]
156
161
)):
157
162
if a .dtype in _integer_dtypes and s == BIG_INT :
158
- assert_raises (OverflowError , lambda : getattr (a , _op )(s ))
163
+ assert_raises (OverflowError , lambda : getattr (a , _op )(s ), _op )
159
164
else :
160
165
# Only test for no error
161
- with suppress_warnings () as sup :
166
+ with numpy . testing . suppress_warnings () as sup :
162
167
# ignore warnings from pow(BIG_INT)
163
168
sup .filter (RuntimeWarning ,
164
169
"invalid value encountered in power" )
165
170
getattr (a , _op )(s )
166
171
else :
167
- assert_raises (TypeError , lambda : getattr (a , _op )(s ))
172
+ assert_raises (TypeError , lambda : getattr (a , _op )(s ), _op )
168
173
169
174
# Test array op array.
170
175
for _op in ops :
@@ -173,25 +178,25 @@ def _array_vals():
173
178
# See the promotion table in NEP 47 or the array
174
179
# API spec page on type promotion. Mixed kind
175
180
# promotion is not defined.
176
- if (x .dtype == uint64 and y .dtype in [int8 , int16 , int32 , int64 ]
177
- or y .dtype == uint64 and x .dtype in [int8 , int16 , int32 , int64 ]
178
- or x .dtype in _integer_dtypes and y .dtype not in _integer_dtypes
179
- or y .dtype in _integer_dtypes and x .dtype not in _integer_dtypes
180
- or x .dtype in _boolean_dtypes and y .dtype not in _boolean_dtypes
181
- or y .dtype in _boolean_dtypes and x .dtype not in _boolean_dtypes
182
- or x .dtype in _floating_dtypes and y .dtype not in _floating_dtypes
183
- or y .dtype in _floating_dtypes and x .dtype not in _floating_dtypes
184
- ):
185
- assert_raises (TypeError , lambda : getattr (x , _op )(y ))
181
+ if (op not in comparison_ops and
182
+ (x .dtype == uint64 and y .dtype in [int8 , int16 , int32 , int64 ]
183
+ or y .dtype == uint64 and x .dtype in [int8 , int16 , int32 , int64 ]
184
+ or x .dtype in _integer_dtypes and y .dtype not in _integer_dtypes
185
+ or y .dtype in _integer_dtypes and x .dtype not in _integer_dtypes
186
+ or x .dtype in _boolean_dtypes and y .dtype not in _boolean_dtypes
187
+ or y .dtype in _boolean_dtypes and x .dtype not in _boolean_dtypes
188
+ or x .dtype in _floating_dtypes and y .dtype not in _floating_dtypes
189
+ or y .dtype in _floating_dtypes and x .dtype not in _floating_dtypes
190
+ )):
191
+ assert_raises (TypeError , lambda : getattr (x , _op )(y ), _op )
186
192
# Ensure in-place operators only promote to the same dtype as the left operand.
187
193
elif (
188
194
_op .startswith ("__i" )
189
195
and result_type (x .dtype , y .dtype ) != x .dtype
190
196
):
191
- assert_raises (TypeError , lambda : getattr (x , _op )(y ))
197
+ assert_raises (TypeError , lambda : getattr (x , _op )(y ), _op )
192
198
# Ensure only those dtypes that are required for every operator are allowed.
193
- elif (dtypes == "all" and (x .dtype in _boolean_dtypes and y .dtype in _boolean_dtypes
194
- or x .dtype in _numeric_dtypes and y .dtype in _numeric_dtypes )
199
+ elif (dtypes == "all"
195
200
or (dtypes == "real numeric" and x .dtype in _real_numeric_dtypes and y .dtype in _real_numeric_dtypes )
196
201
or (dtypes == "numeric" and x .dtype in _numeric_dtypes and y .dtype in _numeric_dtypes )
197
202
or dtypes == "integer" and x .dtype in _integer_dtypes and y .dtype in _integer_dtypes
@@ -202,7 +207,7 @@ def _array_vals():
202
207
):
203
208
getattr (x , _op )(y )
204
209
else :
205
- assert_raises (TypeError , lambda : getattr (x , _op )(y ))
210
+ assert_raises (TypeError , lambda : getattr (x , _op )(y ), ( x , _op , y ) )
206
211
207
212
unary_op_dtypes = {
208
213
"__abs__" : "numeric" ,
@@ -221,7 +226,7 @@ def _array_vals():
221
226
# Only test for no error
222
227
getattr (a , op )()
223
228
else :
224
- assert_raises (TypeError , lambda : getattr (a , op )())
229
+ assert_raises (TypeError , lambda : getattr (a , op )(), _op )
225
230
226
231
# Finally, matmul() must be tested separately, because it works a bit
227
232
# different from the other operations.
@@ -240,9 +245,9 @@ def _matmul_array_vals():
240
245
or type (s ) == int and a .dtype in _integer_dtypes ):
241
246
# Type promotion is valid, but @ is not allowed on 0-D
242
247
# inputs, so the error is a ValueError
243
- assert_raises (ValueError , lambda : getattr (a , _op )(s ))
248
+ assert_raises (ValueError , lambda : getattr (a , _op )(s ), _op )
244
249
else :
245
- assert_raises (TypeError , lambda : getattr (a , _op )(s ))
250
+ assert_raises (TypeError , lambda : getattr (a , _op )(s ), _op )
246
251
247
252
for x in _matmul_array_vals ():
248
253
for y in _matmul_array_vals ():
@@ -356,20 +361,17 @@ def test_allow_newaxis():
356
361
357
362
def test_disallow_flat_indexing_with_newaxis ():
358
363
a = ones ((3 , 3 , 3 ))
359
- with pytest .raises (IndexError ):
360
- a [None , 0 , 0 ]
364
+ assert_raises (IndexError , lambda : a [None , 0 , 0 ])
361
365
362
366
def test_disallow_mask_with_newaxis ():
363
367
a = ones ((3 , 3 , 3 ))
364
- with pytest .raises (IndexError ):
365
- a [None , asarray (True )]
368
+ assert_raises (IndexError , lambda : a [None , asarray (True )])
366
369
367
370
@pytest .mark .parametrize ("shape" , [(), (5 ,), (3 , 3 , 3 )])
368
371
@pytest .mark .parametrize ("index" , ["string" , False , True ])
369
372
def test_error_on_invalid_index (shape , index ):
370
373
a = ones (shape )
371
- with pytest .raises (IndexError ):
372
- a [index ]
374
+ assert_raises (IndexError , lambda : a [index ])
373
375
374
376
def test_mask_0d_array_without_errors ():
375
377
a = ones (())
@@ -380,10 +382,8 @@ def test_mask_0d_array_without_errors():
380
382
)
381
383
def test_error_on_invalid_index_with_ellipsis (i ):
382
384
a = ones ((3 , 3 , 3 ))
383
- with pytest .raises (IndexError ):
384
- a [..., i ]
385
- with pytest .raises (IndexError ):
386
- a [i , ...]
385
+ assert_raises (IndexError , lambda : a [..., i ])
386
+ assert_raises (IndexError , lambda : a [i , ...])
387
387
388
388
def test_array_keys_use_private_array ():
389
389
"""
@@ -400,8 +400,7 @@ def test_array_keys_use_private_array():
400
400
401
401
a = ones ((0 ,), dtype = bool_ )
402
402
key = ones ((0 , 0 ), dtype = bool_ )
403
- with pytest .raises (IndexError ):
404
- a [key ]
403
+ assert_raises (IndexError , lambda : a [key ])
405
404
406
405
def test_array_namespace ():
407
406
a = ones ((3 , 3 ))
@@ -422,16 +421,16 @@ def test_array_namespace():
422
421
assert a .__array_namespace__ (api_version = "2021.12" ) is array_api_strict
423
422
assert array_api_strict .__array_api_version__ == "2021.12"
424
423
425
- pytest . raises (ValueError , lambda : a .__array_namespace__ (api_version = "2021.11" ))
426
- pytest . raises (ValueError , lambda : a .__array_namespace__ (api_version = "2024.12" ))
424
+ assert_raises (ValueError , lambda : a .__array_namespace__ (api_version = "2021.11" ))
425
+ assert_raises (ValueError , lambda : a .__array_namespace__ (api_version = "2024.12" ))
427
426
428
427
def test_iter ():
429
- pytest . raises (TypeError , lambda : iter (asarray (3 )))
428
+ assert_raises (TypeError , lambda : iter (asarray (3 )))
430
429
assert list (ones (3 )) == [asarray (1. ), asarray (1. ), asarray (1. )]
431
430
assert all_ (isinstance (a , Array ) for a in iter (ones (3 )))
432
431
assert all_ (a .shape == () for a in iter (ones (3 )))
433
432
assert all_ (a .dtype == float64 for a in iter (ones (3 )))
434
- pytest . raises (TypeError , lambda : iter (ones ((3 , 3 ))))
433
+ assert_raises (TypeError , lambda : iter (ones ((3 , 3 ))))
435
434
436
435
@pytest .mark .parametrize ("api_version" , ['2021.12' , '2022.12' , '2023.12' ])
437
436
def dlpack_2023_12 (api_version ):
@@ -447,17 +446,17 @@ def dlpack_2023_12(api_version):
447
446
448
447
449
448
exception = NotImplementedError if api_version >= '2023.12' else ValueError
450
- pytest . raises (exception , lambda :
449
+ assert_raises (exception , lambda :
451
450
a .__dlpack__ (dl_device = CPU_DEVICE ))
452
- pytest . raises (exception , lambda :
451
+ assert_raises (exception , lambda :
453
452
a .__dlpack__ (dl_device = None ))
454
- pytest . raises (exception , lambda :
453
+ assert_raises (exception , lambda :
455
454
a .__dlpack__ (max_version = (1 , 0 )))
456
- pytest . raises (exception , lambda :
455
+ assert_raises (exception , lambda :
457
456
a .__dlpack__ (max_version = None ))
458
- pytest . raises (exception , lambda :
457
+ assert_raises (exception , lambda :
459
458
a .__dlpack__ (copy = False ))
460
- pytest . raises (exception , lambda :
459
+ assert_raises (exception , lambda :
461
460
a .__dlpack__ (copy = True ))
462
- pytest . raises (exception , lambda :
461
+ assert_raises (exception , lambda :
463
462
a .__dlpack__ (copy = None ))
0 commit comments