Skip to content

Fix deadlocks for distributed training for RichProgressBar #10428

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 9, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed issue with pickling `CSVLogger` after a call to `CSVLogger.save` ([#10388](https://github.com/PyTorchLightning/pytorch-lightning/pull/10388))

-
- Fixed deadlocks for distributed training with `RichProgressBar` ([#10428](https://github.com/PyTorchLightning/pytorch-lightning/pull/10428))


-
Expand Down
57 changes: 32 additions & 25 deletions pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,19 @@ def render(self, task) -> RenderableType:
class MetricsTextColumn(ProgressColumn):
"""A column containing text."""

def __init__(self, trainer, pl_module):
def __init__(self, trainer):
self._trainer = trainer
self._pl_module = pl_module
self._tasks = {}
self._current_task_id = 0
self.metrics = {}
super().__init__()

def update(self, metrics):
# called when metrics are ready to be rendered.
# this is due to preventing render from causing deadlock issues by requesting metrics
# in separate thread.
self.metrics = metrics

def render(self, task) -> Text:
from pytorch_lightning.trainer.states import TrainerFn

Expand All @@ -149,14 +155,8 @@ def render(self, task) -> Text:
if self._trainer.training and task.id != self._current_task_id:
return self._tasks[task.id]
_text = ""
# TODO(@daniellepintz): make this code cleaner
progress_bar_callback = getattr(self._trainer, "progress_bar_callback", None)
if progress_bar_callback:
metrics = self._trainer.progress_bar_callback.get_metrics(self._trainer, self._pl_module)
else:
metrics = self._trainer.progress_bar_metrics

for k, v in metrics.items():

for k, v in self.metrics.items():
_text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
return Text(_text, justify="left")

Expand Down Expand Up @@ -225,9 +225,9 @@ def __init__(
self.progress: Optional[Progress] = None
self.val_sanity_progress_bar_id: Optional[int] = None
self._reset_progress_bar_ids()
self._metric_component = None
self._progress_stopped: bool = False
self.theme = theme
self._console: Console = Console()

@property
def refresh_rate_per_second(self) -> float:
Expand Down Expand Up @@ -268,12 +268,15 @@ def test_description(self) -> str:
def predict_description(self) -> str:
return "Predicting"

def _init_progress(self, trainer, pl_module):
if self.progress is None or self._progress_stopped:
def _init_progress(self, trainer):
if self.is_enabled and (self.progress is None or self._progress_stopped):
self._reset_progress_bar_ids()
self._console: Console = Console()
self._console.clear_live()
self._metric_component = MetricsTextColumn(trainer)
self.progress = CustomProgress(
*self.configure_columns(trainer, pl_module),
*self.configure_columns(trainer),
self._metric_component,
refresh_per_second=self.refresh_rate_per_second,
disable=self.is_disabled,
console=self._console,
Expand All @@ -284,19 +287,19 @@ def _init_progress(self, trainer, pl_module):

def on_train_start(self, trainer, pl_module):
super().on_train_start(trainer, pl_module)
self._init_progress(trainer, pl_module)
self._init_progress(trainer)

def on_predict_start(self, trainer, pl_module):
super().on_predict_start(trainer, pl_module)
self._init_progress(trainer, pl_module)
self._init_progress(trainer)

def on_test_start(self, trainer, pl_module):
super().on_test_start(trainer, pl_module)
self._init_progress(trainer, pl_module)
self._init_progress(trainer)

def on_validation_start(self, trainer, pl_module):
super().on_validation_start(trainer, pl_module)
self._init_progress(trainer, pl_module)
self._init_progress(trainer)

def __getstate__(self):
# can't pickle the rich progress objects
Expand All @@ -307,12 +310,11 @@ def __getstate__(self):

def __setstate__(self, state):
self.__dict__ = state
# reset console reference after loading progress
self._console = Console()
state["_console"] = Console()

def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self._init_progress(trainer, pl_module)
self._init_progress(trainer)
self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description)

def on_sanity_check_end(self, trainer, pl_module):
Expand All @@ -333,10 +335,10 @@ def on_train_epoch_start(self, trainer, pl_module):
train_description = self._get_train_description(trainer.current_epoch)
if self.main_progress_bar_id is not None and self._leave:
self._stop_progress()
self._init_progress(trainer, pl_module)
self._init_progress(trainer)
if self.main_progress_bar_id is None:
self.main_progress_bar_id = self._add_task(total_batches, train_description)
else:
elif self.progress is not None:
self.progress.reset(
self.main_progress_bar_id, total=total_batches, description=train_description, visible=True
)
Expand Down Expand Up @@ -377,6 +379,7 @@ def on_predict_epoch_start(self, trainer, pl_module):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
self._update(self.main_progress_bar_id)
self._update_metrics(trainer, pl_module)

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
Expand Down Expand Up @@ -419,6 +422,11 @@ def _reset_progress_bar_ids(self):
self.test_progress_bar_id: Optional[int] = None
self.predict_progress_bar_id: Optional[int] = None

def _update_metrics(self, trainer, pl_module) -> None:
metrics = self.get_metrics(trainer, pl_module)
if self._metric_component:
self._metric_component.update(metrics)

def teardown(self, trainer, pl_module, stage: Optional[str] = None) -> None:
self._stop_progress()

Expand All @@ -441,7 +449,7 @@ def main_progress_bar(self) -> Task:
def test_progress_bar(self) -> Task:
return self.progress.tasks[self.test_progress_bar_id]

def configure_columns(self, trainer, pl_module) -> list:
def configure_columns(self, trainer) -> list:
return [
TextColumn("[progress.description]{task.description}"),
CustomBarColumn(
Expand All @@ -452,5 +460,4 @@ def configure_columns(self, trainer, pl_module) -> list:
BatchesProcessedColumn(style=self.theme.batch_process),
CustomTimeColumn(style=self.theme.time),
ProcessingSpeedColumn(style=self.theme.processing_speed),
MetricsTextColumn(trainer, pl_module),
]
6 changes: 3 additions & 3 deletions tests/callbacks/test_rich_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,15 @@ def test_rich_progress_bar_configure_columns():
custom_column = TextColumn("[progress.description]Testing Rich!")

class CustomRichProgressBar(RichProgressBar):
def configure_columns(self, trainer, pl_module):
def configure_columns(self, trainer):
return [custom_column]

progress_bar = CustomRichProgressBar()

progress_bar._init_progress(Mock(), Mock())
progress_bar._init_progress(Mock())

assert progress_bar.progress.columns[0] == custom_column
assert len(progress_bar.progress.columns) == 1
assert len(progress_bar.progress.columns) == 2


@RunIf(rich=True)
Expand Down