Skip to content

Commit c82eb99

Browse files
committed
Allow to periodically save with any function (#362)
* Allow to periodically save with pandas * Generalize saving * add future annotations * Fix defaults * remove typing_extensions again * rename
1 parent 441a6bf commit c82eb99

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

adaptive/runner.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import abc
24
import asyncio
35
import concurrent.futures as concurrent
@@ -11,11 +13,15 @@
1113
import traceback
1214
import warnings
1315
from contextlib import suppress
16+
from typing import TYPE_CHECKING, Any, Callable
1417

1518
import loky
1619

1720
from adaptive.notebook_integration import in_ipynb, live_info, live_plot
1821

22+
if TYPE_CHECKING:
23+
from adaptive import BaseLearner
24+
1925
try:
2026
import ipyparallel
2127

@@ -663,15 +669,26 @@ def elapsed_time(self):
663669
end_time = time.time()
664670
return end_time - self.start_time
665671

666-
def start_periodic_saving(self, save_kwargs, interval):
672+
def start_periodic_saving(
673+
self,
674+
save_kwargs: dict[str, Any] | None = None,
675+
interval: int = 30,
676+
method: Callable[[BaseLearner], None] | None = None,
677+
):
667678
"""Periodically save the learner's data.
668679
669680
Parameters
670681
----------
671682
save_kwargs : dict
672683
Key-word arguments for ``learner.save(**save_kwargs)``.
684+
Only used if ``method=None``.
673685
interval : int
674686
Number of seconds between saving the learner.
687+
method : callable
688+
The method to use for saving the learner. If None, the default
689+
saves the learner using "pickle" which calls
690+
``learner.save(**save_kwargs)``. Otherwise provide a callable
691+
that takes the learner and saves the learner.
675692
676693
Example
677694
-------
@@ -681,11 +698,19 @@ def start_periodic_saving(self, save_kwargs, interval):
681698
... interval=600)
682699
"""
683700

684-
async def _saver(save_kwargs=save_kwargs, interval=interval):
701+
def default_save(learner):
702+
learner.save(**save_kwargs)
703+
704+
if method is None:
705+
method = default_save
706+
if save_kwargs is None:
707+
raise ValueError("Must provide `save_kwargs` if method=None.")
708+
709+
async def _saver():
685710
while self.status() == "running":
686-
self.learner.save(**save_kwargs)
711+
method(self.learner)
687712
await asyncio.sleep(interval)
688-
self.learner.save(**save_kwargs) # one last time
713+
method(self.learner) # one last time
689714

690715
self.saving_task = self.ioloop.create_task(_saver())
691716
return self.saving_task

0 commit comments

Comments
 (0)