25
25
else :
26
26
from tqdm import tqdm as _tqdm
27
27
28
+ import pytorch_lightning as pl
28
29
from pytorch_lightning .callbacks .progress .base import ProgressBarBase
29
30
30
31
_PAD_SIZE = 5
@@ -206,12 +207,10 @@ def init_test_tqdm(self) -> Tqdm:
206
207
return bar
207
208
208
209
def on_sanity_check_start (self , trainer , pl_module ):
209
- super ().on_sanity_check_start (trainer , pl_module )
210
210
self .val_progress_bar = self .init_sanity_tqdm ()
211
211
self .main_progress_bar = Tqdm (disable = True ) # dummy progress bar
212
212
213
213
def on_sanity_check_end (self , trainer , pl_module ):
214
- super ().on_sanity_check_end (trainer , pl_module )
215
214
self .main_progress_bar .close ()
216
215
self .val_progress_bar .close ()
217
216
@@ -233,49 +232,59 @@ def on_train_epoch_start(self, trainer, pl_module):
233
232
234
233
def on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
235
234
super ().on_train_batch_end (trainer , pl_module , outputs , batch , batch_idx )
236
- total_batches = self .total_train_batches + self .total_val_batches
237
- total_batches = convert_inf (total_batches )
238
- if self ._should_update (self .train_batch_idx , total_batches ):
235
+ if self ._should_update (self .train_batch_idx ):
239
236
self ._update_bar (self .main_progress_bar )
240
237
self .main_progress_bar .set_postfix (self .get_metrics (trainer , pl_module ))
241
238
239
+ def on_train_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
240
+ if self .is_enabled :
241
+ self ._update_bar (self .main_progress_bar )
242
+ self .main_progress_bar .set_postfix (self .get_metrics (trainer , pl_module ))
243
+
244
+ def on_train_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
245
+ self .main_progress_bar .close ()
246
+
242
247
def on_validation_start (self , trainer , pl_module ):
243
248
super ().on_validation_start (trainer , pl_module )
244
249
if trainer .sanity_checking :
245
250
reset (self .val_progress_bar , total = sum (trainer .num_sanity_val_batches ), current = self .val_batch_idx )
246
251
else :
247
- self ._update_bar (self .main_progress_bar ) # fill up remaining
252
+ if trainer .state .fn == pl .trainer .states .TrainerFn .FITTING :
253
+ self ._update_bar (self .main_progress_bar ) # fill up remaining
248
254
self .val_progress_bar = self .init_validation_tqdm ()
249
255
reset (self .val_progress_bar , total = self .total_val_batches , current = self .val_batch_idx )
250
256
251
257
def on_validation_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ):
252
258
super ().on_validation_batch_end (trainer , pl_module , outputs , batch , batch_idx , dataloader_idx )
253
- if self ._should_update (self .val_batch_idx , convert_inf (self .total_val_batches )):
259
+ if self ._should_update (self .val_batch_idx ):
260
+ self ._update_bar (self .val_progress_bar )
261
+ if trainer .state .fn == pl .trainer .states .TrainerFn .FITTING :
262
+ self ._update_bar (self .main_progress_bar )
263
+
264
+ def on_validation_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
265
+ if self .is_enabled :
254
266
self ._update_bar (self .val_progress_bar )
255
- self ._update_bar (self .main_progress_bar )
256
267
257
268
def on_validation_end (self , trainer , pl_module ):
258
- super ().on_validation_end (trainer , pl_module )
259
- if self .main_progress_bar is not None :
269
+ if self .main_progress_bar is not None and trainer .state .fn == pl .trainer .states .TrainerFn .FITTING :
260
270
self .main_progress_bar .set_postfix (self .get_metrics (trainer , pl_module ))
261
271
self .val_progress_bar .close ()
262
272
263
- def on_train_end (self , trainer , pl_module ):
264
- super ().on_train_end (trainer , pl_module )
265
- self .main_progress_bar .close ()
266
-
267
273
def on_test_start (self , trainer , pl_module ):
268
274
super ().on_test_start (trainer , pl_module )
269
275
self .test_progress_bar = self .init_test_tqdm ()
270
276
self .test_progress_bar .total = convert_inf (self .total_test_batches )
271
277
272
278
def on_test_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ):
273
279
super ().on_test_batch_end (trainer , pl_module , outputs , batch , batch_idx , dataloader_idx )
274
- if self ._should_update (self .test_batch_idx , self .total_test_batches ):
280
+ if self ._should_update (self .test_batch_idx ):
281
+ self ._update_bar (self .test_progress_bar )
282
+
283
+ def on_test_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
284
+ if self .is_enabled :
275
285
self ._update_bar (self .test_progress_bar )
276
286
277
287
def on_test_end (self , trainer , pl_module ):
278
- super ().on_test_end (trainer , pl_module )
279
288
self .test_progress_bar .close ()
280
289
281
290
def on_predict_epoch_start (self , trainer , pl_module ):
@@ -285,7 +294,7 @@ def on_predict_epoch_start(self, trainer, pl_module):
285
294
286
295
def on_predict_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ):
287
296
super ().on_predict_batch_end (trainer , pl_module , outputs , batch , batch_idx , dataloader_idx )
288
- if self ._should_update (self .predict_batch_idx , self . total_predict_batches ):
297
+ if self ._should_update (self .predict_batch_idx ):
289
298
self ._update_bar (self .predict_progress_bar )
290
299
291
300
def on_predict_end (self , trainer , pl_module ):
@@ -309,8 +318,8 @@ def print(
309
318
s = sep .join (map (str , args ))
310
319
active_progress_bar .write (s , end = end , file = file , nolock = nolock )
311
320
312
- def _should_update (self , current , total ) -> bool :
313
- return self .is_enabled and (current % self .refresh_rate == 0 or current == total )
321
+ def _should_update (self , idx : int ) -> bool :
322
+ return self .is_enabled and (idx % self .refresh_rate == 0 )
314
323
315
324
def _update_bar (self , bar : Optional [Tqdm ]) -> None :
316
325
"""Updates the bar by the refresh rate without overshooting."""
0 commit comments