@@ -722,6 +722,14 @@ def cancel(self) -> None:
722
722
"""
723
723
self .task .cancel ()
724
724
725
+ def block_until_done (self ) -> None :
726
+ if in_ipynb ():
727
+ raise RuntimeError (
728
+ "Cannot block the event loop when running in a Jupyter notebook."
729
+ " Use `await runner.task` instead."
730
+ )
731
+ self .ioloop .run_until_complete (self .task )
732
+
725
733
def live_plot (
726
734
self ,
727
735
* ,
@@ -768,6 +776,55 @@ def live_info(self, *, update_interval: float = 0.1) -> None:
768
776
"""
769
777
return live_info (self , update_interval = update_interval )
770
778
779
+ def live_info_terminal (
780
+ self , * , update_interval : float = 0.5 , overwrite_previous : bool = True
781
+ ) -> asyncio .Task :
782
+ """
783
+ Display live information about the runner in the terminal.
784
+
785
+ This function provides a live update of the runner's status in the terminal.
786
+ The update can either overwrite the previous status or be printed on a new line.
787
+
788
+ Parameters
789
+ ----------
790
+ update_interval : float, optional
791
+ The time interval (in seconds) at which the runner's status is updated in the terminal.
792
+ Default is 0.5 seconds.
793
+ overwrite_previous : bool, optional
794
+ If True, each update will overwrite the previous status in the terminal.
795
+ If False, each update will be printed on a new line.
796
+ Default is True.
797
+
798
+ Returns
799
+ -------
800
+ asyncio.Task
801
+ The asynchronous task responsible for updating the runner's status in the terminal.
802
+
803
+ Examples
804
+ --------
805
+ >>> runner = AsyncRunner(...)
806
+ >>> runner.live_info_terminal(update_interval=1.0, overwrite_previous=False)
807
+
808
+ Notes
809
+ -----
810
+ This function uses ANSI escape sequences to control the terminal's cursor position.
811
+ It might not work as expected on all terminal emulators.
812
+ """
813
+
814
+ async def _update (runner : AsyncRunner ) -> None :
815
+ try :
816
+ while not runner .task .done ():
817
+ if overwrite_previous :
818
+ # Clear the terminal
819
+ print ("\033 [H\033 [J" , end = "" )
820
+ print (_info_text (runner , separator = "\t " ))
821
+ await asyncio .sleep (update_interval )
822
+
823
+ except asyncio .CancelledError :
824
+ print ("Live info display cancelled." )
825
+
826
+ return self .ioloop .create_task (_update (self ))
827
+
771
828
async def _run (self ) -> None :
772
829
first_completed = asyncio .FIRST_COMPLETED
773
830
@@ -847,6 +904,43 @@ async def _saver():
847
904
return self .saving_task
848
905
849
906
907
+ def _info_text (runner , separator : str = "\n " ):
908
+ status = runner .status ()
909
+
910
+ color_map = {
911
+ "cancelled" : "\033 [33m" , # Yellow
912
+ "failed" : "\033 [31m" , # Red
913
+ "running" : "\033 [34m" , # Blue
914
+ "finished" : "\033 [32m" , # Green
915
+ }
916
+
917
+ overhead = runner .overhead ()
918
+ if overhead < 50 :
919
+ overhead_color = "\033 [32m" # Green
920
+ else :
921
+ overhead_color = "\033 [31m" # Red
922
+
923
+ info = [
924
+ ("time" , str (datetime .now ())),
925
+ ("status" , f"{ color_map [status ]} { status } \033 [0m" ),
926
+ ("elapsed time" , str (timedelta (seconds = runner .elapsed_time ()))),
927
+ ("overhead" , f"{ overhead_color } { overhead :.2f} %\033 [0m" ),
928
+ ]
929
+
930
+ with suppress (Exception ):
931
+ info .append (("# of points" , runner .learner .npoints ))
932
+
933
+ with suppress (Exception ):
934
+ info .append (("# of samples" , runner .learner .nsamples ))
935
+
936
+ with suppress (Exception ):
937
+ info .append (("latest loss" , f'{ runner .learner ._cache ["loss" ]:.3f} ' ))
938
+
939
+ width = 30
940
+ formatted_info = [f"{ k } : { v } " .ljust (width ) for i , (k , v ) in enumerate (info )]
941
+ return separator .join (formatted_info )
942
+
943
+
850
944
# Default runner
851
945
Runner = AsyncRunner
852
946
0 commit comments