Skip to content

Commit cb2e7d0

Browse files
authored
Merge pull request #96 from honno/more-linalg2
Implementing the remaining linalg tests w/ additional fixes
2 parents aae17fc + 0572275 commit cb2e7d0

File tree

2 files changed

+74
-61
lines changed

2 files changed

+74
-61
lines changed

array_api_tests/hypothesis_helpers.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,21 @@ def matrix_shapes(draw, stack_shapes=shapes()):
158158

159159
square_matrix_shapes = matrix_shapes().filter(lambda shape: shape[-1] == shape[-2])
160160

161-
finite_matrices = xps.arrays(dtype=xps.floating_dtypes(),
162-
shape=matrix_shapes(),
163-
elements=dict(allow_nan=False,
164-
allow_infinity=False))
161+
@composite
162+
def finite_matrices(draw, shape=matrix_shapes()):
163+
return draw(xps.arrays(dtype=xps.floating_dtypes(),
164+
shape=shape,
165+
elements=dict(allow_nan=False,
166+
allow_infinity=False)))
167+
168+
rtol_shared_matrix_shapes = shared(matrix_shapes())
169+
# Should we set a max_value here?
170+
_rtol_float_kw = dict(allow_nan=False, allow_infinity=False, min_value=0)
171+
rtols = one_of(floats(**_rtol_float_kw),
172+
xps.arrays(dtype=xps.floating_dtypes(),
173+
shape=rtol_shared_matrix_shapes.map(lambda shape: shape[:-2]),
174+
elements=_rtol_float_kw))
175+
165176

166177
def mutually_broadcastable_shapes(
167178
num_shapes: int,

array_api_tests/test_linalg.py

+59-57
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import pytest
1717
from hypothesis import assume, given
1818
from hypothesis.strategies import (booleans, composite, none, tuples, integers,
19-
shared, sampled_from, data, just)
19+
shared, sampled_from, one_of, data, just)
2020
from ndindex import iter_indices
2121

2222
from .array_helpers import assert_exactly_equal, asarray
@@ -26,7 +26,8 @@
2626
invertible_matrices, two_mutual_arrays,
2727
mutually_promotable_dtypes, one_d_shapes,
2828
two_mutually_broadcastable_shapes,
29-
SQRT_MAX_ARRAY_SIZE, finite_matrices)
29+
SQRT_MAX_ARRAY_SIZE, finite_matrices,
30+
rtol_shared_matrix_shapes, rtols)
3031
from . import dtype_helpers as dh
3132
from . import pytest_helpers as ph
3233
from . import shape_helpers as sh
@@ -37,18 +38,17 @@
3738

3839
pytestmark = pytest.mark.ci
3940

40-
41-
4241
# Standin strategy for not yet implemented tests
4342
todo = none()
4443

