Skip to content

Commit a1d7701

Browse files
authored
Merge pull request #101 from asmeurer/more-linalg2
More improvements to test_linalg
2 parents 4f83bb3 + 0ddb0cd commit a1d7701

9 files changed

+555
-221
lines changed

Diff for: array_api_tests/array_helpers.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
# These are exported here so that they can be included in the special cases
88
# tests from this file.
99
from ._array_module import logical_not, subtract, floor, ceil, where
10+
from . import _array_module as xp
1011
from . import dtype_helpers as dh
1112

12-
1313
__all__ = ['all', 'any', 'logical_and', 'logical_or', 'logical_not', 'less',
1414
'less_equal', 'greater', 'subtract', 'negative', 'floor', 'ceil',
1515
'where', 'isfinite', 'equal', 'not_equal', 'zero', 'one', 'NaN',
@@ -164,19 +164,21 @@ def notequal(x, y):
164164

165165
return not_equal(x, y)
166166

167-
def assert_exactly_equal(x, y):
167+
def assert_exactly_equal(x, y, msg_extra=None):
168168
"""
169169
Test that the arrays x and y are exactly equal.
170170
171171
If x and y do not have the same shape and dtype, they are not considered
172172
equal.
173173
174174
"""
175-
assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape})"
175+
extra = '' if not msg_extra else f' ({msg_extra})'
176+
177+
assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape}){extra}"
176178

177-
assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype})"
179+
assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype}){extra}"
178180

179-
assert all(exactly_equal(x, y)), "The input arrays have different values"
181+
assert all(exactly_equal(x, y)), f"The input arrays have different values ({x!r} != {y!r}){extra}"
180182

