Skip to content

Commit 3668598

Browse files
awaelchlikaushikb11
authored andcommitted
Fix deadlocks for distributed training for RichProgressBar (#10428)
Co-authored-by: Kaushik Bokka <[email protected]>
1 parent 8d7712c commit 3668598

File tree

3 files changed

+37
-29
lines changed

3 files changed

+37
-29
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8-
## [1.5.1] - 2021-MM-DD
8+
## [1.5.1] - 2021-11-09
99

1010
### Fixed
1111

@@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515
- Fixed issue with pickling `CSVLogger` after a call to `CSVLogger.save` ([#10388](https://github.com/PyTorchLightning/pytorch-lightning/pull/10388))
1616
- Fixed an import error being caused by `PostLocalSGD` when `torch.distributed` not available ([#10359](https://github.com/PyTorchLightning/pytorch-lightning/pull/10359))
1717
- Fixed the logging with `on_step=True` in epoch-level hooks causing unintended side-effects. Logging with `on_step=True` in epoch-level hooks will now correctly raise an error ([#10409](https://github.com/PyTorchLightning/pytorch-lightning/pull/10409))
18+
- Fixed deadlocks for distributed training with `RichProgressBar` ([#10428](https://github.com/PyTorchLightning/pytorch-lightning/pull/10428))
1819

1920

2021
## [1.5.0] - 2021-11-02

pytorch_lightning/callbacks/progress/rich_progress.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,19 @@ def render(self, task) -> RenderableType:
129129
class MetricsTextColumn(ProgressColumn):
130130
"""A column containing text."""
131131

132-
def __init__(self, trainer, pl_module):
132+
def __init__(self, trainer):
133133
self._trainer = trainer
134-
self._pl_module = pl_module
135134
self._tasks = {}
136135
self._current_task_id = 0
136+
self._metrics = {}
137137
super().__init__()
138138

139+
def update(self, metrics):
140+
# Called when metrics are ready to be rendered.
141+
# This is to prevent render from causing deadlock issues by requesting metrics
142+
# in separate threads.
143+
self._metrics = metrics
144+
139145
def render(self, task) -> Text:
140146
from pytorch_lightning.trainer.states import TrainerFn
141147

@@ -149,14 +155,8 @@ def render(self, task) -> Text:
149155
if self._trainer.training and task.id != self._current_task_id:
150156
return self._tasks[task.id]
151157
_text = ""
152-
# TODO(@daniellepintz): make this code cleaner
153-
progress_bar_callback = getattr(self._trainer, "progress_bar_callback", None)
154-
if progress_bar_callback:
155-
metrics = self._trainer.progress_bar_callback.get_metrics(self._trainer, self._pl_module)
156-
else:
157-
metrics = self._trainer.progress_bar_metrics
158-
159-
for k, v in metrics.items():
158+
159+
for k, v in self._metrics.items():
160160
_text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
161161
return Text(_text, justify="left")
162162

@@ -220,9 +220,9 @@ def __init__(
220220
self.progress: Optional[Progress] = None
221221
self.val_sanity_progress_bar_id: Optional[int] = None
222222
self._reset_progress_bar_ids()
223+
self._metric_component = None
223224
self._progress_stopped: bool = False
224225
self.theme = theme
225-
self._console: Console = Console()
226226

227227
@property
228228
def refresh_rate_per_second(self) -> float:
@@ -263,12 +263,15 @@ def test_description(self) -> str:
263263
def predict_description(self) -> str:
264264
return "Predicting"
265265

266-
def _init_progress(self, trainer, pl_module):
267-
if self.progress is None or self._progress_stopped:
266+
def _init_progress(self, trainer):
267+
if self.is_enabled and (self.progress is None or self._progress_stopped):
268268
self._reset_progress_bar_ids()
269+
self._console: Console = Console()
269270
self._console.clear_live()
271+
self._metric_component = MetricsTextColumn(trainer)
270272
self.progress = CustomProgress(
271-
*self.configure_columns(trainer, pl_module),
273+
*self.configure_columns(trainer),
274+
self._metric_component,
272275
refresh_per_second=self.refresh_rate_per_second,
273276
disable=self.is_disabled,
274277
console=self._console,
@@ -279,19 +282,19 @@ def _init_progress(self, trainer, pl_module):
279282

280283
def on_train_start(self, trainer, pl_module):
281284
super().on_train_start(trainer, pl_module)
282-
self._init_progress(trainer, pl_module)
285+
self._init_progress(trainer)
283286

284287
def on_predict_start(self, trainer, pl_module):
285288
super().on_predict_start(trainer, pl_module)
286-
self._init_progress(trainer, pl_module)
289+
self._init_progress(trainer)
287290

288291
def on_test_start(self, trainer, pl_module):
289292
super().on_test_start(trainer, pl_module)
290-
self._init_progress(trainer, pl_module)
293+
self._init_progress(trainer)
291294

292295
def on_validation_start(self, trainer, pl_module):
293296
super().on_validation_start(trainer, pl_module)
294-
self._init_progress(trainer, pl_module)
297+
self._init_progress(trainer)
295298

296299
def __getstate__(self):
297300
# can't pickle the rich progress objects
@@ -302,12 +305,11 @@ def __getstate__(self):
302305

303306
def __setstate__(self, state):
304307
self.__dict__ = state
305-
# reset console reference after loading progress
306-
self._console = Console()
308+
state["_console"] = Console()
307309

308310
def on_sanity_check_start(self, trainer, pl_module):
309311
super().on_sanity_check_start(trainer, pl_module)
310-
self._init_progress(trainer, pl_module)
312+
self._init_progress(trainer)
311313
self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description)
312314

313315
def on_sanity_check_end(self, trainer, pl_module):
@@ -328,10 +330,10 @@ def on_train_epoch_start(self, trainer, pl_module):
328330
train_description = self._get_train_description(trainer.current_epoch)
329331
if self.main_progress_bar_id is not None and self._leave:
330332
self._stop_progress()
331-
self._init_progress(trainer, pl_module)
333+
self._init_progress(trainer)
332334
if self.main_progress_bar_id is None:
333335
self.main_progress_bar_id = self._add_task(total_batches, train_description)
334-
else:
336+
elif self.progress is not None:
335337
self.progress.reset(
336338
self.main_progress_bar_id, total=total_batches, description=train_description, visible=True
337339
)
@@ -372,6 +374,7 @@ def on_predict_epoch_start(self, trainer, pl_module):
372374
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
373375
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
374376
self._update(self.main_progress_bar_id)
377+
self._update_metrics(trainer, pl_module)
375378

376379
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
377380
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
@@ -414,6 +417,11 @@ def _reset_progress_bar_ids(self):
414417
self.test_progress_bar_id: Optional[int] = None
415418
self.predict_progress_bar_id: Optional[int] = None
416419

420+
def _update_metrics(self, trainer, pl_module) -> None:
421+
metrics = self.get_metrics(trainer, pl_module)
422+
if self._metric_component:
423+
self._metric_component.update(metrics)
424+
417425
def teardown(self, trainer, pl_module, stage: Optional[str] = None) -> None:
418426
self._stop_progress()
419427

@@ -436,7 +444,7 @@ def main_progress_bar(self) -> Task:
436444
def test_progress_bar(self) -> Task:
437445
return self.progress.tasks[self.test_progress_bar_id]
438446

439-
def configure_columns(self, trainer, pl_module) -> list:
447+
def configure_columns(self, trainer) -> list:
440448
return [
441449
TextColumn("[progress.description]{task.description}"),
442450
CustomBarColumn(
@@ -447,5 +455,4 @@ def configure_columns(self, trainer, pl_module) -> list:
447455
BatchesProcessedColumn(style=self.theme.batch_process),
448456
CustomTimeColumn(style=self.theme.time),
449457
ProcessingSpeedColumn(style=self.theme.processing_speed),
450-
MetricsTextColumn(trainer, pl_module),
451458
]

tests/callbacks/test_rich_progress_bar.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,15 @@ def test_rich_progress_bar_configure_columns():
150150
custom_column = TextColumn("[progress.description]Testing Rich!")
151151

152152
class CustomRichProgressBar(RichProgressBar):
153-
def configure_columns(self, trainer, pl_module):
153+
def configure_columns(self, trainer):
154154
return [custom_column]
155155

156156
progress_bar = CustomRichProgressBar()
157157

158-
progress_bar._init_progress(Mock(), Mock())
158+
progress_bar._init_progress(Mock())
159159

160160
assert progress_bar.progress.columns[0] == custom_column
161-
assert len(progress_bar.progress.columns) == 1
161+
assert len(progress_bar.progress.columns) == 2
162162

163163

164164
@RunIf(rich=True)

0 commit comments

Comments
 (0)