@@ -204,10 +204,65 @@ def __call__(self, dtype: DataType, **kw) -> st.SearchStrategy[float]:
204
204
@dataclass
205
205
class BoundFromDtype (FromDtypeFunc ):
206
206
"""
207
- A callable which bounds kwargs and strategy filters to xps.from_dtype() or
208
- equivalent function.
207
+ A xps.from_dtype()-like callable with bounded kwargs, filters and base function.
209
208
209
+ We can bound:
210
210
211
+ 1. Keyword arguments that xps.from_dtype() can use, e.g.
212
+
213
+ >>> from_dtype = BoundFromDtype(kwargs={'min_value': 0, 'allow_infinity': False})
214
+ >>> strategy = from_dtype(xp.float64)
215
+
216
+ is equivalent to
217
+
218
+ >>> strategy = xps.from_dtype(xp.float64, min_value=0, allow_infinity=False)
219
+
220
+ i.e. a strategy that generates finite floats above 0
221
+
222
+ 2. Functions that filter the elements strategy that xps.from_dtype() returns, e.g.
223
+
224
+ >>> from_dtype = BoundFromDtype(filter=lambda i: i != 0)
225
+ >>> strategy = from_dtype(xp.float64)
226
+
227
+ is equivalent to
228
+
229
+ >>> strategy = xps.from_dtype(xp.float64).filter(lambda i: i != 0)
230
+
231
+ i.e. a strategy that generates any floats except 0
232
+
233
+ 3. The underlying function that returns an elements strategy from a dtype, e.g.
234
+
235
+ >>> from_dtype = BoundFromDtype(
236
+ ... from_dtype=lambda d: st.integers(
237
+ ... math.ceil(xp.finfo(d).min), math.floor(xp.finfo(d).max)
238
+ ... )
239
+ ... )
240
+ >>> strategy = from_dtype(xp.float64)
241
+
242
+ is equivalent to
243
+
244
+ >>> strategy = lambda d: st.integers(
245
+ ... math.ceil(xp.finfo(d).min), math.floor(xp.finfo(d).max)
246
+ ... )
247
+
248
+ i.e. a strategy that generates integers (within the dtypes range)
249
+
250
+ This is useful to avoid translating special case conditions into either a
251
+ dict, filter or "base func", and instead allows us to generalise these three
252
+ components into a callable equivalent of xps.from_dtype().
253
+
254
+ Additionally, BoundFromDtype instances can be added together. This allows us
255
+ to keep parsing each condition individually - so we don't need to duplicate
256
+ complicated parsing code - as ultimately we can represent (and subsequently
257
+ test for) special cases which have more than one condition per array, e.g.
258
+
259
+ "If x1_i is greater than 0 and x1_i is not 42, ..."
260
+
261
+ could be translated as
262
+
263
+ >>> gt_0_from_dtype = BoundFromDtype(kwargs={'min_value': 0})
264
+ >>> not_42_from_dtype = BoundFromDtype(filter=lambda i: i != 42)
265
+ >>> from_dtype = gt_0_from_dtype + not_42_from_dtype
211
266
212
267
"""
213
268
0 commit comments