Skip to content

Commit ca4ce65

Browse files
committed
Document BoundFromDtype with extensive examples
1 parent 9c4b26b commit ca4ce65

File tree

1 file changed

+57
-2
lines changed

1 file changed

+57
-2
lines changed

Diff for: array_api_tests/test_special_cases.py

+57-2
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,65 @@ def __call__(self, dtype: DataType, **kw) -> st.SearchStrategy[float]:
204204
@dataclass
205205
class BoundFromDtype(FromDtypeFunc):
206206
"""
207-
A callable which bounds kwargs and strategy filters to xps.from_dtype() or
208-
equivalent function.
207+
A xps.from_dtype()-like callable with bounded kwargs, filters and base function.
209208
209+
We can bound:
210210
211+
1. Keyword arguments that xps.from_dtype() can use, e.g.
212+
213+
>>> from_dtype = BoundFromDtype(kwargs={'min_value': 0, 'allow_infinity': False})
214+
>>> strategy = from_dtype(xp.float64)
215+
216+
is equivalent to
217+
218+
>>> strategy = xps.from_dtype(xp.float64, min_value=0, allow_infinity=False)
219+
220+
i.e. a strategy that generates finite floats above 0
221+
222+
2. Functions that filter the elements strategy that xps.from_dtype() returns, e.g.
223+
224+
>>> from_dtype = BoundFromDtype(filter=lambda i: i != 0)
225+
>>> strategy = from_dtype(xp.float64)
226+
227+
is equivalent to
228+
229+
>>> strategy = xps.from_dtype(xp.float64).filter(lambda i: i != 0)
230+
231+
i.e. a strategy that generates any floats except 0
232+
233+
3. The underlying function that returns an elements strategy from a dtype, e.g.
234+
235+
>>> from_dtype = BoundFromDtype(
236+
... from_dtype=lambda d: st.integers(
237+
... math.ceil(xp.finfo(d).min), math.floor(xp.finfo(d).max)
238+
... )
239+
... )
240+
>>> strategy = from_dtype(xp.float64)
241+
242+
is equivalent to
243+
244+
>>> strategy = lambda d: st.integers(
245+
... math.ceil(xp.finfo(d).min), math.floor(xp.finfo(d).max)
246+
... )
247+
248+
i.e. a strategy that generates integers (within the dtypes range)
249+
250+
This is useful to avoid translating special case conditions into either a
251+
dict, filter or "base func", and instead allows us to generalise these three
252+
components into a callable equivalent of xps.from_dtype().
253+
254+
Additionally, BoundFromDtype instances can be added together. This allows us
255+
to keep parsing each condition individually - so we don't need to duplicate
256+
complicated parsing code - as ultimately we can represent (and subsequently
257+
test for) special cases which have more than one condition per array, e.g.
258+
259+
"If x1_i is greater than 0 and x1_i is not 42, ..."
260+
261+
could be translated as
262+
263+
>>> gt_0_from_dtype = BoundFromDtype(kwargs={'min_value': 0})
264+
>>> not_42_from_dtype = BoundFromDtype(filter=lambda i: i != 42)
265+
>>> from_dtype = gt_0_from_dtype + not_42_from_dtype
211266
212267
"""
213268

0 commit comments

Comments
 (0)