Skip to content
This repository was archived by the owner on Sep 28, 2022. It is now read-only.

Commit f6ea60f

Browse files
kaushikb11Raalsky
authored andcommitted
Fix deadlocks for distributed training for RichProgressBar (Lightning-AI#10428)
1 parent abd7062 commit f6ea60f

File tree

3 files changed

+38
-28
lines changed

3 files changed

+38
-28
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
107107
- Fixed issue with pickling `CSVLogger` after a call to `CSVLogger.save` ([#10388](https://github.com/PyTorchLightning/pytorch-lightning/pull/10388))
108108

109109

110+
- Fixed deadlocks for distributed training with `RichProgressBar` ([#10428](https://github.com/PyTorchLightning/pytorch-lightning/pull/10428))
111+
112+
110113
- Fixed the logging with `on_step=True` in epoch-level hooks causing unintended side-effects. Logging with `on_step=True` in epoch-level hooks will now correctly raise an error ([#10409](https://github.com/PyTorchLightning/pytorch-lightning/pull/10409))
111114

112115

pytorch_lightning/callbacks/progress/rich_progress.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,19 @@ def render(self, task) -> RenderableType:
129129
class MetricsTextColumn(ProgressColumn):
130130
"""A column containing text."""
131131

132-
def __init__(self, trainer, pl_module):
132+
def __init__(self, trainer):
133133
self._trainer = trainer
134-
self._pl_module = pl_module
135134
self._tasks = {}
136135
self._current_task_id = 0
136+
self._metrics = {}
137137
super().__init__()
138138

139+
def update(self, metrics):
140+
# Called when metrics are ready to be rendered.
141+
# This is to prevent render from causing deadlock issues by requesting metrics
142+
# in separate threads.
143+
self._metrics = metrics
144+
139145
def render(self, task) -> Text:
140146
from pytorch_lightning.trainer.states import TrainerFn
141147

@@ -149,14 +155,8 @@ def render(self, task) -> Text:
149155
if self._trainer.training and task.id != self._current_task_id:
150156
return self._tasks[task.id]
151157
_text = ""
152-
# TODO(@daniellepintz): make this code cleaner
153-
progress_bar_callback = getattr(self._trainer, "progress_bar_callback", None)
154-
if progress_bar_callback:
155-
metrics = self._trainer.progress_bar_callback.get_metrics(self._trainer, self._pl_module)
156-
else:
157-
metrics = self._trainer.progress_bar_metrics
158-
159-
for k, v in metrics.items():
158+
159+
for k, v in self._metrics.items():
160160
_text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
161161
return Text(_text, justify="left")
162162

@@ -225,9 +225,9 @@ def __init__(
225225
self.progress: Optional[Progress] = None
226226
self.val_sanity_progress_bar_id: Optional[int] = None
227227
self._reset_progress_bar_ids()
228+
self._metric_component = None
228229
self._progress_stopped: bool = False
229230
self.theme = theme
230-
self._console: Console = Console()
231231

232232
@property
233233
def refresh_rate_per_second(self) -> float:
@@ -268,12 +268,15 @@ def test_description(self) -> str:
268268
def predict_description(self) -> str:
269269
return "Predicting"
270270

271-
def _init_progress(self, trainer, pl_module):
272-
if self.progress is None or self._progress_stopped:
271+
def _init_progress(self, trainer):
272+
if self.is_enabled and (self.progress is None or self._progress_stopped):
273273
self._reset_progress_bar_ids()
274+
self._console: Console = Console()
274275
self._console.clear_live()
276+
self._metric_component = MetricsTextColumn(trainer)
275277
self.progress = CustomProgress(
276-
*self.configure_columns(trainer, pl_module),
278+
*self.configure_columns(trainer),
279+
self._metric_component,
277280
refresh_per_second=self.refresh_rate_per_second,
278281
disable=self.is_disabled,
279282
console=self._console,
@@ -284,19 +287,19 @@ def _init_progress(self, trainer, pl_module):
284287

285288
def on_train_start(self, trainer, pl_module):
286289
super().on_train_start(trainer, pl_module)
287-
self._init_progress(trainer, pl_module)
290+
self._init_progress(trainer)
288291

289292
def on_predict_start(self, trainer, pl_module):
290293
super().on_predict_start(trainer, pl_module)
291-
self._init_progress(trainer, pl_module)
294+
self._init_progress(trainer)
292295

293296
def on_test_start(self, trainer, pl_module):
294297
super().on_test_start(trainer, pl_module)
295-
self._init_progress(trainer, pl_module)
298+
self._init_progress(trainer)
296299

297300
def on_validation_start(self, trainer, pl_module):
298301
super().on_validation_start(trainer, pl_module)
299-
self._init_progress(trainer, pl_module)
302+
self._init_progress(trainer)
300303

301304
def __getstate__(self):
302305
# can't pickle the rich progress objects
@@ -307,12 +310,11 @@ def __getstate__(self):
307310

308311
def __setstate__(self, state):
309312
self.__dict__ = state
310-
# reset console reference after loading progress
311-
self._console = Console()
313+
state["_console"] = Console()
312314

313315
def on_sanity_check_start(self, trainer, pl_module):
314316
super().on_sanity_check_start(trainer, pl_module)
315-
self._init_progress(trainer, pl_module)
317+
self._init_progress(trainer)
316318
self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description)
317319

318320
def on_sanity_check_end(self, trainer, pl_module):
@@ -333,10 +335,10 @@ def on_train_epoch_start(self, trainer, pl_module):
333335
train_description = self._get_train_description(trainer.current_epoch)
334336
if self.main_progress_bar_id is not None and self._leave:
335337
self._stop_progress()
336-
self._init_progress(trainer, pl_module)
338+
self._init_progress(trainer)
337339
if self.main_progress_bar_id is None:
338340
self.main_progress_bar_id = self._add_task(total_batches, train_description)
339-
else:
341+
elif self.progress is not None:
340342
self.progress.reset(
341343
self.main_progress_bar_id, total=total_batches, description=train_description, visible=True
342344
)
@@ -377,6 +379,7 @@ def on_predict_epoch_start(self, trainer, pl_module):
377379
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
378380
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
379381
self._update(self.main_progress_bar_id)
382+
self._update_metrics(trainer, pl_module)
380383

381384
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
382385
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
@@ -419,6 +422,11 @@ def _reset_progress_bar_ids(self):
419422
self.test_progress_bar_id: Optional[int] = None
420423
self.predict_progress_bar_id: Optional[int] = None
421424

425+
def _update_metrics(self, trainer, pl_module) -> None:
426+
metrics = self.get_metrics(trainer, pl_module)
427+
if self._metric_component:
428+
self._metric_component.update(metrics)
429+
422430
def teardown(self, trainer, pl_module, stage: Optional[str] = None) -> None:
423431
self._stop_progress()
424432

@@ -441,7 +449,7 @@ def main_progress_bar(self) -> Task:
441449
def test_progress_bar(self) -> Task:
442450
return self.progress.tasks[self.test_progress_bar_id]
443451

444-
def configure_columns(self, trainer, pl_module) -> list:
452+
def configure_columns(self, trainer) -> list:
445453
return [
446454
TextColumn("[progress.description]{task.description}"),
447455
CustomBarColumn(
@@ -452,5 +460,4 @@ def configure_columns(self, trainer, pl_module) -> list:
452460
BatchesProcessedColumn(style=self.theme.batch_process),
453461
CustomTimeColumn(style=self.theme.time),
454462
ProcessingSpeedColumn(style=self.theme.processing_speed),
455-
MetricsTextColumn(trainer, pl_module),
456463
]

tests/callbacks/test_rich_progress_bar.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,15 @@ def test_rich_progress_bar_configure_columns():
150150
custom_column = TextColumn("[progress.description]Testing Rich!")
151151

152152
class CustomRichProgressBar(RichProgressBar):
153-
def configure_columns(self, trainer, pl_module):
153+
def configure_columns(self, trainer):
154154
return [custom_column]
155155

156156
progress_bar = CustomRichProgressBar()
157157

158-
progress_bar._init_progress(Mock(), Mock())
158+
progress_bar._init_progress(Mock())
159159

160160
assert progress_bar.progress.columns[0] == custom_column
161-
assert len(progress_bar.progress.columns) == 1
161+
assert len(progress_bar.progress.columns) == 2
162162

163163

164164
@RunIf(rich=True)

0 commit comments

Comments
 (0)