Skip to content

Add refresh_rate to RichProgressBar #10497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Nov 19, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520))


- Renamed `refresh_rate_per_second` parameter to `referesh_rate` for `RichProgressBar` signature ([#10497](https://github.com/PyTorchLightning/pytorch-lightning/pull/10497))


- Moved ownership of the `PrecisionPlugin` into `TrainingTypePlugin` and updated all references ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570))


Expand Down
59 changes: 36 additions & 23 deletions pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ class RichProgressBar(ProgressBarBase):
trainer = Trainer(callbacks=RichProgressBar())

Args:
refresh_rate_per_second: the number of updates per second. If refresh_rate is 0, progress bar is disabled.
refresh_rate: Determines at which rate (in number of batches) the progress bars get updated.
Set it to ``0`` to disable the display.
leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False
theme: Contains styles used to stylize the progress bar.

Expand All @@ -222,7 +223,7 @@ class RichProgressBar(ProgressBarBase):

def __init__(
self,
refresh_rate_per_second: int = 10,
refresh_rate: int = 1,
leave: bool = False,
theme: RichProgressBarTheme = RichProgressBarTheme(),
) -> None:
Expand All @@ -231,7 +232,7 @@ def __init__(
"`RichProgressBar` requires `rich` to be installed. Install it by running `pip install -U rich`."
)
super().__init__()
self._refresh_rate_per_second: int = refresh_rate_per_second
self._refresh_rate: int = refresh_rate
self._leave: bool = leave
self._enabled: bool = True
self.progress: Optional[Progress] = None
Expand All @@ -242,17 +243,12 @@ def __init__(
self.theme = theme

@property
def refresh_rate_per_second(self) -> float:
"""Refresh rate for Rich Progress.

Returns: Refresh rate for Progress Bar.
Return 1 if not enabled, as a positive integer is required (ignored by Rich Progress).
"""
return self._refresh_rate_per_second if self._refresh_rate_per_second > 0 else 1
def refresh_rate(self) -> float:
return self._refresh_rate

@property
def is_enabled(self) -> bool:
return self._enabled and self._refresh_rate_per_second > 0
return self._enabled and self.refresh_rate > 0

@property
def is_disabled(self) -> bool:
Expand Down Expand Up @@ -289,14 +285,18 @@ def _init_progress(self, trainer):
self.progress = CustomProgress(
*self.configure_columns(trainer),
self._metric_component,
refresh_per_second=self.refresh_rate_per_second,
auto_refresh=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the fix I proposed in #9647 to @SeanNaren to prevent threading issues in the render function.
Did you check that this might enable #9647 to pass the CI now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talked to Sean regarding it, it didn't work!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kaushikb11, I'm working on #13937 and debugging it has led me to this PR. Do you remember if there is any particular reason why we changed auto_refresh from True to False? Was it for #10362?

disable=self.is_disabled,
console=self._console,
)
self.progress.start()
# progress has started
self._progress_stopped = False

def refresh(self) -> None:
if self.progress:
self.progress.refresh()

def on_train_start(self, trainer, pl_module):
super().on_train_start(trainer, pl_module)
self._init_progress(trainer)
Expand Down Expand Up @@ -328,10 +328,12 @@ def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self._init_progress(trainer)
self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description)
self.refresh()

def on_sanity_check_end(self, trainer, pl_module):
super().on_sanity_check_end(trainer, pl_module)
self._update(self.val_sanity_progress_bar_id, visible=False)
self.refresh()

def on_train_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
Expand All @@ -354,6 +356,7 @@ def on_train_epoch_start(self, trainer, pl_module):
self.progress.reset(
self.main_progress_bar_id, total=total_batches, description=train_description, visible=True
)
self.refresh()

def on_validation_epoch_start(self, trainer, pl_module):
super().on_validation_epoch_start(trainer, pl_module)
Expand All @@ -364,52 +367,62 @@ def on_validation_epoch_start(self, trainer, pl_module):
val_checks_per_epoch = self.total_train_batches // trainer.val_check_batch
total_val_batches = self.total_val_batches * val_checks_per_epoch
self.val_progress_bar_id = self._add_task(total_val_batches, self.validation_description, visible=False)
self.refresh()

def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]:
if self.progress is not None:
return self.progress.add_task(
f"[{self.theme.description}]{description}", total=total_batches, visible=visible
)

def _update(self, progress_bar_id: int, visible: bool = True) -> None:
if self.progress is not None:
self.progress.update(progress_bar_id, advance=1.0, visible=visible)
def _update(self, progress_bar_id: int, current: int, total: int, visible: bool = True) -> None:
if self.progress is not None and self._should_update(current, total):
self.progress.update(progress_bar_id, advance=self.refresh_rate, visible=visible)
self.refresh()

def _should_update(self, current: int, total: int) -> bool:
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

def on_validation_epoch_end(self, trainer, pl_module):
super().on_validation_epoch_end(trainer, pl_module)
if self.val_progress_bar_id is not None:
self._update(self.val_progress_bar_id, visible=False)
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches, visible=False)

def on_test_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description)
self.refresh()

def on_predict_epoch_start(self, trainer, pl_module):
super().on_predict_epoch_start(trainer, pl_module)
self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description)
self.refresh()

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
self._update(self.main_progress_bar_id)
self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches)
self._update_metrics(trainer, pl_module)
self.refresh()

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if trainer.sanity_checking:
self._update(self.val_sanity_progress_bar_id)
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches)
elif self.val_progress_bar_id is not None:
# check to see if we should update the main training progress bar
if self.main_progress_bar_id is not None:
self._update(self.main_progress_bar_id)
self._update(self.val_progress_bar_id)
self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches)
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches)
self.refresh()

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
self._update(self.test_progress_bar_id)
self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches)
self.refresh()

def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
self._update(self.predict_progress_bar_id)
self._update(self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches)
self.refresh()

def _get_train_description(self, current_epoch: int) -> str:
train_description = f"Epoch {current_epoch}"
Expand Down
27 changes: 24 additions & 3 deletions tests/callbacks/test_rich_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ def test_rich_progress_bar_callback():


@RunIf(rich=True)
def test_rich_progress_bar_refresh_rate():
progress_bar = RichProgressBar(refresh_rate_per_second=1)
def test_rich_progress_bar_refresh_rate_enabled():
progress_bar = RichProgressBar(refresh_rate=1)
assert progress_bar.is_enabled
assert not progress_bar.is_disabled
progress_bar = RichProgressBar(refresh_rate_per_second=0)
progress_bar = RichProgressBar(refresh_rate=0)
assert not progress_bar.is_enabled
assert progress_bar.is_disabled

Expand Down Expand Up @@ -180,3 +180,24 @@ def test_rich_progress_bar_leave(tmpdir, leave, reset_call_count):
)
trainer.fit(model)
assert mock_progress_reset.call_count == reset_call_count


@RunIf(rich=True)
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update")
@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(0, 0), (3, 7)]))
def test_rich_progress_bar_refresh_rate(progress_update, tmpdir, refresh_rate, expected_call_count):

model = BoringModel()

trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
limit_train_batches=6,
limit_val_batches=6,
max_epochs=1,
callbacks=RichProgressBar(refresh_rate=refresh_rate),
)

trainer.fit(model)

assert progress_update.call_count == expected_call_count