From 5a2036b0aaaeef695c5a1808433f73100275adfc Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 20 Nov 2024 15:35:06 -0700 Subject: [PATCH 1/2] Use shared() in test_reshape --- array_api_tests/test_manipulation_functions.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index b8a919c4..16b9e1e8 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -349,7 +349,8 @@ def test_repeat(x, kw, data): start = end @st.composite -def reshape_shapes(draw, shape): +def reshape_shapes(draw, shapes): + shape = draw(shapes) size = 1 if len(shape) == 0 else math.prod(shape) rshape = draw(st.lists(st.integers(0)).filter(lambda s: math.prod(s) == size)) assume(all(side <= MAX_SIDE for side in rshape)) @@ -359,15 +360,14 @@ def reshape_shapes(draw, shape): return tuple(rshape) +reshape_shape = st.shared(hh.shapes(max_side=MAX_SIDE), key="reshape_shape") + @pytest.mark.unvectorized -@pytest.mark.skip("flaky") # TODO: fix! @given( - x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(max_side=MAX_SIDE)), - data=st.data(), + x=hh.arrays(dtype=hh.all_dtypes, shape=reshape_shape), + shape=reshape_shapes(reshape_shape), ) -def test_reshape(x, data): - shape = data.draw(reshape_shapes(x.shape)) - +def test_reshape(x, shape): out = xp.reshape(x, shape) ph.assert_dtype("reshape", in_dtype=x.dtype, out_dtype=out.dtype) From bcfcdba88c5ac4b8ec4333ada6f18c6b9be52886 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 22 Nov 2024 15:44:41 -0700 Subject: [PATCH 2/2] Fix test_reshape - Fix the input strategy to not generate arrays that are too large, which was causing an error from hypothesis. - Rewrite the reshape_shapes() strategy to generate reshape tuples directly by distributing the prime factors of the array size, rather than by using filtering. --- array_api_tests/hypothesis_helpers.py | 62 +++++++++++++++++++ .../test_manipulation_functions.py | 21 +------ 2 files changed, 65 insertions(+), 18 deletions(-) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index eeb470f7..5814534a 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -236,6 +236,68 @@ def shapes(**kw): lambda shape: prod(i for i in shape if i) < MAX_ARRAY_SIZE ) +def _factorize(n: int) -> List[int]: + # Simple prime factorization. Only needs to handle n ~ MAX_ARRAY_SIZE + factors = [] + while n % 2 == 0: + factors.append(2) + n //= 2 + + for i in range(3, int(math.sqrt(n)) + 1, 2): + while n % i == 0: + factors.append(i) + n //= i + + if n > 1: # n is a prime number greater than 2 + factors.append(n) + + return factors + +MAX_SIDE = MAX_ARRAY_SIZE // 64 +# NumPy only supports up to 32 dims. TODO: Get this from the new inspection APIs +MAX_DIMS = min(MAX_ARRAY_SIZE // MAX_SIDE, 32) + + +@composite +def reshape_shapes(draw, arr_shape, ndims=integers(1, MAX_DIMS)): + """ + Generate shape tuples whose product equals the product of array_shape. + """ + shape = draw(arr_shape) + + array_size = math.prod(shape) + + n_dims = draw(ndims) + + # Handle special cases + if array_size == 0: + # Generate a random tuple, and ensure at least one of the entries is 0 + result = list(draw(shapes(min_dims=n_dims, max_dims=n_dims))) + pos = draw(integers(0, n_dims - 1)) + result[pos] = 0 + return tuple(result) + + if array_size == 1: + return tuple(1 for _ in range(n_dims)) + + # Get prime factorization + factors = _factorize(array_size) + + # Distribute prime factors randomly + result = [1] * n_dims + for factor in factors: + pos = draw(integers(0, n_dims - 1)) + result[pos] *= factor + + assert math.prod(result) == array_size + + # An element of the reshape tuple can be -1, which means it is a stand-in + # for the remaining factors. + if draw(booleans()): + pos = draw(integers(0, n_dims - 1)) + result[pos] = -1 + + return tuple(result) one_d_shapes = xps.array_shapes(min_dims=1, max_dims=1, min_side=0, max_side=SQRT_MAX_ARRAY_SIZE) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 16b9e1e8..d5344039 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -14,9 +14,6 @@ from . import xps from .typing import Array, Shape -MAX_SIDE = hh.MAX_ARRAY_SIZE // 64 -MAX_DIMS = min(hh.MAX_ARRAY_SIZE // MAX_SIDE, 32) # NumPy only supports up to 32 dims - def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]: key = "shape" @@ -66,7 +63,7 @@ def test_concat(dtypes, base_shape, data): shape_strat = hh.shapes() else: _axis = axis if axis >= 0 else len(base_shape) + axis - shape_strat = st.integers(0, MAX_SIDE).map( + shape_strat = st.integers(0, hh.MAX_SIDE).map( lambda i: base_shape[:_axis] + (i,) + base_shape[_axis + 1 :] ) arrays = [] @@ -348,24 +345,12 @@ def test_repeat(x, kw, data): kw=kw) start = end -@st.composite -def reshape_shapes(draw, shapes): - shape = draw(shapes) - size = 1 if len(shape) == 0 else math.prod(shape) - rshape = draw(st.lists(st.integers(0)).filter(lambda s: math.prod(s) == size)) - assume(all(side <= MAX_SIDE for side in rshape)) - if len(rshape) != 0 and size > 0 and draw(st.booleans()): - index = draw(st.integers(0, len(rshape) - 1)) - rshape[index] = -1 - return tuple(rshape) - - -reshape_shape = st.shared(hh.shapes(max_side=MAX_SIDE), key="reshape_shape") +reshape_shape = st.shared(hh.shapes(), key="reshape_shape") @pytest.mark.unvectorized @given( x=hh.arrays(dtype=hh.all_dtypes, shape=reshape_shape), - shape=reshape_shapes(reshape_shape), + shape=hh.reshape_shapes(reshape_shape), ) def test_reshape(x, shape): out = xp.reshape(x, shape)