16
16
import pytest
17
17
from hypothesis import assume , given
18
18
from hypothesis .strategies import (booleans , composite , none , tuples , integers ,
19
- shared , sampled_from , data , just )
19
+ shared , sampled_from , one_of , data , just )
20
20
from ndindex import iter_indices
21
21
22
22
from .array_helpers import assert_exactly_equal , asarray
26
26
invertible_matrices , two_mutual_arrays ,
27
27
mutually_promotable_dtypes , one_d_shapes ,
28
28
two_mutually_broadcastable_shapes ,
29
- SQRT_MAX_ARRAY_SIZE , finite_matrices )
29
+ SQRT_MAX_ARRAY_SIZE , finite_matrices ,
30
+ rtol_shared_matrix_shapes , rtols )
30
31
from . import dtype_helpers as dh
31
32
from . import pytest_helpers as ph
32
33
from . import shape_helpers as sh
37
38
38
39
pytestmark = pytest .mark .ci
39
40
40
-
41
-
42
41
# Standin strategy for not yet implemented tests
43
42
todo = none ()
44
43
45
- def _test_stacks (f , * args , res = None , dims = 2 , true_val = None , matrix_axes = (- 2 , - 1 ),
44
+ def _test_stacks (f , * args , res = None , dims = 2 , true_val = None ,
45
+ matrix_axes = (- 2 , - 1 ),
46
46
assert_equal = assert_exactly_equal , ** kw ):
47
47
"""
48
48
Test that f(*args, **kw) maps across stacks of matrices
49
49
50
- dims is the number of dimensions f(*args) should have for a single n x m
51
- matrix stack.
50
+ dims is the number of dimensions f(*args, *kw ) should have for a single n
51
+ x m matrix stack.
52
52
53
53
matrix_axes are the axes along which matrices (or vectors) are stacked in
54
54
the input.
@@ -65,9 +65,13 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1)
65
65
66
66
shapes = [x .shape for x in args ]
67
67
68
+ # Assume the result is stacked along the last 'dims' axes of matrix_axes.
69
+ # This holds for all the functions tested in this file
70
+ res_axes = matrix_axes [::- 1 ][:dims ]
71
+
68
72
for (x_idxes , (res_idx ,)) in zip (
69
73
iter_indices (* shapes , skip_axes = matrix_axes ),
70
- iter_indices (res .shape , skip_axes = tuple ( range ( - dims , 0 )) )):
74
+ iter_indices (res .shape , skip_axes = res_axes )):
71
75
x_idxes = [x_idx .raw for x_idx in x_idxes ]
72
76
res_idx = res_idx .raw
73
77
@@ -159,26 +163,18 @@ def test_cross(x1_x2_kw):
159
163
assert res .dtype == dh .result_type (x1 .dtype , x2 .dtype ), "cross() did not return the correct dtype"
160
164
assert res .shape == shape , "cross() did not return the correct shape"
161
165
162
- # cross is too different from other functions to use _test_stacks, and it
163
- # is the only function that works the way it does, so it's not really
164
- # worth generalizing _test_stacks to handle it.
165
- a = axis if axis >= 0 else axis + len (shape )
166
- for _idx in sh .ndindex (shape [:a ] + shape [a + 1 :]):
167
- idx = _idx [:a ] + (slice (None ),) + _idx [a :]
168
- assert len (idx ) == len (shape ), "Invalid index. This indicates a bug in the test suite."
169
- res_stack = res [idx ]
170
- x1_stack = x1 [idx ]
171
- x2_stack = x2 [idx ]
172
- assert x1_stack .shape == x2_stack .shape == (3 ,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
173
- decomp_res_stack = linalg .cross (x1_stack , x2_stack )
174
- assert_exactly_equal (res_stack , decomp_res_stack )
175
-
176
- exact_cross = asarray ([
177
- x1_stack [1 ]* x2_stack [2 ] - x1_stack [2 ]* x2_stack [1 ],
178
- x1_stack [2 ]* x2_stack [0 ] - x1_stack [0 ]* x2_stack [2 ],
179
- x1_stack [0 ]* x2_stack [1 ] - x1_stack [1 ]* x2_stack [0 ],
180
- ], dtype = res .dtype )
181
- assert_exactly_equal (res_stack , exact_cross )
166
+ def exact_cross (a , b ):
167
+ assert a .shape == b .shape == (3 ,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
168
+ return asarray ([
169
+ a [1 ]* b [2 ] - a [2 ]* b [1 ],
170
+ a [2 ]* b [0 ] - a [0 ]* b [2 ],
171
+ a [0 ]* b [1 ] - a [1 ]* b [0 ],
172
+ ], dtype = res .dtype )
173
+
174
+ # We don't want to pass in **kw here because that would pass axis to
175
+ # cross() on a single stack, but the axis is not meaningful on unstacked
176
+ # vectors.
177
+ _test_stacks (linalg .cross , x1 , x2 , dims = 1 , matrix_axes = (axis ,), res = res , true_val = exact_cross )
182
178
183
179
@pytest .mark .xp_extension ('linalg' )
184
180
@given (
@@ -313,14 +309,30 @@ def test_matmul(x1, x2):
313
309
assert res .shape == stack_shape + (x1 .shape [- 2 ], x2 .shape [- 1 ])
314
310
_test_stacks (_array_module .matmul , x1 , x2 , res = res )
315
311
312
+ matrix_norm_shapes = shared (matrix_shapes ())
313
+
316
314
@pytest .mark .xp_extension ('linalg' )
317
315
@given (
318
- x = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ()),
319
- kw = kwargs (axis = todo , keepdims = todo , ord = todo )
316
+ x = finite_matrices (),
317
+ kw = kwargs (keepdims = booleans (),
318
+ ord = sampled_from ([- float ('inf' ), - 2 , - 2 , 1 , 2 , float ('inf' ), 'fro' , 'nuc' ]))
320
319
)
321
320
def test_matrix_norm (x , kw ):
322
- # res = linalg.matrix_norm(x, **kw)
323
- pass
321
+ res = linalg .matrix_norm (x , ** kw )
322
+
323
+ keepdims = kw .get ('keepdims' , False )
324
+ # TODO: Check that the ord values give the correct norms.
325
+ # ord = kw.get('ord', 'fro')
326
+
327
+ if keepdims :
328
+ expected_shape = x .shape [:- 2 ] + (1 , 1 )
329
+ else :
330
+ expected_shape = x .shape [:- 2 ]
331
+ assert res .shape == expected_shape , f"matrix_norm({ keepdims = } ) did not return the correct shape"
332
+ assert res .dtype == x .dtype , "matrix_norm() did not return the correct dtype"
333
+
334
+ _test_stacks (linalg .matrix_norm , x , ** kw , dims = 2 if keepdims else 0 ,
335
+ res = res )
324
336
325
337
matrix_power_n = shared (integers (- 1000 , 1000 ), key = 'matrix_power n' )
326
338
@pytest .mark .xp_extension ('linalg' )
@@ -347,12 +359,11 @@ def test_matrix_power(x, n):
347
359
348
360
@pytest .mark .xp_extension ('linalg' )
349
361
@given (
350
- x = xps . arrays ( dtype = xps . floating_dtypes (), shape = shapes () ),
351
- kw = kwargs (rtol = todo )
362
+ x = finite_matrices ( shape = rtol_shared_matrix_shapes ),
363
+ kw = kwargs (rtol = rtols )
352
364
)
353
365
def test_matrix_rank (x , kw ):
354
- # res = linalg.matrix_rank(x, **kw)
355
- pass
366
+ linalg .matrix_rank (x , ** kw )
356
367
357
368
@given (
358
369
x = xps .arrays (dtype = dtypes , shape = matrix_shapes ()),
@@ -397,12 +408,11 @@ def test_outer(x1, x2):
397
408
398
409
@pytest .mark .xp_extension ('linalg' )
399
410
@given (
400
- x = xps . arrays ( dtype = xps . floating_dtypes (), shape = shapes () ),
401
- kw = kwargs (rtol = todo )
411
+ x = finite_matrices ( shape = rtol_shared_matrix_shapes ),
412
+ kw = kwargs (rtol = rtols )
402
413
)
403
414
def test_pinv (x , kw ):
404
- # res = linalg.pinv(x, **kw)
405
- pass
415
+ linalg .pinv (x , ** kw )
406
416
407
417
@pytest .mark .xp_extension ('linalg' )
408
418
@given (
@@ -482,7 +492,7 @@ def solve_args():
482
492
Strategy for the x1 and x2 arguments to test_solve()
483
493
484
494
solve() takes x1, x2, where x1 is any stack of square invertible matrices
485
- of shape (..., M, M), and x2 is either shape (..., M ) or (..., M, K),
495
+ of shape (..., M, M), and x2 is either shape (M, ) or (..., M, K),
486
496
where the ... parts of x1 and x2 are broadcast compatible.
487
497
"""
488
498
stack_shapes = shared (two_mutually_broadcastable_shapes )
@@ -492,30 +502,22 @@ def solve_args():
492
502
pair [0 ])))
493
503
494
504
@composite
495
- def x2_shapes (draw ):
496
- end = draw (xps .array_shapes (min_dims = 0 , max_dims = 1 , min_side = 0 ,
497
- max_side = SQRT_MAX_ARRAY_SIZE ))
498
- return draw (stack_shapes )[1 ] + draw (x1 ).shape [- 1 :] + end
505
+ def _x2_shapes (draw ):
506
+ end = draw (integers (0 , SQRT_MAX_ARRAY_SIZE ))
507
+ return draw (stack_shapes )[1 ] + draw (x1 ).shape [- 1 :] + (end ,)
499
508
500
- x2 = xps .arrays (dtype = xps .floating_dtypes (), shape = x2_shapes ())
509
+ x2_shapes = one_of (x1 .map (lambda x : (x .shape [- 1 ],)), _x2_shapes ())
510
+ x2 = xps .arrays (dtype = xps .floating_dtypes (), shape = x2_shapes )
501
511
return x1 , x2
502
512
503
513
@pytest .mark .xp_extension ('linalg' )
504
514
@given (* solve_args ())
505
515
def test_solve (x1 , x2 ):
506
- # TODO: solve() is currently ambiguous, in that some inputs can be
507
- # interpreted in two different ways. For example, if x1 is shape (2, 2, 2)
508
- # and x2 is shape (2, 2), should this be interpreted as x2 is (2,) stack
509
- # of a (2,) vector, i.e., the result would be (2, 2, 2, 1) after
510
- # broadcasting, or as a single stack of a 2x2 matrix, i.e., resulting in
511
- # (2, 2, 2, 2).
512
-
513
- # res = linalg.solve(x1, x2)
514
- pass
516
+ linalg .solve (x1 , x2 )
515
517
516
518
@pytest .mark .xp_extension ('linalg' )
517
519
@given (
518
- x = finite_matrices ,
520
+ x = finite_matrices () ,
519
521
kw = kwargs (full_matrices = booleans ())
520
522
)
521
523
def test_svd (x , kw ):
@@ -551,7 +553,7 @@ def test_svd(x, kw):
551
553
552
554
@pytest .mark .xp_extension ('linalg' )
553
555
@given (
554
- x = finite_matrices ,
556
+ x = finite_matrices () ,
555
557
)
556
558
def test_svdvals (x ):
557
559
res = linalg .svdvals (x )
0 commit comments