-
Notifications
You must be signed in to change notification settings - Fork 3.5k
[bugfix] Fix dataloading for iterable datasets and limit_train_batches #7306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e2fbbd0
d1fb173
4fb0aa2
8058a3c
a5f70c7
ec93877
f1d9e4d
069d5b6
3df1116
05656ef
0044e0d
15e52be
c95f996
01dca98
89d284a
67bc801
62b546a
cc1ba4a
a3f959c
49cd18c
5cd2482
0e7b0ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -484,9 +484,9 @@ def run_training_epoch(self): | |
self.trainer.logger_connector.log_train_step_metrics(batch_output) | ||
|
||
# ----------------------------------------- | ||
# VALIDATE IF NEEDED + CHECKPOINT CALLBACK | ||
# VALIDATE IF NEEDED | ||
# ----------------------------------------- | ||
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch) | ||
should_check_val = self._should_check_val_fx(batch_idx, is_last_batch) | ||
if should_check_val: | ||
self.trainer.validating = True | ||
self.trainer.run_evaluation() | ||
|
@@ -535,7 +535,7 @@ def run_training_epoch(self): | |
# log epoch metrics | ||
self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output) | ||
|
||
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) | ||
should_check_val = self._should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) | ||
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) | ||
should_train_only = self.trainer.disable_validation or should_skip_eval | ||
|
||
|
@@ -825,19 +825,34 @@ def should_accumulate(self): | |
is_final_batch = self._num_training_batches_reached() | ||
return not (accumulation_done or is_final_batch) | ||
|
||
def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False): | ||
# decide if we should run validation | ||
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 | ||
is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 | ||
can_check_val = self.trainer.enable_validation and is_val_check_epoch | ||
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf") | ||
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0 | ||
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool: | ||
""" Decide if we should run validation. """ | ||
|
||
if not self.trainer.enable_validation: | ||
return False | ||
|
||
# check if this epoch is eligible to run validation | ||
if (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0: | ||
return False | ||
|
||
should_check_val = ((is_val_check_batch and epoch_end_val_check) or self.trainer.should_stop | ||
or is_last_batch_for_infinite_dataset | ||
) if on_epoch else (is_val_check_batch and not epoch_end_val_check) | ||
# val_check_batch is inf for iterable datasets with no length defined | ||
# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch | ||
is_val_check_batch = False | ||
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Love this refactor! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @kaushikb11 thanks! it still feels complicated to me. part of that is from i'm wondering what's a better way to split "when to stop training mid-epoch" vs when to run validation or if a split is needed at all. |
||
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 | ||
elif self.trainer.val_check_batch != float('inf'): | ||
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 | ||
|
||
return should_check_val and can_check_val | ||
# Note: num_training_batches is also inf for iterable datasets with no length defined | ||
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0 | ||
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf") | ||
|
||
if on_epoch: | ||
return ( | ||
is_val_check_batch and epoch_end_val_check | ||
) or self.trainer.should_stop or is_last_batch_for_infinite_dataset | ||
else: | ||
return is_val_check_batch and not epoch_end_val_check | ||
|
||
def build_train_args(self, batch, batch_idx, opt_idx, hiddens): | ||
# enable not needing to add opt_idx to training_step | ||
|
Uh oh!
There was an error while loading. Please reload this page.