Skip to content

Commit 067444f

Browse files
authored
Add adaptive.utils.daskify (#422)
* Add adaptive.utils.daskify * Do not overwrite variables g, h * Fix header level * Add link to TutorialAdvancedTopics
1 parent f31d0a5 commit 067444f

File tree

3 files changed

+105
-13
lines changed

3 files changed

+105
-13
lines changed

adaptive/utils.py

+46-2
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,17 @@
77
import os
88
import pickle
99
import warnings
10-
from collections.abc import Iterator, Sequence
10+
from collections.abc import Awaitable, Iterator, Sequence
1111
from contextlib import contextmanager
12+
from functools import wraps
1213
from itertools import product
13-
from typing import Any, Callable
14+
from typing import TYPE_CHECKING, Any, Callable, TypeVar
1415

1516
import cloudpickle
1617

18+
if TYPE_CHECKING:
19+
from dask.distributed import Client as AsyncDaskClient
20+
1721

1822
def named_product(**items: Sequence[Any]):
1923
names = items.keys()
@@ -161,3 +165,43 @@ def map(self, fn, *iterable, timeout=None, chunksize=1):
161165

162166
def shutdown(self, wait=True):
163167
pass
168+
169+
170+
def _cache_key(args: tuple[Any], kwargs: dict[str, Any]) -> str:
171+
arg_strings = [str(a) for a in args]
172+
kwarg_strings = [f"{k}={v}" for k, v in sorted(kwargs.items())]
173+
return "_".join(arg_strings + kwarg_strings)
174+
175+
176+
T = TypeVar("T")
177+
178+
179+
def daskify(
180+
client: AsyncDaskClient, cache: bool = False
181+
) -> Callable[[Callable[..., T]], Callable[..., Awaitable[T]]]:
182+
from dask import delayed
183+
184+
def _daskify(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
185+
if cache:
186+
func.cache = {} # type: ignore[attr-defined]
187+
188+
delayed_func = delayed(func)
189+
190+
@wraps(func)
191+
async def wrapper(*args: Any, **kwargs: Any) -> T:
192+
if cache:
193+
key = _cache_key(args, kwargs) # type: ignore[arg-type]
194+
future = func.cache.get(key) # type: ignore[attr-defined]
195+
196+
if future is None:
197+
future = client.compute(delayed_func(*args, **kwargs))
198+
func.cache[key] = future # type: ignore[attr-defined]
199+
else:
200+
future = client.compute(delayed_func(*args, **kwargs))
201+
202+
result = await future
203+
return result
204+
205+
return wrapper
206+
207+
return _daskify

docs/source/tutorial/tutorial.advanced-topics.md

+57-11
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ kernelspec:
99
display_name: python3
1010
name: python3
1111
---
12-
12+
(TutorialAdvancedTopics)=
1313
# Advanced Topics
1414

1515
```{note}
@@ -365,22 +365,19 @@ await runner.task # This is not needed in a notebook environment!
365365
# The result will only be set when the runner is done.
366366
timer.result()
367367
```
368-
368+
(CustomParallelization)=
369369
## Custom parallelization using coroutines
370370

371371
Adaptive by itself does not implement a way of sharing partial results between function executions.
372372
Instead its implementation of parallel computation using executors is minimal by design.
373373
The appropriate way to implement custom parallelization is by using coroutines (asynchronous functions).
374374

375+
375376
We illustrate this approach by using `dask.distributed` for parallel computations in part because it supports asynchronous operation out-of-the-box.
376-
Let us consider a function `f(x)` which is composed by two parts:
377-
a slow part `g` which can be reused by multiple inputs and shared across function evaluations and a fast part `h` that will be computed for every `x`.
377+
We will focus on a function `f(x)` that consists of two distinct components: a slow part `g` that can be reused across multiple inputs and shared among various function evaluations, and a fast part `h` that is calculated for each `x` value.
378378

379379
```{code-cell} ipython3
380-
import time
381-
382-
383-
def f(x):
380+
def f(x): # example function without caching
384381
"""
385382
Integer part of `x` repeats and should be reused
386383
Decimal part requires a new computation
@@ -390,7 +387,9 @@ def f(x):
390387
391388
def g(x):
392389
"""Slow but reusable function"""
393-
time.sleep(random.randrange(5))
390+
from time import sleep
391+
392+
sleep(random.randrange(5))
394393
return x**2
395394
396395
@@ -399,12 +398,59 @@ def h(x):
399398
return x**3
400399
```
401400

401+
### Using `adaptive.utils.daskify`
402+
403+
To simplify the process of using coroutines and caching with dask and Adaptive, we provide the {func}`adaptive.utils.daskify` decorator. This decorator can be used to parallelize functions with caching as well as functions without caching, making it a powerful tool for custom parallelization in Adaptive.
404+
405+
```{code-cell} ipython3
406+
from dask.distributed import Client
407+
408+
import adaptive
409+
410+
client = await Client(asynchronous=True)
411+
412+
413+
# The g function has caching enabled
414+
g_dask = adaptive.utils.daskify(client, cache=True)(g)
415+
416+
# Can be used like a decorator too:
417+
# >>> @adaptive.utils.daskify(client, cache=True)
418+
# ... def g(x): ...
419+
420+
# The h function does not use caching
421+
h_dask = adaptive.utils.daskify(client)(h)
422+
423+
# Now we need to rewrite `f(x)` to use `g` and `h` as coroutines
424+
425+
426+
async def f_parallel(x):
427+
g_result = await g_dask(int(x))
428+
h_result = await h_dask(x % 1)
429+
return (g_result + h_result) ** 2
430+
431+
432+
learner = adaptive.Learner1D(f_parallel, bounds=(-3.5, 3.5))
433+
runner = adaptive.AsyncRunner(learner, loss_goal=0.01, ntasks=20)
434+
runner.live_info()
435+
```
436+
437+
Finally, we wait for the runner to finish, and then plot the result.
438+
439+
```{code-cell} ipython3
440+
await runner.task
441+
learner.plot()
442+
```
443+
444+
### Step-by-step explanation of custom parallelization
445+
446+
Now let's dive into a detailed explanation of the process to understand how the {func}`adaptive.utils.daskify` decorator works.
447+
402448
In order to combine reuse of values of `g` with adaptive, we need to convert `f` into a dask graph by using `dask.delayed`.
403449

404450
```{code-cell} ipython3
405451
from dask import delayed
406452
407-
# Convert g and h to dask.Delayed objects
453+
# Convert g and h to dask.Delayed objects, such that they run in the Client
408454
g, h = delayed(g), delayed(h)
409455
410456
@@ -441,7 +487,7 @@ learner = adaptive.Learner1D(f_parallel, bounds=(-3.5, 3.5))
441487
runner = adaptive.AsyncRunner(learner, loss_goal=0.01, ntasks=20)
442488
```
443489

444-
Finally we await for the runner to finish, and then plot the result.
490+
Finally we wait for the runner to finish, and then plot the result.
445491

446492
```{code-cell} ipython3
447493
await runner.task

docs/source/tutorial/tutorial.parallelism.md

+2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ runner.live_info()
5757
runner.live_plot(update_interval=0.1)
5858
```
5959

60+
Also check out the {ref}`Custom parallelization<CustomParallelization>` section in the {ref}`advanced topics tutorial<TutorialAdvancedTopics>` for more control over caching and parallelization.
61+
6062
## `mpi4py.futures.MPIPoolExecutor`
6163

6264
This makes sense if you want to run a `Learner` on a cluster non-interactively using a job script.

0 commit comments

Comments
 (0)