diff --git a/array_api_tests/__init__.py b/array_api_tests/__init__.py index 763cef14..3b2c38e1 100644 --- a/array_api_tests/__init__.py +++ b/array_api_tests/__init__.py @@ -1,10 +1,24 @@ +from functools import wraps + +from hypothesis import strategies as st from hypothesis.extra.array_api import make_strategies_namespace from ._array_module import mod as _xp +__all__ = ["xps"] xps = make_strategies_namespace(_xp) -del _xp -del make_strategies_namespace +# We monkey patch floats() to always disable subnormals as they are out-of-scope + +_floats = st.floats + + +@wraps(_floats) +def floats(*a, **kw): + kw["allow_subnormal"] = False + return _floats(*a, **kw) + + +st.floats = floats