@@ -129,13 +129,19 @@ def render(self, task) -> RenderableType:
129
129
class MetricsTextColumn (ProgressColumn ):
130
130
"""A column containing text."""
131
131
132
- def __init__ (self , trainer , pl_module ):
132
+ def __init__ (self , trainer ):
133
133
self ._trainer = trainer
134
- self ._pl_module = pl_module
135
134
self ._tasks = {}
136
135
self ._current_task_id = 0
136
+ self ._metrics = {}
137
137
super ().__init__ ()
138
138
139
+ def update (self , metrics ):
140
+ # Called when metrics are ready to be rendered.
141
+ # This is to prevent render from causing deadlock issues by requesting metrics
142
+ # in separate threads.
143
+ self ._metrics = metrics
144
+
139
145
def render (self , task ) -> Text :
140
146
from pytorch_lightning .trainer .states import TrainerFn
141
147
@@ -149,14 +155,8 @@ def render(self, task) -> Text:
149
155
if self ._trainer .training and task .id != self ._current_task_id :
150
156
return self ._tasks [task .id ]
151
157
_text = ""
152
- # TODO(@daniellepintz): make this code cleaner
153
- progress_bar_callback = getattr (self ._trainer , "progress_bar_callback" , None )
154
- if progress_bar_callback :
155
- metrics = self ._trainer .progress_bar_callback .get_metrics (self ._trainer , self ._pl_module )
156
- else :
157
- metrics = self ._trainer .progress_bar_metrics
158
-
159
- for k , v in metrics .items ():
158
+
159
+ for k , v in self ._metrics .items ():
160
160
_text += f"{ k } : { round (v , 3 ) if isinstance (v , float ) else v } "
161
161
return Text (_text , justify = "left" )
162
162
@@ -220,9 +220,9 @@ def __init__(
220
220
self .progress : Optional [Progress ] = None
221
221
self .val_sanity_progress_bar_id : Optional [int ] = None
222
222
self ._reset_progress_bar_ids ()
223
+ self ._metric_component = None
223
224
self ._progress_stopped : bool = False
224
225
self .theme = theme
225
- self ._console : Console = Console ()
226
226
227
227
@property
228
228
def refresh_rate_per_second (self ) -> float :
@@ -263,12 +263,15 @@ def test_description(self) -> str:
263
263
def predict_description (self ) -> str :
264
264
return "Predicting"
265
265
266
- def _init_progress (self , trainer , pl_module ):
267
- if self .progress is None or self ._progress_stopped :
266
+ def _init_progress (self , trainer ):
267
+ if self .is_enabled and ( self . progress is None or self ._progress_stopped ) :
268
268
self ._reset_progress_bar_ids ()
269
+ self ._console : Console = Console ()
269
270
self ._console .clear_live ()
271
+ self ._metric_component = MetricsTextColumn (trainer )
270
272
self .progress = CustomProgress (
271
- * self .configure_columns (trainer , pl_module ),
273
+ * self .configure_columns (trainer ),
274
+ self ._metric_component ,
272
275
refresh_per_second = self .refresh_rate_per_second ,
273
276
disable = self .is_disabled ,
274
277
console = self ._console ,
@@ -279,19 +282,19 @@ def _init_progress(self, trainer, pl_module):
279
282
280
283
def on_train_start (self , trainer , pl_module ):
281
284
super ().on_train_start (trainer , pl_module )
282
- self ._init_progress (trainer , pl_module )
285
+ self ._init_progress (trainer )
283
286
284
287
def on_predict_start (self , trainer , pl_module ):
285
288
super ().on_predict_start (trainer , pl_module )
286
- self ._init_progress (trainer , pl_module )
289
+ self ._init_progress (trainer )
287
290
288
291
def on_test_start (self , trainer , pl_module ):
289
292
super ().on_test_start (trainer , pl_module )
290
- self ._init_progress (trainer , pl_module )
293
+ self ._init_progress (trainer )
291
294
292
295
def on_validation_start (self , trainer , pl_module ):
293
296
super ().on_validation_start (trainer , pl_module )
294
- self ._init_progress (trainer , pl_module )
297
+ self ._init_progress (trainer )
295
298
296
299
def __getstate__ (self ):
297
300
# can't pickle the rich progress objects
@@ -302,12 +305,11 @@ def __getstate__(self):
302
305
303
306
def __setstate__ (self , state ):
304
307
self .__dict__ = state
305
- # reset console reference after loading progress
306
- self ._console = Console ()
308
+ state ["_console" ] = Console ()
307
309
308
310
def on_sanity_check_start (self , trainer , pl_module ):
309
311
super ().on_sanity_check_start (trainer , pl_module )
310
- self ._init_progress (trainer , pl_module )
312
+ self ._init_progress (trainer )
311
313
self .val_sanity_progress_bar_id = self ._add_task (trainer .num_sanity_val_steps , self .sanity_check_description )
312
314
313
315
def on_sanity_check_end (self , trainer , pl_module ):
@@ -328,10 +330,10 @@ def on_train_epoch_start(self, trainer, pl_module):
328
330
train_description = self ._get_train_description (trainer .current_epoch )
329
331
if self .main_progress_bar_id is not None and self ._leave :
330
332
self ._stop_progress ()
331
- self ._init_progress (trainer , pl_module )
333
+ self ._init_progress (trainer )
332
334
if self .main_progress_bar_id is None :
333
335
self .main_progress_bar_id = self ._add_task (total_batches , train_description )
334
- else :
336
+ elif self . progress is not None :
335
337
self .progress .reset (
336
338
self .main_progress_bar_id , total = total_batches , description = train_description , visible = True
337
339
)
@@ -372,6 +374,7 @@ def on_predict_epoch_start(self, trainer, pl_module):
372
374
def on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
373
375
super ().on_train_batch_end (trainer , pl_module , outputs , batch , batch_idx )
374
376
self ._update (self .main_progress_bar_id )
377
+ self ._update_metrics (trainer , pl_module )
375
378
376
379
def on_validation_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ):
377
380
super ().on_validation_batch_end (trainer , pl_module , outputs , batch , batch_idx , dataloader_idx )
@@ -414,6 +417,11 @@ def _reset_progress_bar_ids(self):
414
417
self .test_progress_bar_id : Optional [int ] = None
415
418
self .predict_progress_bar_id : Optional [int ] = None
416
419
420
+ def _update_metrics (self , trainer , pl_module ) -> None :
421
+ metrics = self .get_metrics (trainer , pl_module )
422
+ if self ._metric_component :
423
+ self ._metric_component .update (metrics )
424
+
417
425
def teardown (self , trainer , pl_module , stage : Optional [str ] = None ) -> None :
418
426
self ._stop_progress ()
419
427
@@ -436,7 +444,7 @@ def main_progress_bar(self) -> Task:
436
444
def test_progress_bar (self ) -> Task :
437
445
return self .progress .tasks [self .test_progress_bar_id ]
438
446
439
- def configure_columns (self , trainer , pl_module ) -> list :
447
+ def configure_columns (self , trainer ) -> list :
440
448
return [
441
449
TextColumn ("[progress.description]{task.description}" ),
442
450
CustomBarColumn (
@@ -447,5 +455,4 @@ def configure_columns(self, trainer, pl_module) -> list:
447
455
BatchesProcessedColumn (style = self .theme .batch_process ),
448
456
CustomTimeColumn (style = self .theme .time ),
449
457
ProcessingSpeedColumn (style = self .theme .processing_speed ),
450
- MetricsTextColumn (trainer , pl_module ),
451
458
]
0 commit comments