1
+ from __future__ import annotations
2
+
1
3
import abc
2
4
import asyncio
3
5
import concurrent .futures as concurrent
11
13
import traceback
12
14
import warnings
13
15
from contextlib import suppress
16
+ from typing import TYPE_CHECKING , Any , Callable
14
17
15
18
import loky
16
19
17
20
from adaptive .notebook_integration import in_ipynb , live_info , live_plot
18
21
22
+ if TYPE_CHECKING :
23
+ from adaptive import BaseLearner
24
+
19
25
try :
20
26
import ipyparallel
21
27
@@ -663,15 +669,26 @@ def elapsed_time(self):
663
669
end_time = time .time ()
664
670
return end_time - self .start_time
665
671
666
- def start_periodic_saving (self , save_kwargs , interval ):
672
+ def start_periodic_saving (
673
+ self ,
674
+ save_kwargs : dict [str , Any ] | None = None ,
675
+ interval : int = 30 ,
676
+ method : Callable [[BaseLearner ], None ] | None = None ,
677
+ ):
667
678
"""Periodically save the learner's data.
668
679
669
680
Parameters
670
681
----------
671
682
save_kwargs : dict
672
683
Key-word arguments for ``learner.save(**save_kwargs)``.
684
+ Only used if ``method=None``.
673
685
interval : int
674
686
Number of seconds between saving the learner.
687
+ method : callable
688
+ The method to use for saving the learner. If None, the default
689
+ saves the learner using "pickle" which calls
690
+ ``learner.save(**save_kwargs)``. Otherwise provide a callable
691
+ that takes the learner and saves the learner.
675
692
676
693
Example
677
694
-------
@@ -681,11 +698,19 @@ def start_periodic_saving(self, save_kwargs, interval):
681
698
... interval=600)
682
699
"""
683
700
684
- async def _saver (save_kwargs = save_kwargs , interval = interval ):
701
+ def default_save (learner ):
702
+ learner .save (** save_kwargs )
703
+
704
+ if method is None :
705
+ method = default_save
706
+ if save_kwargs is None :
707
+ raise ValueError ("Must provide `save_kwargs` if method=None." )
708
+
709
+ async def _saver ():
685
710
while self .status () == "running" :
686
- self .learner . save ( ** save_kwargs )
711
+ method ( self .learner )
687
712
await asyncio .sleep (interval )
688
- self .learner . save ( ** save_kwargs ) # one last time
713
+ method ( self .learner ) # one last time
689
714
690
715
self .saving_task = self .ioloop .create_task (_saver ())
691
716
return self .saving_task
0 commit comments