Skip to content

Commit 94804ee

Browse files
committed
Merge remote-tracking branch 'origin/main' into pre-commit-ci-update-config
2 parents cf676bb + b883911 commit 94804ee

File tree

4 files changed

+104
-9
lines changed

4 files changed

+104
-9
lines changed

adaptive/runner.py

+94
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,14 @@ def cancel(self) -> None:
722722
"""
723723
self.task.cancel()
724724

725+
def block_until_done(self) -> None:
726+
if in_ipynb():
727+
raise RuntimeError(
728+
"Cannot block the event loop when running in a Jupyter notebook."
729+
" Use `await runner.task` instead."
730+
)
731+
self.ioloop.run_until_complete(self.task)
732+
725733
def live_plot(
726734
self,
727735
*,
@@ -768,6 +776,55 @@ def live_info(self, *, update_interval: float = 0.1) -> None:
768776
"""
769777
return live_info(self, update_interval=update_interval)
770778

779+
def live_info_terminal(
780+
self, *, update_interval: float = 0.5, overwrite_previous: bool = True
781+
) -> asyncio.Task:
782+
"""
783+
Display live information about the runner in the terminal.
784+
785+
This function provides a live update of the runner's status in the terminal.
786+
The update can either overwrite the previous status or be printed on a new line.
787+
788+
Parameters
789+
----------
790+
update_interval : float, optional
791+
The time interval (in seconds) at which the runner's status is updated in the terminal.
792+
Default is 0.5 seconds.
793+
overwrite_previous : bool, optional
794+
If True, each update will overwrite the previous status in the terminal.
795+
If False, each update will be printed on a new line.
796+
Default is True.
797+
798+
Returns
799+
-------
800+
asyncio.Task
801+
The asynchronous task responsible for updating the runner's status in the terminal.
802+
803+
Examples
804+
--------
805+
>>> runner = AsyncRunner(...)
806+
>>> runner.live_info_terminal(update_interval=1.0, overwrite_previous=False)
807+
808+
Notes
809+
-----
810+
This function uses ANSI escape sequences to control the terminal's cursor position.
811+
It might not work as expected on all terminal emulators.
812+
"""
813+
814+
async def _update(runner: AsyncRunner) -> None:
815+
try:
816+
while not runner.task.done():
817+
if overwrite_previous:
818+
# Clear the terminal
819+
print("\033[H\033[J", end="")
820+
print(_info_text(runner, separator="\t"))
821+
await asyncio.sleep(update_interval)
822+
823+
except asyncio.CancelledError:
824+
print("Live info display cancelled.")
825+
826+
return self.ioloop.create_task(_update(self))
827+
771828
async def _run(self) -> None:
772829
first_completed = asyncio.FIRST_COMPLETED
773830

@@ -847,6 +904,43 @@ async def _saver():
847904
return self.saving_task
848905

849906

907+
def _info_text(runner, separator: str = "\n"):
908+
status = runner.status()
909+
910+
color_map = {
911+
"cancelled": "\033[33m", # Yellow
912+
"failed": "\033[31m", # Red
913+
"running": "\033[34m", # Blue
914+
"finished": "\033[32m", # Green
915+
}
916+
917+
overhead = runner.overhead()
918+
if overhead < 50:
919+
overhead_color = "\033[32m" # Green
920+
else:
921+
overhead_color = "\033[31m" # Red
922+
923+
info = [
924+
("time", str(datetime.now())),
925+
("status", f"{color_map[status]}{status}\033[0m"),
926+
("elapsed time", str(timedelta(seconds=runner.elapsed_time()))),
927+
("overhead", f"{overhead_color}{overhead:.2f}%\033[0m"),
928+
]
929+
930+
with suppress(Exception):
931+
info.append(("# of points", runner.learner.npoints))
932+
933+
with suppress(Exception):
934+
info.append(("# of samples", runner.learner.nsamples))
935+
936+
with suppress(Exception):
937+
info.append(("latest loss", f'{runner.learner._cache["loss"]:.3f}'))
938+
939+
width = 30
940+
formatted_info = [f"{k}: {v}".ljust(width) for i, (k, v) in enumerate(info)]
941+
return separator.join(formatted_info)
942+
943+
850944
# Default runner
851945
Runner = AsyncRunner
852946

adaptive/tests/test_runner.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import platform
32
import sys
43
import time
@@ -34,7 +33,7 @@ def blocking_runner(learner, **kw):
3433

3534
def async_runner(learner, **kw):
3635
runner = AsyncRunner(learner, executor=SequentialExecutor(), **kw)
37-
asyncio.get_event_loop().run_until_complete(runner.task)
36+
runner.block_until_done()
3837

3938

4039
runners = [simple, blocking_runner, async_runner]
@@ -71,7 +70,7 @@ async def f(x):
7170

7271
learner = Learner1D(f, (-1, 1))
7372
runner = AsyncRunner(learner, npoints_goal=10)
74-
asyncio.get_event_loop().run_until_complete(runner.task)
73+
runner.block_until_done()
7574

7675

7776
# --- Test with different executors
@@ -158,7 +157,7 @@ def test_loky_executor(loky_executor):
158157
def test_default_executor():
159158
learner = Learner1D(linear, (-1, 1))
160159
runner = AsyncRunner(learner, npoints_goal=10)
161-
asyncio.get_event_loop().run_until_complete(runner.task)
160+
runner.block_until_done()
162161

163162

164163
def test_auto_goal():

docs/environment.yml

+6-4
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ dependencies:
99
- scikit-optimize=0.9.0
1010
- scikit-learn=1.2.2
1111
- scipy=1.10.1
12-
- holoviews=1.15.4
13-
- bokeh=2.4.3
14-
- panel=0.14.4
12+
- holoviews=1.18.3
13+
- bokeh=3.3.4
14+
- panel=1.3.8
1515
- pandas=2.0.0
1616
- plotly=5.14.1
1717
- ipywidgets=8.0.6
@@ -23,6 +23,8 @@ dependencies:
2323
- loky=3.3.0
2424
- furo=2023.3.27
2525
- myst-parser=0.18.1
26-
- dask=2023.3.2
26+
- dask=2024.2.0
2727
- emoji=2.2.0
2828
- versioningit=2.2.0
29+
- distributed=2024.2.0
30+
- param=2.0.2

docs/source/tutorial/tutorial.parallelism.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ if __name__ == "__main__":
8989
runner.start_periodic_saving(dict(fname=fname), interval=600)
9090

9191
# block until runner goal reached
92-
runner.ioloop.run_until_complete(runner.task)
92+
runner.block_until_done()
9393

9494
# save one final time before exiting
9595
learner.save(fname)

0 commit comments

Comments
 (0)