45-
def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1),
44+
def _test_stacks(f, *args, res=None, dims=2, true_val=None,
45+
matrix_axes=(-2, -1),
4646
assert_equal=assert_exactly_equal, **kw):
4747
"""
4848
Test that f(*args, **kw) maps across stacks of matrices
4949
50-
dims is the number of dimensions f(*args) should have for a single n x m
51-
matrix stack.
50+
dims is the number of dimensions f(*args, *kw) should have for a single n
51+
x m matrix stack.
5252
5353
matrix_axes are the axes along which matrices (or vectors) are stacked in
5454
the input.
@@ -65,9 +65,13 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1)
6565

6666
shapes = [x.shape for x in args]
6767

68+
# Assume the result is stacked along the last 'dims' axes of matrix_axes.
69+
# This holds for all the functions tested in this file
70+
res_axes = matrix_axes[::-1][:dims]
71+
6872
for (x_idxes, (res_idx,)) in zip(
6973
iter_indices(*shapes, skip_axes=matrix_axes),
70-
iter_indices(res.shape, skip_axes=tuple(range(-dims, 0)))):
74+
iter_indices(res.shape, skip_axes=res_axes)):
7175
x_idxes = [x_idx.raw for x_idx in x_idxes]
7276
res_idx = res_idx.raw
7377

@@ -159,26 +163,18 @@ def test_cross(x1_x2_kw):
159163
assert res.dtype == dh.result_type(x1.dtype, x2.dtype), "cross() did not return the correct dtype"
160164
assert res.shape == shape, "cross() did not return the correct shape"
161165

162-
# cross is too different from other functions to use _test_stacks, and it
163-
# is the only function that works the way it does, so it's not really
164-
# worth generalizing _test_stacks to handle it.
165-
a = axis if axis >= 0 else axis + len(shape)
166-
for _idx in sh.ndindex(shape[:a] + shape[a+1:]):
167-
idx = _idx[:a] + (slice(None),) + _idx[a:]
168-
assert len(idx) == len(shape), "Invalid index. This indicates a bug in the test suite."
169-
res_stack = res[idx]
170-
x1_stack = x1[idx]
171-
x2_stack = x2[idx]
172-
assert x1_stack.shape == x2_stack.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
173-
decomp_res_stack = linalg.cross(x1_stack, x2_stack)
174-
assert_exactly_equal(res_stack, decomp_res_stack)
175-
176-
exact_cross = asarray([
177-
x1_stack[1]*x2_stack[2] - x1_stack[2]*x2_stack[1],
178-
x1_stack[2]*x2_stack[0] - x1_stack[0]*x2_stack[2],
179-
x1_stack[0]*x2_stack[1] - x1_stack[1]*x2_stack[0],
180-
], dtype=res.dtype)
181-
assert_exactly_equal(res_stack, exact_cross)
166+
def exact_cross(a, b):
167+
assert a.shape == b.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
168+
return asarray([
169+
a[1]*b[2] - a[2]*b[1],
170+
a[2]*b[0] - a[0]*b[2],
171+
a[0]*b[1] - a[1]*b[0],
172+
], dtype=res.dtype)
173+
174+
# We don't want to pass in **kw here because that would pass axis to
175+
# cross() on a single stack, but the axis is not meaningful on unstacked
176+
# vectors.
177+
_test_stacks(linalg.cross, x1, x2, dims=1, matrix_axes=(axis,), res=res, true_val=exact_cross)
182178

183179
@pytest.mark.xp_extension('linalg')
184180
@given(
@@ -313,14 +309,30 @@ def test_matmul(x1, x2):
313309
assert res.shape == stack_shape + (x1.shape[-2], x2.shape[-1])
314310
_test_stacks(_array_module.matmul, x1, x2, res=res)
315311

312+
matrix_norm_shapes = shared(matrix_shapes())
313+
316314
@pytest.mark.xp_extension('linalg')
317315
@given(
318-
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes()),
319-
kw=kwargs(axis=todo, keepdims=todo, ord=todo)
316+
x=finite_matrices(),
317+
kw=kwargs(keepdims=booleans(),
318+
ord=sampled_from([-float('inf'), -2, -2, 1, 2, float('inf'), 'fro', 'nuc']))
320319
)
321320
def test_matrix_norm(x, kw):
322-
# res = linalg.matrix_norm(x, **kw)
323-
pass
321+
res = linalg.matrix_norm(x, **kw)
322+
323+
keepdims = kw.get('keepdims', False)
324+
# TODO: Check that the ord values give the correct norms.
325+
# ord = kw.get('ord', 'fro')
326+
327+
if keepdims:
328+
expected_shape = x.shape[:-2] + (1, 1)
329+
else:
330+
expected_shape = x.shape[:-2]
331+
assert res.shape == expected_shape, f"matrix_norm({keepdims=}) did not return the correct shape"
332+
assert res.dtype == x.dtype, "matrix_norm() did not return the correct dtype"
333+
334+
_test_stacks(linalg.matrix_norm, x, **kw, dims=2 if keepdims else 0,
335+
res=res)
324336

325337
matrix_power_n = shared(integers(-1000, 1000), key='matrix_power n')
326338
@pytest.mark.xp_extension('linalg')
@@ -347,12 +359,11 @@ def test_matrix_power(x, n):
347359

348360
@pytest.mark.xp_extension('linalg')
349361
@given(
350-
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes()),
351-
kw=kwargs(rtol=todo)
362+
x=finite_matrices(shape=rtol_shared_matrix_shapes),
363+
kw=kwargs(rtol=rtols)
352364
)
353365
def test_matrix_rank(x, kw):
354-
# res = linalg.matrix_rank(x, **kw)
355-
pass
366+
linalg.matrix_rank(x, **kw)
356367

357368
@given(
358369
x=xps.arrays(dtype=dtypes, shape=matrix_shapes()),
@@ -397,12 +408,11 @@ def test_outer(x1, x2):
397408

398409
@pytest.mark.xp_extension('linalg')
399410
@given(
400-
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes()),
401-
kw=kwargs(rtol=todo)
411+
x=finite_matrices(shape=rtol_shared_matrix_shapes),
412+
kw=kwargs(rtol=rtols)
402413
)
403414
def test_pinv(x, kw):
404-
# res = linalg.pinv(x, **kw)
405-
pass
415+
linalg.pinv(x, **kw)
406416

407417
@pytest.mark.xp_extension('linalg')
408418
@given(
@@ -482,7 +492,7 @@ def solve_args():
482492
Strategy for the x1 and x2 arguments to test_solve()
483493
484494
solve() takes x1, x2, where x1 is any stack of square invertible matrices
485-
of shape (..., M, M), and x2 is either shape (..., M) or (..., M, K),
495+
of shape (..., M, M), and x2 is either shape (M,) or (..., M, K),
486496
where the ... parts of x1 and x2 are broadcast compatible.
487497
"""
488498
stack_shapes = shared(two_mutually_broadcastable_shapes)
@@ -492,30 +502,22 @@ def solve_args():
492502
pair[0])))
493503

