Skip to content

Commit 1ecb962

Browse files
Raahul-Singhkaushikb11
authored andcommitted
Change attributes of RichProgressBarTheme dataclass (#10454)
Co-authored-by: Kaushik B <[email protected]>
1 parent 9e45024 commit 1ecb962

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

pytorch_lightning/callbacks/progress/rich_progress.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,12 @@ def render(self, task) -> RenderableType:
129129
class MetricsTextColumn(ProgressColumn):
130130
"""A column containing text."""
131131

132-
def __init__(self, trainer):
132+
def __init__(self, trainer, style):
133133
self._trainer = trainer
134134
self._tasks = {}
135135
self._current_task_id = 0
136136
self._metrics = {}
137+
self._style = style
137138
super().__init__()
138139

139140
def update(self, metrics):
@@ -158,23 +159,34 @@ def render(self, task) -> Text:
158159

159160
for k, v in self._metrics.items():
160161
_text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
161-
return Text(_text, justify="left")
162+
return Text(_text, justify="left", style=self._style)
162163

163164

164165
@dataclass
165166
class RichProgressBarTheme:
166167
"""Styles to associate to different base components.
167168
169+
Args:
170+
description: Style for the progress bar description. For eg., Epoch x, Testing, etc.
171+
progress_bar: Style for the bar in progress.
172+
progress_bar_finished: Style for the finished progress bar.
173+
progress_bar_pulse: Style for the progress bar when `IterableDataset` is being processed.
174+
batch_progress: Style for the progress tracker (i.e 10/50 batches completed).
175+
time: Style for the processed time and estimate time remaining.
176+
processing_speed: Style for the speed of the batches being processed.
177+
metrics: Style for the metrics
178+
168179
https://rich.readthedocs.io/en/stable/style.html
169180
"""
170181

171-
text_color: str = "white"
172-
progress_bar_complete: Union[str, Style] = "#6206E0"
182+
description: Union[str, Style] = "white"
183+
progress_bar: Union[str, Style] = "#6206E0"
173184
progress_bar_finished: Union[str, Style] = "#6206E0"
174185
progress_bar_pulse: Union[str, Style] = "#6206E0"
175-
batch_process: str = "white"
176-
time: str = "grey54"
177-
processing_speed: str = "grey70"
186+
batch_progress: Union[str, Style] = "white"
187+
time: Union[str, Style] = "grey54"
188+
processing_speed: Union[str, Style] = "grey70"
189+
metrics: Union[str, Style] = "white"
178190

179191

180192
class RichProgressBar(ProgressBarBase):
@@ -268,7 +280,7 @@ def _init_progress(self, trainer):
268280
self._reset_progress_bar_ids()
269281
self._console: Console = Console()
270282
self._console.clear_live()
271-
self._metric_component = MetricsTextColumn(trainer)
283+
self._metric_component = MetricsTextColumn(trainer, self.theme.metrics)
272284
self.progress = CustomProgress(
273285
*self.configure_columns(trainer),
274286
self._metric_component,
@@ -351,7 +363,7 @@ def on_validation_epoch_start(self, trainer, pl_module):
351363
def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]:
352364
if self.progress is not None:
353365
return self.progress.add_task(
354-
f"[{self.theme.text_color}]{description}", total=total_batches, visible=visible
366+
f"[{self.theme.description}]{description}", total=total_batches, visible=visible
355367
)
356368

357369
def _update(self, progress_bar_id: int, visible: bool = True) -> None:
@@ -448,11 +460,11 @@ def configure_columns(self, trainer) -> list:
448460
return [
449461
TextColumn("[progress.description]{task.description}"),
450462
CustomBarColumn(
451-
complete_style=self.theme.progress_bar_complete,
463+
complete_style=self.theme.progress_bar,
452464
finished_style=self.theme.progress_bar_finished,
453465
pulse_style=self.theme.progress_bar_pulse,
454466
),
455-
BatchesProcessedColumn(style=self.theme.batch_process),
467+
BatchesProcessedColumn(style=self.theme.batch_progress),
456468
CustomTimeColumn(style=self.theme.time),
457469
ProcessingSpeedColumn(style=self.theme.processing_speed),
458470
]

tests/callbacks/test_rich_progress_bar.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,11 @@ def test_rich_progress_bar_custom_theme(tmpdir):
106106

107107
assert progress_bar.theme == theme
108108
args, kwargs = mocks["CustomBarColumn"].call_args
109-
assert kwargs["complete_style"] == theme.progress_bar_complete
109+
assert kwargs["complete_style"] == theme.progress_bar
110110
assert kwargs["finished_style"] == theme.progress_bar_finished
111111

112112
args, kwargs = mocks["BatchesProcessedColumn"].call_args
113-
assert kwargs["style"] == theme.batch_process
113+
assert kwargs["style"] == theme.batch_progress
114114

115115
args, kwargs = mocks["CustomTimeColumn"].call_args
116116
assert kwargs["style"] == theme.time

0 commit comments

Comments
 (0)