diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 652e12e..0cee0b4 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -521,7 +521,7 @@ def test_xp(self, xp: ModuleType): class TestExpandDims: def test_single_axis(self, xp: ModuleType): """Trivial case where xpx.expand_dims doesn't add anything to xp.expand_dims""" - a = xp.empty((2, 3, 4, 5)) + a = xp.asarray(np.reshape(np.arange(2 * 3 * 4 * 5), (2, 3, 4, 5))) for axis in range(-5, 4): b = expand_dims(a, axis=axis) xp_assert_equal(b, xp.expand_dims(a, axis=axis))