-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Add refresh_rate
to RichProgressBar
#10497
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7a31be7
24e387a
fb2ebff
ab01e3b
294ac7c
e97b07a
86c2c15
1916d32
c385fab
3e64f43
4f47b3d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -206,7 +206,8 @@ class RichProgressBar(ProgressBarBase): | |
trainer = Trainer(callbacks=RichProgressBar()) | ||
|
||
Args: | ||
refresh_rate_per_second: the number of updates per second. If refresh_rate is 0, progress bar is disabled. | ||
refresh_rate: Determines at which rate (in number of batches) the progress bars get updated. | ||
Set it to ``0`` to disable the display. | ||
leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False | ||
theme: Contains styles used to stylize the progress bar. | ||
|
||
|
@@ -222,7 +223,7 @@ class RichProgressBar(ProgressBarBase): | |
|
||
def __init__( | ||
self, | ||
refresh_rate_per_second: int = 10, | ||
refresh_rate: int = 1, | ||
leave: bool = False, | ||
theme: RichProgressBarTheme = RichProgressBarTheme(), | ||
) -> None: | ||
|
@@ -231,7 +232,7 @@ def __init__( | |
"`RichProgressBar` requires `rich` to be installed. Install it by running `pip install -U rich`." | ||
) | ||
super().__init__() | ||
self._refresh_rate_per_second: int = refresh_rate_per_second | ||
self._refresh_rate: int = refresh_rate | ||
self._leave: bool = leave | ||
self._enabled: bool = True | ||
self.progress: Optional[Progress] = None | ||
|
@@ -242,17 +243,12 @@ def __init__( | |
self.theme = theme | ||
|
||
@property | ||
def refresh_rate_per_second(self) -> float: | ||
"""Refresh rate for Rich Progress. | ||
|
||
Returns: Refresh rate for Progress Bar. | ||
Return 1 if not enabled, as a positive integer is required (ignored by Rich Progress). | ||
""" | ||
return self._refresh_rate_per_second if self._refresh_rate_per_second > 0 else 1 | ||
def refresh_rate(self) -> float: | ||
return self._refresh_rate | ||
|
||
@property | ||
def is_enabled(self) -> bool: | ||
return self._enabled and self._refresh_rate_per_second > 0 | ||
return self._enabled and self.refresh_rate > 0 | ||
|
||
@property | ||
def is_disabled(self) -> bool: | ||
|
@@ -289,14 +285,18 @@ def _init_progress(self, trainer): | |
self.progress = CustomProgress( | ||
*self.configure_columns(trainer), | ||
self._metric_component, | ||
refresh_per_second=self.refresh_rate_per_second, | ||
auto_refresh=False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is the fix I proposed in #9647 to @SeanNaren to prevent threading issues in the render function. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Talked to Sean regarding it, it didn't work! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @kaushikb11, I'm working on #13937 and debugging it has led me to this PR. Do you remember if there is any particular reason why we changed |
||
disable=self.is_disabled, | ||
console=self._console, | ||
) | ||
self.progress.start() | ||
# progress has started | ||
self._progress_stopped = False | ||
|
||
def refresh(self) -> None: | ||
if self.progress: | ||
self.progress.refresh() | ||
|
||
def on_train_start(self, trainer, pl_module): | ||
super().on_train_start(trainer, pl_module) | ||
self._init_progress(trainer) | ||
|
@@ -328,10 +328,12 @@ def on_sanity_check_start(self, trainer, pl_module): | |
super().on_sanity_check_start(trainer, pl_module) | ||
self._init_progress(trainer) | ||
self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description) | ||
self.refresh() | ||
|
||
def on_sanity_check_end(self, trainer, pl_module): | ||
super().on_sanity_check_end(trainer, pl_module) | ||
self._update(self.val_sanity_progress_bar_id, visible=False) | ||
self.refresh() | ||
|
||
def on_train_epoch_start(self, trainer, pl_module): | ||
super().on_train_epoch_start(trainer, pl_module) | ||
|
@@ -354,6 +356,7 @@ def on_train_epoch_start(self, trainer, pl_module): | |
self.progress.reset( | ||
self.main_progress_bar_id, total=total_batches, description=train_description, visible=True | ||
) | ||
self.refresh() | ||
|
||
def on_validation_epoch_start(self, trainer, pl_module): | ||
super().on_validation_epoch_start(trainer, pl_module) | ||
|
@@ -364,52 +367,62 @@ def on_validation_epoch_start(self, trainer, pl_module): | |
val_checks_per_epoch = self.total_train_batches // trainer.val_check_batch | ||
total_val_batches = self.total_val_batches * val_checks_per_epoch | ||
self.val_progress_bar_id = self._add_task(total_val_batches, self.validation_description, visible=False) | ||
self.refresh() | ||
|
||
def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]: | ||
if self.progress is not None: | ||
return self.progress.add_task( | ||
f"[{self.theme.description}]{description}", total=total_batches, visible=visible | ||
) | ||
|
||
def _update(self, progress_bar_id: int, visible: bool = True) -> None: | ||
if self.progress is not None: | ||
self.progress.update(progress_bar_id, advance=1.0, visible=visible) | ||
def _update(self, progress_bar_id: int, current: int, total: int, visible: bool = True) -> None: | ||
if self.progress is not None and self._should_update(current, total): | ||
self.progress.update(progress_bar_id, advance=self.refresh_rate, visible=visible) | ||
self.refresh() | ||
|
||
def _should_update(self, current: int, total: int) -> bool: | ||
return self.is_enabled and (current % self.refresh_rate == 0 or current == total) | ||
|
||
def on_validation_epoch_end(self, trainer, pl_module): | ||
super().on_validation_epoch_end(trainer, pl_module) | ||
if self.val_progress_bar_id is not None: | ||
self._update(self.val_progress_bar_id, visible=False) | ||
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches, visible=False) | ||
|
||
def on_test_epoch_start(self, trainer, pl_module): | ||
super().on_train_epoch_start(trainer, pl_module) | ||
self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description) | ||
self.refresh() | ||
|
||
def on_predict_epoch_start(self, trainer, pl_module): | ||
super().on_predict_epoch_start(trainer, pl_module) | ||
self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description) | ||
self.refresh() | ||
|
||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | ||
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) | ||
self._update(self.main_progress_bar_id) | ||
self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches) | ||
self._update_metrics(trainer, pl_module) | ||
self.refresh() | ||
|
||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): | ||
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) | ||
if trainer.sanity_checking: | ||
self._update(self.val_sanity_progress_bar_id) | ||
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches) | ||
elif self.val_progress_bar_id is not None: | ||
# check to see if we should update the main training progress bar | ||
if self.main_progress_bar_id is not None: | ||
self._update(self.main_progress_bar_id) | ||
self._update(self.val_progress_bar_id) | ||
self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches) | ||
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches) | ||
self.refresh() | ||
|
||
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): | ||
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) | ||
self._update(self.test_progress_bar_id) | ||
self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches) | ||
self.refresh() | ||
|
||
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): | ||
super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) | ||
self._update(self.predict_progress_bar_id) | ||
self._update(self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches) | ||
self.refresh() | ||
|
||
def _get_train_description(self, current_epoch: int) -> str: | ||
train_description = f"Epoch {current_epoch}" | ||
|
Uh oh!
There was an error while loading. Please reload this page.