1
1
import operator
2
2
from builtins import all as all_
3
3
4
- import numpy .testing
4
+ from numpy .testing import assert_raises , suppress_warnings
5
5
import numpy as np
6
6
import pytest
7
7
29
29
30
30
import array_api_strict
31
31
32
- def assert_raises (exception , func , msg = None ):
33
- with numpy .testing .assert_raises (exception , msg = msg ):
34
- func ()
35
-
36
32
def test_validate_index ():
37
33
# The indexing tests in the official array API test suite test that the
38
34
# array object correctly handles the subset of indices that are required
@@ -94,7 +90,7 @@ def test_validate_index():
94
90
95
91
def test_operators ():
96
92
# For every operator, we test that it works for the required type
97
- # combinations and assert_raises TypeError otherwise
93
+ # combinations and raises TypeError otherwise
98
94
binary_op_dtypes = {
99
95
"__add__" : "numeric" ,
100
96
"__and__" : "integer_or_boolean" ,
@@ -115,7 +111,6 @@ def test_operators():
115
111
"__truediv__" : "floating" ,
116
112
"__xor__" : "integer_or_boolean" ,
117
113
}
118
- comparison_ops = ["__eq__" , "__ne__" , "__le__" , "__ge__" , "__lt__" , "__gt__" ]
119
114
# Recompute each time because of in-place ops
120
115
def _array_vals ():
121
116
for d in _integer_dtypes :
@@ -129,7 +124,7 @@ def _array_vals():
129
124
BIG_INT = int (1e30 )
130
125
for op , dtypes in binary_op_dtypes .items ():
131
126
ops = [op ]
132
- if op not in comparison_ops :
127
+ if op not in [ "__eq__" , "__ne__" , "__le__" , "__ge__" , "__lt__" , "__gt__" ] :
133
128
rop = "__r" + op [2 :]
134
129
iop = "__i" + op [2 :]
135
130
ops += [rop , iop ]
@@ -160,16 +155,16 @@ def _array_vals():
160
155
or a .dtype in _complex_floating_dtypes and type (s ) in [complex , float , int ]
161
156
)):
162
157
if a .dtype in _integer_dtypes and s == BIG_INT :
163
- assert_raises (OverflowError , lambda : getattr (a , _op )(s ), _op )
158
+ assert_raises (OverflowError , lambda : getattr (a , _op )(s ))
164
159
else :
165
160
# Only test for no error
166
- with numpy . testing . suppress_warnings () as sup :
161
+ with suppress_warnings () as sup :
167
162
# ignore warnings from pow(BIG_INT)
168
163
sup .filter (RuntimeWarning ,
169
164
"invalid value encountered in power" )
170
165
getattr (a , _op )(s )
171
166
else :
172
- assert_raises (TypeError , lambda : getattr (a , _op )(s ), _op )
167
+ assert_raises (TypeError , lambda : getattr (a , _op )(s ))
173
168
174
169
# Test array op array.
175
170
for _op in ops :
@@ -178,25 +173,25 @@ def _array_vals():
178
173
# See the promotion table in NEP 47 or the array
179
174
# API spec page on type promotion. Mixed kind
180
175
# promotion is not defined.
181
- if (op not in comparison_ops and
182
- (x .dtype == uint64 and y .dtype in [int8 , int16 , int32 , int64 ]
183
- or y .dtype == uint64 and x .dtype in [int8 , int16 , int32 , int64 ]
184
- or x .dtype in _integer_dtypes and y .dtype not in _integer_dtypes
185
- or y .dtype in _integer_dtypes and x .dtype not in _integer_dtypes
186
- or x .dtype in _boolean_dtypes and y .dtype not in _boolean_dtypes
187
- or y .dtype in _boolean_dtypes and x .dtype not in _boolean_dtypes
188
- or x .dtype in _floating_dtypes and y .dtype not in _floating_dtypes
189
- or y .dtype in _floating_dtypes and x .dtype not in _floating_dtypes
190
- )):
191
- assert_raises (TypeError , lambda : getattr (x , _op )(y ), _op )
176
+ if (x .dtype == uint64 and y .dtype in [int8 , int16 , int32 , int64 ]
177
+ or y .dtype == uint64 and x .dtype in [int8 , int16 , int32 , int64 ]
178
+ or x .dtype in _integer_dtypes and y .dtype not in _integer_dtypes
179
+ or y .dtype in _integer_dtypes and x .dtype not in _integer_dtypes
180
+ or x .dtype in _boolean_dtypes and y .dtype not in _boolean_dtypes
181
+ or y .dtype in _boolean_dtypes and x .dtype not in _boolean_dtypes
182
+ or x .dtype in _floating_dtypes and y .dtype not in _floating_dtypes
183
+ or y .dtype in _floating_dtypes and x .dtype not in _floating_dtypes
184
+ ):
185
+ assert_raises (TypeError , lambda : getattr (x , _op )(y ))
192
186
# Ensure in-place operators only promote to the same dtype as the left operand.
193
187
elif (
194
188
_op .startswith ("__i" )
195
189
and result_type (x .dtype , y .dtype ) != x .dtype
196
190
):
197
- assert_raises (TypeError , lambda : getattr (x , _op )(y ), _op )
191
+ assert_raises (TypeError , lambda : getattr (x , _op )(y ))
198
192
# Ensure only those dtypes that are required for every operator are allowed.
199
- elif (dtypes == "all"
193
+ elif (dtypes == "all" and (x .dtype in _boolean_dtypes and y .dtype in _boolean_dtypes
194
+ or x .dtype in _numeric_dtypes and y .dtype in _numeric_dtypes )
200
195
or (dtypes == "real numeric" and x .dtype in _real_numeric_dtypes and y .dtype in _real_numeric_dtypes )
201
196
or (dtypes == "numeric" and x .dtype in _numeric_dtypes and y .dtype in _numeric_dtypes )
202
197
or dtypes == "integer" and x .dtype in _integer_dtypes and y .dtype in _integer_dtypes
@@ -207,7 +202,7 @@ def _array_vals():
207
202
):
208
203
getattr (x , _op )(y )
209
204
else :
210
- assert_raises (TypeError , lambda : getattr (x , _op )(y ), ( x , _op , y ) )
205
+ assert_raises (TypeError , lambda : getattr (x , _op )(y ))
211
206
212
207
unary_op_dtypes = {
213
208
"__abs__" : "numeric" ,
@@ -226,7 +221,7 @@ def _array_vals():
226
221
# Only test for no error
227
222
getattr (a , op )()
228
223
else :
229
- assert_raises (TypeError , lambda : getattr (a , op )(), _op )
224
+ assert_raises (TypeError , lambda : getattr (a , op )())
230
225
231
226
# Finally, matmul() must be tested separately, because it works a bit
232
227
# different from the other operations.
@@ -245,9 +240,9 @@ def _matmul_array_vals():
245
240
or type (s ) == int and a .dtype in _integer_dtypes ):
246
241
# Type promotion is valid, but @ is not allowed on 0-D
247
242
# inputs, so the error is a ValueError
248
- assert_raises (ValueError , lambda : getattr (a , _op )(s ), _op )
243
+ assert_raises (ValueError , lambda : getattr (a , _op )(s ))
249
244
else :
250
- assert_raises (TypeError , lambda : getattr (a , _op )(s ), _op )
245
+ assert_raises (TypeError , lambda : getattr (a , _op )(s ))
251
246
252
247
for x in _matmul_array_vals ():
253
248
for y in _matmul_array_vals ():
@@ -361,17 +356,20 @@ def test_allow_newaxis():
361
356
362
357
def test_disallow_flat_indexing_with_newaxis ():
363
358
a = ones ((3 , 3 , 3 ))
364
- assert_raises (IndexError , lambda : a [None , 0 , 0 ])
359
+ with pytest .raises (IndexError ):
360
+ a [None , 0 , 0 ]
365
361
366
362
def test_disallow_mask_with_newaxis ():
367
363
a = ones ((3 , 3 , 3 ))
368
- assert_raises (IndexError , lambda : a [None , asarray (True )])
364
+ with pytest .raises (IndexError ):
365
+ a [None , asarray (True )]
369
366
370
367
@pytest .mark .parametrize ("shape" , [(), (5 ,), (3 , 3 , 3 )])
371
368
@pytest .mark .parametrize ("index" , ["string" , False , True ])
372
369
def test_error_on_invalid_index (shape , index ):
373
370
a = ones (shape )
374
- assert_raises (IndexError , lambda : a [index ])
371
+ with pytest .raises (IndexError ):
372
+ a [index ]
375
373
376
374
def test_mask_0d_array_without_errors ():
377
375
a = ones (())
@@ -382,8 +380,10 @@ def test_mask_0d_array_without_errors():
382
380
)
383
381
def test_error_on_invalid_index_with_ellipsis (i ):
384
382
a = ones ((3 , 3 , 3 ))
385
- assert_raises (IndexError , lambda : a [..., i ])
386
- assert_raises (IndexError , lambda : a [i , ...])
383
+ with pytest .raises (IndexError ):
384
+ a [..., i ]
385
+ with pytest .raises (IndexError ):
386
+ a [i , ...]
387
387
388
388
def test_array_keys_use_private_array ():
389
389
"""
@@ -400,7 +400,8 @@ def test_array_keys_use_private_array():
400
400
401
401
a = ones ((0 ,), dtype = bool_ )
402
402
key = ones ((0 , 0 ), dtype = bool_ )
403
- assert_raises (IndexError , lambda : a [key ])
403
+ with pytest .raises (IndexError ):
404
+ a [key ]
404
405
405
406
def test_array_namespace ():
406
407
a = ones ((3 , 3 ))
@@ -421,16 +422,16 @@ def test_array_namespace():
421
422
assert a .__array_namespace__ (api_version = "2021.12" ) is array_api_strict
422
423
assert array_api_strict .__array_api_version__ == "2021.12"
423
424
424
- assert_raises (ValueError , lambda : a .__array_namespace__ (api_version = "2021.11" ))
425
- assert_raises (ValueError , lambda : a .__array_namespace__ (api_version = "2024.12" ))
425
+ pytest . raises (ValueError , lambda : a .__array_namespace__ (api_version = "2021.11" ))
426
+ pytest . raises (ValueError , lambda : a .__array_namespace__ (api_version = "2024.12" ))
426
427
427
428
def test_iter ():
428
- assert_raises (TypeError , lambda : iter (asarray (3 )))
429
+ pytest . raises (TypeError , lambda : iter (asarray (3 )))
429
430
assert list (ones (3 )) == [asarray (1. ), asarray (1. ), asarray (1. )]
430
431
assert all_ (isinstance (a , Array ) for a in iter (ones (3 )))
431
432
assert all_ (a .shape == () for a in iter (ones (3 )))
432
433
assert all_ (a .dtype == float64 for a in iter (ones (3 )))
433
- assert_raises (TypeError , lambda : iter (ones ((3 , 3 ))))
434
+ pytest . raises (TypeError , lambda : iter (ones ((3 , 3 ))))
434
435
435
436
@pytest .mark .parametrize ("api_version" , ['2021.12' , '2022.12' , '2023.12' ])
436
437
def dlpack_2023_12 (api_version ):
@@ -446,17 +447,17 @@ def dlpack_2023_12(api_version):
446
447
447
448
448
449
exception = NotImplementedError if api_version >= '2023.12' else ValueError
449
- assert_raises (exception , lambda :
450
+ pytest . raises (exception , lambda :
450
451
a .__dlpack__ (dl_device = CPU_DEVICE ))
451
- assert_raises (exception , lambda :
452
+ pytest . raises (exception , lambda :
452
453
a .__dlpack__ (dl_device = None ))
453
- assert_raises (exception , lambda :
454
+ pytest . raises (exception , lambda :
454
455
a .__dlpack__ (max_version = (1 , 0 )))
455
- assert_raises (exception , lambda :
456
+ pytest . raises (exception , lambda :
456
457
a .__dlpack__ (max_version = None ))
457
- assert_raises (exception , lambda :
458
+ pytest . raises (exception , lambda :
458
459
a .__dlpack__ (copy = False ))
459
- assert_raises (exception , lambda :
460
+ pytest . raises (exception , lambda :
460
461
a .__dlpack__ (copy = True ))
461
- assert_raises (exception , lambda :
462
+ pytest . raises (exception , lambda :
462
463
a .__dlpack__ (copy = None ))
0 commit comments