Skip to content

Commit efec3d4

Browse files
Binh Tangananthsubawaelchli
authored
Move logger and profiler finalization to trainer's teardown (#8685)
Co-authored-by: ananthsub <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 963c267 commit efec3d4

File tree

8 files changed

+90
-19
lines changed

8 files changed

+90
-19
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
117117
- Fixed recursive call for `apply_to_collection(include_none=False)` ([#8719](https://github.com/PyTorchLightning/pytorch-lightning/pull/8719))
118118

119119

120+
- Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8333](https://github.com/PyTorchLightning/pytorch-lightning/issues/8333))
121+
120122

121123
## [1.4.0] - 2021-07-27
122124

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from pytorch_lightning.loops.dataloader import DataLoaderLoop
2121
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
2222
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
23-
from pytorch_lightning.trainer.states import TrainerFn
2423
from pytorch_lightning.utilities.model_helpers import is_overridden
2524
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
2625

@@ -206,10 +205,6 @@ def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
206205
else:
207206
self.trainer.call_hook("on_validation_end", *args, **kwargs)
208207

209-
if self.trainer.state.fn != TrainerFn.FITTING:
210-
# summarize profile results
211-
self.trainer.profiler.describe()
212-
213208
# reset any `torchmetrics.Metric` and the logger connector state
214209
self.trainer.logger_connector.reset(metrics=True)
215210

pytorch_lightning/loops/dataloader/prediction_loop.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,6 @@ def on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
119119
Returns:
120120
the results for all dataloaders
121121
"""
122-
self.trainer.profiler.describe()
123-
124122
results = self.predictions
125123

126124
self.trainer.call_hook("on_predict_epoch_end", results)

pytorch_lightning/loops/fit_loop.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -225,15 +225,6 @@ def on_run_end(self) -> None:
225225
# hook
226226
self.trainer.call_hook("on_train_end")
227227

228-
# todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
229-
# It might be related to xla tensors blocked when moving the cpu
230-
# kill loggers
231-
if self.trainer.logger is not None:
232-
self.trainer.logger.finalize("success")
233-
234-
# summarize profile results
235-
self.trainer.profiler.describe()
236-
237228
# give accelerators a chance to finish
238229
self.trainer.accelerator.on_train_end()
239230

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,9 @@ def new_process(self, process_idx, trainer, mp_queue):
203203
# persist info in ddp_spawn
204204
self.transfer_distrib_spawn_state_on_fit_end(results)
205205

206+
# ensure that spawned processes go through teardown before joining
207+
trainer._call_teardown_hook()
208+
206209
def post_dispatch(self):
207210
# restore main state with best weights
208211
best_path = self.mp_queue.get()

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
172172
if self.local_rank == 0:
173173
time.sleep(2)
174174

175+
# ensure that spawned processes go through teardown before joining
176+
trainer._call_teardown_hook()
177+
175178
@parameter_validation
176179
def model_to_device(self) -> None:
177180
self.model = self.wrapped_model.to(self.root_device)

pytorch_lightning/trainer/trainer.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
)
7777
from pytorch_lightning.utilities.debugging import InternalDebugger
7878
from pytorch_lightning.utilities.distributed import distributed_available
79+
from pytorch_lightning.utilities.enums import DistributedType
7980
from pytorch_lightning.utilities.exceptions import MisconfigurationException
8081
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
8182
from pytorch_lightning.utilities.model_helpers import is_overridden
@@ -944,8 +945,10 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
944945
if self.state.fn == TrainerFn.FITTING:
945946
self.call_hook("on_fit_end")
946947

947-
# teardown
948-
self._call_teardown_hook()
948+
# teardown if necessary (similar calls for spawn plugins are excluded as they have
949+
# been included at the end of `new_process` functions)
950+
if self._distrib_type not in DistributedType.interactive_compatible_types():
951+
self._call_teardown_hook()
949952

950953
if self.state.status != TrainerStatus.INTERRUPTED:
951954
self.state.status = TrainerStatus.FINISHED
@@ -1211,7 +1214,7 @@ def _call_teardown_hook(self) -> None:
12111214

12121215
if self.datamodule is not None:
12131216
self.datamodule.teardown(stage=fn)
1214-
self.profiler.teardown(stage=fn)
1217+
12151218
self.teardown(stage=fn)
12161219
self.lightning_module.teardown(stage=fn)
12171220

@@ -1220,6 +1223,14 @@ def _call_teardown_hook(self) -> None:
12201223
# these could have become stale if metrics are defined in `setup`
12211224
self.lightning_module._metric_attributes = None
12221225

1226+
# todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
1227+
# It might be related to xla tensors blocked when moving the cpu kill loggers.
1228+
if self.logger is not None:
1229+
self.logger.finalize("success")
1230+
1231+
# summarize profile results
1232+
self.profiler.describe()
1233+
12231234
def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
12241235
if self.lightning_module:
12251236
prev_fx_name = self.lightning_module._current_fx_name

tests/trainer/logging_/test_distributed_logging.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from typing import Any, Dict, Optional, Union
1516
from unittest import mock
1617
from unittest.mock import Mock
1718

19+
import pytorch_lightning as pl
1820
from pytorch_lightning import Callback, Trainer
21+
from pytorch_lightning.loggers.base import LightningLoggerBase
1922
from tests.helpers import BoringModel
2023
from tests.helpers.runif import RunIf
2124

@@ -101,3 +104,68 @@ def on_train_start(self, trainer, pl_module):
101104
callbacks=[LoggerCallsObserver()],
102105
)
103106
trainer.fit(model)
107+
108+
109+
def test_logger_after_fit_predict_test_calls(tmpdir):
110+
"""
111+
Make sure logger outputs are finalized after fit, prediction, and test calls.
112+
"""
113+
114+
class BufferLogger(LightningLoggerBase):
115+
def __init__(self):
116+
super().__init__()
117+
self.buffer = {}
118+
self.logs = {}
119+
120+
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
121+
self.buffer.update(metrics)
122+
123+
def finalize(self, status: str) -> None:
124+
self.logs.update(self.buffer)
125+
self.buffer = {}
126+
127+
@property
128+
def experiment(self) -> Any:
129+
return None
130+
131+
@property
132+
def version(self) -> Union[int, str]:
133+
return 1
134+
135+
@property
136+
def name(self) -> str:
137+
return "BufferLogger"
138+
139+
def log_hyperparams(self, *args, **kwargs) -> None:
140+
return None
141+
142+
class LoggerCallsObserver(Callback):
143+
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
144+
trainer.logger.log_metrics({"fit": 1})
145+
146+
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
147+
trainer.logger.log_metrics({"validate": 1})
148+
149+
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
150+
trainer.logger.log_metrics({"predict": 1})
151+
152+
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
153+
trainer.logger.log_metrics({"test": 1})
154+
155+
model = BoringModel()
156+
trainer = Trainer(
157+
default_root_dir=tmpdir,
158+
limit_train_batches=1,
159+
limit_val_batches=1,
160+
max_epochs=1,
161+
logger=BufferLogger(),
162+
callbacks=[LoggerCallsObserver()],
163+
)
164+
165+
assert not trainer.logger.logs
166+
trainer.fit(model)
167+
assert trainer.logger.logs == {"fit": 1, "validate": 1}
168+
trainer.test(model)
169+
assert trainer.logger.logs == {"fit": 1, "validate": 1, "test": 1}
170+
trainer.predict(model)
171+
assert trainer.logger.logs == {"fit": 1, "validate": 1, "test": 1, "predict": 1}

0 commit comments

Comments
 (0)