494504
@composite
495-
def x2_shapes(draw):
496-
end = draw(xps.array_shapes(min_dims=0, max_dims=1, min_side=0,
497-
max_side=SQRT_MAX_ARRAY_SIZE))
498-
return draw(stack_shapes)[1] + draw(x1).shape[-1:] + end
505+
def _x2_shapes(draw):
506+
end = draw(integers(0, SQRT_MAX_ARRAY_SIZE))
507+
return draw(stack_shapes)[1] + draw(x1).shape[-1:] + (end,)
499508

500-
x2 = xps.arrays(dtype=xps.floating_dtypes(), shape=x2_shapes())
509+
x2_shapes = one_of(x1.map(lambda x: (x.shape[-1],)), _x2_shapes())
510+
x2 = xps.arrays(dtype=xps.floating_dtypes(), shape=x2_shapes)
501511
return x1, x2
502512

503513
@pytest.mark.xp_extension('linalg')
504514
@given(*solve_args())
505515
def test_solve(x1, x2):
506-
# TODO: solve() is currently ambiguous, in that some inputs can be
507-
# interpreted in two different ways. For example, if x1 is shape (2, 2, 2)
508-
# and x2 is shape (2, 2), should this be interpreted as x2 is (2,) stack
509-
# of a (2,) vector, i.e., the result would be (2, 2, 2, 1) after
510-
# broadcasting, or as a single stack of a 2x2 matrix, i.e., resulting in
511-
# (2, 2, 2, 2).
512-
513-
# res = linalg.solve(x1, x2)
514-
pass
516+
linalg.solve(x1, x2)
515517

516518
@pytest.mark.xp_extension('linalg')
517519
@given(
518-
x=finite_matrices,
520+
x=finite_matrices(),
519521
kw=kwargs(full_matrices=booleans())
520522
)
521523
def test_svd(x, kw):
@@ -551,7 +553,7 @@ def test_svd(x, kw):
551553

552554
@pytest.mark.xp_extension('linalg')
553555
@given(
554-
x=finite_matrices,
556+
x=finite_matrices(),
555557
)
556558
def test_svdvals(x):
557559
res = linalg.svdvals(x)

0 commit comments

Comments
 (0)