181183
def assert_finite(x):
182184
"""
@@ -306,3 +308,13 @@ def same_sign(x, y):
306308
def assert_same_sign(x, y):
307309
assert all(same_sign(x, y)), "The input arrays do not have the same sign"
308310

311+
def _matrix_transpose(x):
312+
if not isinstance(xp.matrix_transpose, xp._UndefinedStub):
313+
return xp.matrix_transpose(x)
314+
if hasattr(x, 'mT'):
315+
return x.mT
316+
if not isinstance(xp.permute_dims, xp._UndefinedStub):
317+
perm = list(range(x.ndim))
318+
perm[-1], perm[-2] = perm[-2], perm[-1]
319+
return xp.permute_dims(x, axes=tuple(perm))
320+
raise NotImplementedError("No way to compute matrix transpose")

Diff for: array_api_tests/dtype_helpers.py

+51
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,57 @@ class MinMax(NamedTuple):
231231
{"complex64": xp.float32, "complex128": xp.float64}
232232
)
233233

234+
def as_real_dtype(dtype):
235+
"""
236+
Return the corresponding real dtype for a given floating-point dtype.
237+
"""
238+
if dtype in real_float_dtypes:
239+
return dtype
240+
elif dtype_to_name[dtype] in complex_names:
241+
return dtype_components[dtype]
242+
else:
243+
raise ValueError("as_real_dtype requires a floating-point dtype")
244+
245+
def accumulation_result_dtype(x_dtype, dtype_kwarg):
246+
"""
247+
Result dtype logic for sum(), prod(), and trace()
248+
249+
Note: may return None if a default uint cannot exist (e.g., for pytorch
250+
which doesn't support uint32 or uint64). See https://github.com/data-apis/array-api-tests/issues/106
251+
252+
"""
253+
if dtype_kwarg is None:
254+
if is_int_dtype(x_dtype):
255+
if x_dtype in uint_dtypes:
256+
default_dtype = default_uint
257+
else:
258+
default_dtype = default_int
259+
if default_dtype is None:
260+
_dtype = None
261+
else:
262+
m, M = dtype_ranges[x_dtype]
263+
d_m, d_M = dtype_ranges[default_dtype]
264+
if m < d_m or M > d_M:
265+
_dtype = x_dtype
266+
else:
267+
_dtype = default_dtype
268+
elif is_float_dtype(x_dtype, include_complex=False):
269+
if dtype_nbits[x_dtype] > dtype_nbits[default_float]:
270+
_dtype = x_dtype
271+
else:
272+
_dtype = default_float
273+
elif api_version > "2021.12":
274+
# Complex dtype
275+
if dtype_nbits[x_dtype] > dtype_nbits[default_complex]:
276+
_dtype = x_dtype
277+
else:
278+
_dtype = default_complex
279+
else:
280+
raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.")
281+
else:
282+
_dtype = dtype_kwarg
283+
284+
return _dtype
234285

235286
if not hasattr(xp, "asarray"):
236287
default_int = xp.int32

Diff for: array_api_tests/hypothesis_helpers.py

+21-9
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
sampled_from, shared, builds)
1313

1414
from . import _array_module as xp, api_version
15+
from . import array_helpers as ah
1516
from . import dtype_helpers as dh
1617
from . import shape_helpers as sh
1718
from . import xps
@@ -211,6 +212,7 @@ def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False)
211212

212213
# Use this to avoid memory errors with NumPy.
213214
# See https://github.com/numpy/numpy/issues/15753
215+
# Note, the hypothesis default for max_dims is min_dims + 2 (i.e., 0 + 2)
214216
def shapes(**kw):
215217
kw.setdefault('min_dims', 0)
216218
kw.setdefault('min_side', 0)
@@ -280,25 +282,29 @@ def mutually_broadcastable_shapes(
280282

281283
# Note: This should become hermitian_matrices when complex dtypes are added
282284
@composite
283-
def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True):
285+
def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True, bound=10.):
284286
shape = draw(square_matrix_shapes)
285287
dtype = draw(dtypes)
286288
if not isinstance(finite, bool):
287289
finite = draw(finite)
288290
elements = {'allow_nan': False, 'allow_infinity': False} if finite else None
289291
a = draw(arrays(dtype=dtype, shape=shape, elements=elements))
290-
upper = xp.triu(a)
291-
lower = xp.triu(a, k=1).mT
292-
return upper + lower
292+
at = ah._matrix_transpose(a)
293+
H = (a + at)*0.5
294+
if finite:
295+
assume(not xp.any(xp.isinf(H)))
296+
assume(xp.all((H == 0.) | ((1/bound <= xp.abs(H)) & (xp.abs(H) <= bound))))
297+
return H
293298

294299
@composite
295300
def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()):
296301
# For now just generate stacks of identity matrices
297302
# TODO: Generate arbitrary positive definite matrices, for instance, by
298303
# using something like
299304
# https://github.com/scikit-learn/scikit-learn/blob/844b4be24/sklearn/datasets/_samples_generator.py#L1351.
300-
n = draw(integers(0))
301-
shape = draw(shapes()) + (n, n)
305+
base_shape = draw(shapes())
306+
n = draw(integers(0, 8)) # 8 is an arbitrary small but interesting-enough value
307+
shape = base_shape + (n, n)
302308
assume(prod(i for i in shape if i) < MAX_ARRAY_SIZE)
303309
dtype = draw(dtypes)
304310
return broadcast_to(eye(n, dtype=dtype), shape)
@@ -308,12 +314,18 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes(
308314
# For now, just generate stacks of diagonal matrices.
309315
n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),)
310316
stack_shape = draw(stack_shapes)
311-
d = draw(arrays(dtypes, shape=(*stack_shape, 1, n),
312-
elements=dict(allow_nan=False, allow_infinity=False)))
317+
dtype = draw(dtypes)
318+
elements = one_of(
319+
from_dtype(dtype, min_value=0.5, allow_nan=False, allow_infinity=False),
320+
from_dtype(dtype, max_value=-0.5, allow_nan=False, allow_infinity=False),
321+
)
322+
d = draw(arrays(dtype, shape=(*stack_shape, 1, n), elements=elements))
323+
313324
# Functions that require invertible matrices may do anything when it is
314325
# singular, including raising an exception, so we make sure the diagonals
315326
# are sufficiently nonzero to avoid any numerical issues.
316-
assume(xp.all(xp.abs(d) > 0.5))
327+
assert xp.all(xp.abs(d) >= 0.5)
328+
317329
diag_mask = xp.arange(n) == xp.reshape(xp.arange(n), (n, 1))
318330
return xp.where(diag_mask, d, xp.zeros_like(d))
319331

Diff for: array_api_tests/meta/test_linalg.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import pytest
2+
3+
from hypothesis import given
4+
5+
from ..hypothesis_helpers import symmetric_matrices
6+
from .. import array_helpers as ah
7+
from .. import _array_module as xp
8+
9+
@pytest.mark.xp_extension('linalg')
10+
@given(x=symmetric_matrices(finite=True))
11+
def test_symmetric_matrices(x):
12+
upper = xp.triu(x)
13+
lower = xp.tril(x)
14+
lowerT = ah._matrix_transpose(lower)
15+
16+
ah.assert_exactly_equal(upper, lowerT)

0 commit comments

Comments
 (0)