25
25
else :
26
26
from tqdm import tqdm as _tqdm
27
27
28
+ import pytorch_lightning as pl
28
29
from pytorch_lightning .callbacks .progress .base import ProgressBarBase
29
30
from pytorch_lightning .utilities .distributed import rank_zero_debug
30
31
@@ -207,12 +208,10 @@ def init_test_tqdm(self) -> Tqdm:
207
208
return bar
208
209
209
210
def on_sanity_check_start (self , trainer , pl_module ):
210
- super ().on_sanity_check_start (trainer , pl_module )
211
211
self .val_progress_bar = self .init_sanity_tqdm ()
212
212
self .main_progress_bar = Tqdm (disable = True ) # dummy progress bar
213
213
214
214
def on_sanity_check_end (self , trainer , pl_module ):
215
- super ().on_sanity_check_end (trainer , pl_module )
216
215
self .main_progress_bar .close ()
217
216
self .val_progress_bar .close ()
218
217
@@ -234,49 +233,59 @@ def on_train_epoch_start(self, trainer, pl_module):
234
233
235
234
def on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
236
235
super ().on_train_batch_end (trainer , pl_module , outputs , batch , batch_idx )
237
- total_batches = self .total_train_batches + self .total_val_batches
238
- total_batches = convert_inf (total_batches )
239
- if self ._should_update (self .train_batch_idx , total_batches ):
236
+ if self ._should_update (self .train_batch_idx ):
240
237
self ._update_bar (self .main_progress_bar )
241
238
self .main_progress_bar .set_postfix (self .get_metrics (trainer , pl_module ))
242
239
240
+ def on_train_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
241
+ if self .is_enabled :
242
+ self ._update_bar (self .main_progress_bar )
243
+ self .main_progress_bar .set_postfix (self .get_metrics (trainer , pl_module ))
244
+
245
+ def on_train_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
246
+ self .main_progress_bar .close ()
247
+
243
248
def on_validation_start (self , trainer , pl_module ):
244
249
super ().on_validation_start (trainer , pl_module )
245
250
if trainer .sanity_checking :
246
251
reset (self .val_progress_bar , total = sum (trainer .num_sanity_val_batches ), current = self .val_batch_idx )
247
252
else :
248
- self ._update_bar (self .main_progress_bar ) # fill up remaining
253
+ if trainer .state .fn == pl .trainer .states .TrainerFn .FITTING :
254
+ self ._update_bar (self .main_progress_bar ) # fill up remaining
249
255
self .val_progress_bar = self .init_validation_tqdm ()
250
256
reset (self .val_progress_bar , total = self .total_val_batches , current = self .val_batch_idx )
251
257
252
258
def on_validation_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ):
253
259
super ().on_validation_batch_end (trainer , pl_module , outputs , batch , batch_idx , dataloader_idx )
254
- if self ._should_update (self .val_batch_idx , convert_inf (self .total_val_batches )):
260
+ if self ._should_update (self .val_batch_idx ):
261
+ self ._update_bar (self .val_progress_bar )
262
+ if trainer .state .fn == pl .trainer .states .TrainerFn .FITTING :
263
+ self ._update_bar (self .main_progress_bar )
264
+
265
+ def on_validation_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
266
+ if self .is_enabled :
255
267
self ._update_bar (self .val_progress_bar )
256
- self ._update_bar (self .main_progress_bar )
257
268
258
269
def on_validation_end (self , trainer , pl_module ):
259
- super ().on_validation_end (trainer , pl_module )
260
- if self .main_progress_bar is not None :
270
+ if self .main_progress_bar is not None and trainer .state .fn == pl .trainer .states .TrainerFn .FITTING :
261
271
self .main_progress_bar .set_postfix (self .get_metrics (trainer , pl_module ))
262
272
self .val_progress_bar .close ()
263
273
264
- def on_train_end (self , trainer , pl_module ):
265
- super ().on_train_end (trainer , pl_module )
266
- self .main_progress_bar .close ()
267
-
268
274
def on_test_start (self , trainer , pl_module ):
269
275
super ().on_test_start (trainer , pl_module )
270
276
self .test_progress_bar = self .init_test_tqdm ()
271
277
self .test_progress_bar .total = convert_inf (self .total_test_batches )
272
278
273
279
def on_test_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ):
274
280
super ().on_test_batch_end (trainer , pl_module , outputs , batch , batch_idx , dataloader_idx )
275
- if self ._should_update (self .test_batch_idx , self .total_test_batches ):
281
+ if self ._should_update (self .test_batch_idx ):
282
+ self ._update_bar (self .test_progress_bar )
283
+
284
+ def on_test_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
285
+ if self .is_enabled :
276
286
self ._update_bar (self .test_progress_bar )
277
287
278
288
def on_test_end (self , trainer , pl_module ):
279
- super ().on_test_end (trainer , pl_module )
280
289
self .test_progress_bar .close ()
281
290
282
291
def on_predict_epoch_start (self , trainer , pl_module ):
@@ -286,7 +295,7 @@ def on_predict_epoch_start(self, trainer, pl_module):
286
295
287
296
def on_predict_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ):
288
297
super ().on_predict_batch_end (trainer , pl_module , outputs , batch , batch_idx , dataloader_idx )
289
- if self ._should_update (self .predict_batch_idx , self . total_predict_batches ):
298
+ if self ._should_update (self .predict_batch_idx ):
290
299
self ._update_bar (self .predict_progress_bar )
291
300
292
301
def on_predict_end (self , trainer , pl_module ):
@@ -310,8 +319,8 @@ def print(
310
319
s = sep .join (map (str , args ))
311
320
active_progress_bar .write (s , end = end , file = file , nolock = nolock )
312
321
313
- def _should_update (self , current , total ) -> bool :
314
- return self .is_enabled and (current % self .refresh_rate == 0 or current == total )
322
+ def _should_update (self , idx : int ) -> bool :
323
+ return self .is_enabled and (idx % self .refresh_rate == 0 )
315
324
316
325
def _update_bar (self , bar : Optional [Tqdm ]) -> None :
317
326
"""Updates the bar by the refresh rate without overshooting."""
0 commit comments