@@ -129,13 +129,19 @@ def render(self, task) -> RenderableType:
129
129
class MetricsTextColumn (ProgressColumn ):
130
130
"""A column containing text."""
131
131
132
- def __init__ (self , trainer , pl_module ):
132
+ def __init__ (self , trainer ):
133
133
self ._trainer = trainer
134
- self ._pl_module = pl_module
135
134
self ._tasks = {}
136
135
self ._current_task_id = 0
136
+ self ._metrics = {}
137
137
super ().__init__ ()
138
138
139
+ def update (self , metrics ):
140
+ # Called when metrics are ready to be rendered.
141
+ # This is to prevent render from causing deadlock issues by requesting metrics
142
+ # in separate threads.
143
+ self ._metrics = metrics
144
+
139
145
def render (self , task ) -> Text :
140
146
from pytorch_lightning .trainer .states import TrainerFn
141
147
@@ -149,14 +155,8 @@ def render(self, task) -> Text:
149
155
if self ._trainer .training and task .id != self ._current_task_id :
150
156
return self ._tasks [task .id ]
151
157
_text = ""
152
- # TODO(@daniellepintz): make this code cleaner
153
- progress_bar_callback = getattr (self ._trainer , "progress_bar_callback" , None )
154
- if progress_bar_callback :
155
- metrics = self ._trainer .progress_bar_callback .get_metrics (self ._trainer , self ._pl_module )
156
- else :
157
- metrics = self ._trainer .progress_bar_metrics
158
-
159
- for k , v in metrics .items ():
158
+
159
+ for k , v in self ._metrics .items ():
160
160
_text += f"{ k } : { round (v , 3 ) if isinstance (v , float ) else v } "
161
161
return Text (_text , justify = "left" )
162
162
@@ -225,9 +225,9 @@ def __init__(
225
225
self .progress : Optional [Progress ] = None
226
226
self .val_sanity_progress_bar_id : Optional [int ] = None
227
227
self ._reset_progress_bar_ids ()
228
+ self ._metric_component = None
228
229
self ._progress_stopped : bool = False
229
230
self .theme = theme
230
- self ._console : Console = Console ()
231
231
232
232
@property
233
233
def refresh_rate_per_second (self ) -> float :
@@ -268,12 +268,15 @@ def test_description(self) -> str:
268
268
def predict_description (self ) -> str :
269
269
return "Predicting"
270
270
271
- def _init_progress (self , trainer , pl_module ):
272
- if self .progress is None or self ._progress_stopped :
271
+ def _init_progress (self , trainer ):
272
+ if self .is_enabled and ( self . progress is None or self ._progress_stopped ) :
273
273
self ._reset_progress_bar_ids ()
274
+ self ._console : Console = Console ()
274
275
self ._console .clear_live ()
276
+ self ._metric_component = MetricsTextColumn (trainer )
275
277
self .progress = CustomProgress (
276
- * self .configure_columns (trainer , pl_module ),
278
+ * self .configure_columns (trainer ),
279
+ self ._metric_component ,
277
280
refresh_per_second = self .refresh_rate_per_second ,
278
281
disable = self .is_disabled ,
279
282
console = self ._console ,
@@ -284,19 +287,19 @@ def _init_progress(self, trainer, pl_module):
284
287
285
288
def on_train_start (self , trainer , pl_module ):
286
289
super ().on_train_start (trainer , pl_module )
287
- self ._init_progress (trainer , pl_module )
290
+ self ._init_progress (trainer )
288
291
289
292
def on_predict_start (self , trainer , pl_module ):
290
293
super ().on_predict_start (trainer , pl_module )
291
- self ._init_progress (trainer , pl_module )
294
+ self ._init_progress (trainer )
292
295
293
296
def on_test_start (self , trainer , pl_module ):
294
297
super ().on_test_start (trainer , pl_module )
295
- self ._init_progress (trainer , pl_module )
298
+ self ._init_progress (trainer )
296
299
297
300
def on_validation_start (self , trainer , pl_module ):
298
301
super ().on_validation_start (trainer , pl_module )
299
- self ._init_progress (trainer , pl_module )
302
+ self ._init_progress (trainer )
300
303
301
304
def __getstate__ (self ):
302
305
# can't pickle the rich progress objects
@@ -307,12 +310,11 @@ def __getstate__(self):
307
310
308
311
def __setstate__ (self , state ):
309
312
self .__dict__ = state
310
- # reset console reference after loading progress
311
- self ._console = Console ()
313
+ state ["_console" ] = Console ()
312
314
313
315
def on_sanity_check_start (self , trainer , pl_module ):
314
316
super ().on_sanity_check_start (trainer , pl_module )
315
- self ._init_progress (trainer , pl_module )
317
+ self ._init_progress (trainer )
316
318
self .val_sanity_progress_bar_id = self ._add_task (trainer .num_sanity_val_steps , self .sanity_check_description )
317
319
318
320
def on_sanity_check_end (self , trainer , pl_module ):
@@ -333,10 +335,10 @@ def on_train_epoch_start(self, trainer, pl_module):
333
335
train_description = self ._get_train_description (trainer .current_epoch )
334
336
if self .main_progress_bar_id is not None and self ._leave :
335
337
self ._stop_progress ()
336
- self ._init_progress (trainer , pl_module )
338
+ self ._init_progress (trainer )
337
339
if self .main_progress_bar_id is None :
338
340
self .main_progress_bar_id = self ._add_task (total_batches , train_description )
339
- else :
341
+ elif self . progress is not None :
340
342
self .progress .reset (
341
343
self .main_progress_bar_id , total = total_batches , description = train_description , visible = True
342
344
)
@@ -377,6 +379,7 @@ def on_predict_epoch_start(self, trainer, pl_module):
377
379
def on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
378
380
super ().on_train_batch_end (trainer , pl_module , outputs , batch , batch_idx )
379
381
self ._update (self .main_progress_bar_id )
382
+ self ._update_metrics (trainer , pl_module )
380
383
381
384
def on_validation_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ):
382
385
super ().on_validation_batch_end (trainer , pl_module , outputs , batch , batch_idx , dataloader_idx )
@@ -419,6 +422,11 @@ def _reset_progress_bar_ids(self):
419
422
self .test_progress_bar_id : Optional [int ] = None
420
423
self .predict_progress_bar_id : Optional [int ] = None
421
424
425
+ def _update_metrics (self , trainer , pl_module ) -> None :
426
+ metrics = self .get_metrics (trainer , pl_module )
427
+ if self ._metric_component :
428
+ self ._metric_component .update (metrics )
429
+
422
430
def teardown (self , trainer , pl_module , stage : Optional [str ] = None ) -> None :
423
431
self ._stop_progress ()
424
432
@@ -441,7 +449,7 @@ def main_progress_bar(self) -> Task:
441
449
def test_progress_bar (self ) -> Task :
442
450
return self .progress .tasks [self .test_progress_bar_id ]
443
451
444
- def configure_columns (self , trainer , pl_module ) -> list :
452
+ def configure_columns (self , trainer ) -> list :
445
453
return [
446
454
TextColumn ("[progress.description]{task.description}" ),
447
455
CustomBarColumn (
@@ -452,5 +460,4 @@ def configure_columns(self, trainer, pl_module) -> list:
452
460
BatchesProcessedColumn (style = self .theme .batch_process ),
453
461
CustomTimeColumn (style = self .theme .time ),
454
462
ProcessingSpeedColumn (style = self .theme .processing_speed ),
455
- MetricsTextColumn (trainer , pl_module ),
456
463
]
0 commit comments