@@ -129,11 +129,12 @@ def render(self, task) -> RenderableType:
129
129
class MetricsTextColumn (ProgressColumn ):
130
130
"""A column containing text."""
131
131
132
- def __init__ (self , trainer ):
132
+ def __init__ (self , trainer , style ):
133
133
self ._trainer = trainer
134
134
self ._tasks = {}
135
135
self ._current_task_id = 0
136
136
self ._metrics = {}
137
+ self ._style = style
137
138
super ().__init__ ()
138
139
139
140
def update (self , metrics ):
@@ -158,23 +159,34 @@ def render(self, task) -> Text:
158
159
159
160
for k , v in self ._metrics .items ():
160
161
_text += f"{ k } : { round (v , 3 ) if isinstance (v , float ) else v } "
161
- return Text (_text , justify = "left" )
162
+ return Text (_text , justify = "left" , style = self . _style )
162
163
163
164
164
165
@dataclass
165
166
class RichProgressBarTheme :
166
167
"""Styles to associate to different base components.
167
168
169
+ Args:
170
+ description: Style for the progress bar description. For eg., Epoch x, Testing, etc.
171
+ progress_bar: Style for the bar in progress.
172
+ progress_bar_finished: Style for the finished progress bar.
173
+ progress_bar_pulse: Style for the progress bar when `IterableDataset` is being processed.
174
+ batch_progress: Style for the progress tracker (i.e 10/50 batches completed).
175
+ time: Style for the processed time and estimate time remaining.
176
+ processing_speed: Style for the speed of the batches being processed.
177
+ metrics: Style for the metrics
178
+
168
179
https://rich.readthedocs.io/en/stable/style.html
169
180
"""
170
181
171
- text_color : str = "white"
172
- progress_bar_complete : Union [str , Style ] = "#6206E0"
182
+ description : Union [ str , Style ] = "white"
183
+ progress_bar : Union [str , Style ] = "#6206E0"
173
184
progress_bar_finished : Union [str , Style ] = "#6206E0"
174
185
progress_bar_pulse : Union [str , Style ] = "#6206E0"
175
- batch_process : str = "white"
176
- time : str = "grey54"
177
- processing_speed : str = "grey70"
186
+ batch_progress : Union [str , Style ] = "white"
187
+ time : Union [str , Style ] = "grey54"
188
+ processing_speed : Union [str , Style ] = "grey70"
189
+ metrics : Union [str , Style ] = "white"
178
190
179
191
180
192
class RichProgressBar (ProgressBarBase ):
@@ -268,7 +280,7 @@ def _init_progress(self, trainer):
268
280
self ._reset_progress_bar_ids ()
269
281
self ._console : Console = Console ()
270
282
self ._console .clear_live ()
271
- self ._metric_component = MetricsTextColumn (trainer )
283
+ self ._metric_component = MetricsTextColumn (trainer , self . theme . metrics )
272
284
self .progress = CustomProgress (
273
285
* self .configure_columns (trainer ),
274
286
self ._metric_component ,
@@ -351,7 +363,7 @@ def on_validation_epoch_start(self, trainer, pl_module):
351
363
def _add_task (self , total_batches : int , description : str , visible : bool = True ) -> Optional [int ]:
352
364
if self .progress is not None :
353
365
return self .progress .add_task (
354
- f"[{ self .theme .text_color } ]{ description } " , total = total_batches , visible = visible
366
+ f"[{ self .theme .description } ]{ description } " , total = total_batches , visible = visible
355
367
)
356
368
357
369
def _update (self , progress_bar_id : int , visible : bool = True ) -> None :
@@ -448,11 +460,11 @@ def configure_columns(self, trainer) -> list:
448
460
return [
449
461
TextColumn ("[progress.description]{task.description}" ),
450
462
CustomBarColumn (
451
- complete_style = self .theme .progress_bar_complete ,
463
+ complete_style = self .theme .progress_bar ,
452
464
finished_style = self .theme .progress_bar_finished ,
453
465
pulse_style = self .theme .progress_bar_pulse ,
454
466
),
455
- BatchesProcessedColumn (style = self .theme .batch_process ),
467
+ BatchesProcessedColumn (style = self .theme .batch_progress ),
456
468
CustomTimeColumn (style = self .theme .time ),
457
469
ProcessingSpeedColumn (style = self .theme .processing_speed ),
458
470
]
0 commit comments