Skip to content

Fix self.log(on_epoch=True) on_batch_start #9780

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -527,11 +527,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `broadcast` in `DDPPlugin` and ``DDPSpawnPlugin` to respect the `src` input ([#9691](https://github.com/PyTorchLightning/pytorch-lightning/pull/9691))


- Fixed `self.log(on_epoch=True)` for the `on_batch_start` and `on_train_batch_start` hooks ([#9780](https://github.com/PyTorchLightning/pytorch-lightning/pull/9780))


- Fixed restoring training state during `trainer.fit` only ([#9413](https://github.com/PyTorchLightning/pytorch-lightning/pull/9413))


- Fixed DeepSpeed and Lightning both calling the scheduler ([#9788](https://github.com/PyTorchLightning/pytorch-lightning/pull/9788))


- Fixed missing arguments when saving hyperparameters from the parent class but not from the child class ([#9800](https://github.com/PyTorchLightning/pytorch-lightning/pull/9800))


Expand Down
49 changes: 5 additions & 44 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@
from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop
from pytorch_lightning.loops.utilities import _get_active_optimizers
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.warnings import WarningCache

_OUTPUTS_TYPE = List[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]]

Expand All @@ -43,7 +40,6 @@ def __init__(self) -> None:
self.manual_loop = ManualOptimization()

self._outputs: _OUTPUTS_TYPE = []
self._warning_cache: WarningCache = WarningCache()
self._remaining_splits: Optional[List[Any]] = None

@property
Expand All @@ -59,42 +55,6 @@ def connect(
if manual_loop is not None:
self.manual_loop = manual_loop

def run(self, batch: Any, batch_idx: int) -> AttributeDict:
"""Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks.

Args:
batch: the current batch to run the train step on
batch_idx: the index of the current batch
"""
if batch is None:
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
return AttributeDict(signal=0, outputs=[])

# hook
self.trainer.logger_connector.on_batch_start()
response = self.trainer.call_hook("on_batch_start")
if response == -1:
return AttributeDict(signal=-1)

# hook
# TODO: Update this in v1.7 (deprecation: #9816)
model_fx = self.trainer.lightning_module.on_train_batch_start
extra_kwargs = (
{"dataloader_idx": 0}
if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True)
else {}
)
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
if response == -1:
return AttributeDict(signal=-1)

self.trainer.fit_loop.epoch_loop.batch_progress.increment_started()

super().run(batch, batch_idx)

output, self._outputs = AttributeDict(signal=0, outputs=self._outputs), None # free memory
return output

def reset(self) -> None:
"""Resets the loop state."""
self._outputs = []
Expand All @@ -117,11 +77,10 @@ def advance(self, batch, batch_idx):
batch_idx: the index of the current batch
"""
void(batch)
split_idx, split_batch = self._remaining_splits.pop(0)
self.split_idx = split_idx
self.split_idx, split_batch = self._remaining_splits.pop(0)

# let logger connector extract current batch size
self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch)
self.trainer.logger_connector.on_train_split_start(self.split_idx, split_batch)

# choose which loop will run the optimization
if self.trainer.lightning_module.automatic_optimization:
Expand All @@ -135,10 +94,12 @@ def advance(self, batch, batch_idx):
# then `advance` doesn't finish and an empty dict is returned
self._outputs.append(outputs)

def on_run_end(self) -> None:
def on_run_end(self) -> _OUTPUTS_TYPE:
self.optimizer_loop._hiddens = None
# this is not necessary as the manual loop runs for only 1 iteration, but just in case
self.manual_loop._hiddens = None
output, self._outputs = self._outputs, None # free memory
return output

def teardown(self) -> None:
# release memory
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,10 @@ def _on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx:
Raises:
AssertionError: If the number of dataloaders is None (has not yet been set).
"""
self.trainer.logger_connector.on_batch_start()
self.trainer.logger_connector.on_batch_start(batch_idx)

assert self._num_dataloaders is not None
self.trainer.logger_connector.on_evaluation_batch_start(batch, batch_idx, dataloader_idx, self._num_dataloaders)
self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self._num_dataloaders)

if self.trainer.testing:
self.trainer.call_hook("on_test_batch_start", batch, batch_idx, dataloader_idx)
Expand Down
41 changes: 33 additions & 8 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pytorch_lightning.utilities.fetching import AbstractDataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.warnings import WarningCache

_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]

Expand Down Expand Up @@ -57,6 +58,7 @@ def __init__(self, min_steps: int, max_steps: int):

self._results = ResultCollection(training=True)
self._outputs: _OUTPUTS_TYPE = []
self._warning_cache = WarningCache()
self._dataloader_iter: Optional[Iterator] = None
# caches the loaded dataloader state until dataloader objects are available
self._dataloader_state_dict: Dict[str, Any] = {}
Expand Down Expand Up @@ -151,14 +153,37 @@ def advance(self, *args: Any, **kwargs: Any) -> None:

self.batch_progress.increment_ready()

with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, batch_idx)
if batch is None:
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
batch_output = []
else:
# hook
self.trainer.logger_connector.on_batch_start(batch_idx)
response = self.trainer.call_hook("on_batch_start")
if response == -1:
self.batch_progress.increment_processed()
raise StopIteration

# TODO: Update this in v1.7 (deprecation: #9816)
model_fx = self.trainer.lightning_module.on_train_batch_start
extra_kwargs = (
{"dataloader_idx": 0}
if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True)
else {}
)

self.batch_progress.increment_processed()
# hook
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
if response == -1:
self.batch_progress.increment_processed()
raise StopIteration

# when returning -1 from train_step, we end epoch early
if batch_output.signal == -1:
raise StopIteration
self.batch_progress.increment_started()

with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, batch_idx)

self.batch_progress.increment_processed()

# update non-plateau LR schedulers
# update epoch-interval ones only when we are at the end of training epoch
Expand All @@ -167,7 +192,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
self.update_lr_schedulers("epoch", update_plateau_schedulers=False)

batch_end_outputs = self._prepare_outputs_training_batch_end(
batch_output.outputs,
batch_output,
automatic=self.trainer.lightning_module.trainer.lightning_module.automatic_optimization,
num_optimizers=len(self.trainer.optimizers),
)
Expand All @@ -186,7 +211,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
self.batch_progress.increment_completed()

if is_overridden("training_epoch_end", self.trainer.lightning_module):
self._outputs.append(batch_output.outputs)
self._outputs.append(batch_output)

# -----------------------------------------
# SAVE METRICS TO LOGGERS AND PROGRESS_BAR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,14 @@ def _increment_eval_log_step(self) -> None:
elif self.trainer.state.stage is RunningStage.TESTING:
self._test_log_step += 1

def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None:
def on_evaluation_batch_start(self, batch: Any, dataloader_idx: int, num_dataloaders: int) -> None:
model = self.trainer.lightning_module
# set dataloader_idx only if multiple ones
model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None

# track batch_size
assert self.trainer._results is not None
self.trainer._results.extract_batch_size(batch)
self._batch_idx = batch_idx

def update_eval_step_metrics(self) -> None:
if self.trainer.sanity_checking:
Expand Down Expand Up @@ -213,14 +212,12 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]:
Train metric updates
"""

def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None:
def on_train_split_start(self, split_idx: int, split_batch: Any) -> None:
assert self.trainer._results is not None
# when the user requests `dataloader_iter`, we can't track the batch_size
# and this is left to user responsibility.
if isinstance(split_batch, pl.utilities.fetching.DataLoaderIterDataFetcher):
self.trainer._results.extract_batch_size(split_batch)

self._batch_idx = batch_idx
self._split_idx = split_idx

def update_train_step_metrics(self) -> None:
Expand Down Expand Up @@ -267,7 +264,8 @@ def _log_gpus_metrics(self) -> None:
def on_epoch_start(self) -> None:
self._epoch_end_reached = False

def on_batch_start(self) -> None:
def on_batch_start(self, batch_idx: int) -> None:
self._batch_idx = batch_idx
self._epoch_end_reached = False

def epoch_end_reached(self) -> None:
Expand Down
8 changes: 2 additions & 6 deletions tests/loops/test_evaluation_loop_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,8 @@ def backward(self, loss, optimizer, optimizer_idx):
# simulate training manually
trainer.state.stage = RunningStage.TRAINING
batch_idx, batch = 0, next(iter(model.train_dataloader()))
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
assert out.signal == 0
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)

train_step_out = out.outputs
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
assert isinstance(train_step_out["loss"], torch.Tensor)
Expand Down Expand Up @@ -129,10 +127,8 @@ def backward(self, loss, optimizer, optimizer_idx):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
batch_idx, batch = 0, next(iter(model.train_dataloader()))
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
assert out.signal == 0
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)

train_step_out = out.outputs
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
assert isinstance(train_step_out["loss"], torch.Tensor)
Expand Down
30 changes: 10 additions & 20 deletions tests/loops/test_training_loop_flow_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,8 @@ def backward(self, loss, optimizer, optimizer_idx):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
batch_idx, batch = 0, next(iter(model.train_dataloader()))
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
assert out.signal == 0
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)

train_step_out = out.outputs
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
assert isinstance(train_step_out["loss"], torch.Tensor)
Expand Down Expand Up @@ -221,10 +219,8 @@ def backward(self, loss, optimizer, optimizer_idx):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
batch_idx, batch = 0, next(iter(model.train_dataloader()))
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
assert out.signal == 0
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)

train_step_out = out.outputs
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
assert isinstance(train_step_out["loss"], torch.Tensor)
Expand Down Expand Up @@ -311,8 +307,7 @@ def training_step(self, batch, batch_idx):
for batch_idx, batch in enumerate(model.train_dataloader()):
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
if not batch_idx % 2:
assert out.outputs == []
assert out.signal == 0
assert out == []


def test_training_step_none_batches(tmpdir):
Expand All @@ -321,7 +316,6 @@ def test_training_step_none_batches(tmpdir):
class TestModel(BoringModel):
def __init__(self):
super().__init__()

self.counter = 0

def collate_none_when_even(self, batch):
Expand All @@ -333,12 +327,17 @@ def collate_none_when_even(self, batch):
return result

def train_dataloader(self):
return DataLoader(RandomDataset(32, 64), collate_fn=self.collate_none_when_even)
return DataLoader(RandomDataset(32, 4), collate_fn=self.collate_none_when_even)

def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
if batch_idx % 2 == 0:
assert outputs == []
else:
assert outputs

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=4,
limit_val_batches=1,
max_epochs=4,
enable_model_summary=False,
Expand All @@ -348,12 +347,3 @@ def train_dataloader(self):

with pytest.warns(UserWarning, match=r".*train_dataloader yielded None.*"):
trainer.fit(model)

trainer.state.stage = RunningStage.TRAINING

# manually check a few batches
for batch_idx, batch in enumerate(model.train_dataloader()):
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
if not batch_idx % 2:
assert out.outputs == []
assert out.signal == 0
12 changes: 12 additions & 0 deletions tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,21 @@ def on_train_epoch_start(self, _, pl_module):
pl_module, "on_train_epoch_start", on_steps=self.choices, on_epochs=[True], prob_bars=self.choices
)

def on_batch_start(self, _, pl_module, *__):
self.make_logging(
pl_module, "on_batch_start", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
)

def on_batch_end(self, _, pl_module):
self.make_logging(
pl_module, "on_batch_end", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
)

def on_train_batch_start(self, _, pl_module, *__):
self.make_logging(
pl_module, "on_train_batch_start", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
)

def on_train_batch_end(self, _, pl_module, *__):
self.make_logging(
pl_module, "on_train_batch_end", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
Expand Down Expand Up @@ -323,7 +333,9 @@ def training_step(self, batch, batch_idx):
"on_train_start": 1,
"on_epoch_start": 1,
"on_train_epoch_start": 1,
"on_train_batch_start": 2,
"on_train_batch_end": 2,
"on_batch_start": 2,
"on_batch_end": 2,
"on_train_epoch_end": 1,
"on_epoch_end": 1,
Expand Down