Skip to content

Commit 0b682b8

Browse files
Mark logger_connector as protected (#12195)
1 parent 73bda54 commit 0b682b8

12 files changed

+44
-41
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
332332

333333
- Removed `is_global_zero` check in `training_epoch_loop` before `logger.save`. If you have a custom logger that implements `save` the Trainer will now call `save` on all ranks by default. To change this behavior add `@rank_zero_only` to your `save` implementation ([#12134](https://github.com/PyTorchLightning/pytorch-lightning/pull/12134))
334334

335+
336+
- Marked `trainer.logger_connector` as protected ([#12195](https://github.com/PyTorchLightning/pytorch-lightning/pull/12195))
337+
335338
### Deprecated
336339

337340
- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141))

pytorch_lightning/callbacks/device_stats_monitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def on_train_batch_start(
5858
if not trainer.loggers:
5959
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")
6060

61-
if not trainer.logger_connector.should_update_logs:
61+
if not trainer._logger_connector.should_update_logs:
6262
return
6363

6464
device = trainer.strategy.root_device
@@ -80,7 +80,7 @@ def on_train_batch_end(
8080
if not trainer.loggers:
8181
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")
8282

83-
if not trainer.logger_connector.should_update_logs:
83+
if not trainer._logger_connector.should_update_logs:
8484
return
8585

8686
device = trainer.strategy.root_device

pytorch_lightning/callbacks/gpu_stats_monitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def on_train_batch_start(
150150
if self._log_stats.intra_step_time:
151151
self._snap_intra_step_time = time.time()
152152

153-
if not trainer.logger_connector.should_update_logs:
153+
if not trainer._logger_connector.should_update_logs:
154154
return
155155

156156
gpu_stat_keys = self._get_gpu_stat_keys()
@@ -176,7 +176,7 @@ def on_train_batch_end(
176176
if self._log_stats.inter_step_time:
177177
self._snap_inter_step_time = time.time()
178178

179-
if not trainer.logger_connector.should_update_logs:
179+
if not trainer._logger_connector.should_update_logs:
180180
return
181181

182182
gpu_stat_keys = self._get_gpu_stat_keys() + self._get_gpu_device_stat_keys()

pytorch_lightning/callbacks/lr_monitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def _check_no_key(key: str) -> bool:
149149
self.last_momentum_values = {name + "-momentum": None for name in names_flatten}
150150

151151
def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
152-
if not trainer.logger_connector.should_update_logs:
152+
if not trainer._logger_connector.should_update_logs:
153153
return
154154

155155
if self.logging_interval != "epoch":

pytorch_lightning/core/lightning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def log(
380380

381381
value = apply_to_collection(value, numbers.Number, self.__to_tensor)
382382

383-
if self.trainer.logger_connector.should_reset_tensors(self._current_fx_name):
383+
if self.trainer._logger_connector.should_reset_tensors(self._current_fx_name):
384384
# if we started a new epoch (running its first batch) the hook name has changed
385385
# reset any tensors for the new hook name
386386
results.reset(metrics=False, fx=self._current_fx_name)
@@ -433,7 +433,7 @@ def log(
433433
rank_zero_only=rank_zero_only,
434434
)
435435

436-
self.trainer.logger_connector._current_fx = self._current_fx_name
436+
self.trainer._logger_connector._current_fx = self._current_fx_name
437437

438438
def log_dict(
439439
self,

pytorch_lightning/loops/batch/training_batch_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
7979
void(batch)
8080
self.split_idx, split_batch = self._remaining_splits.pop(0)
8181

82-
self.trainer.logger_connector.on_train_split_start(self.split_idx)
82+
self.trainer._logger_connector.on_train_split_start(self.split_idx)
8383

8484
outputs: Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] = None # for mypy
8585
# choose which loop will run the optimization

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -162,16 +162,16 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
162162
self._has_run = True
163163

164164
def on_advance_end(self) -> None:
165-
self.trainer.logger_connector.epoch_end_reached()
165+
self.trainer._logger_connector.epoch_end_reached()
166166

167-
self._logged_outputs.append(self.trainer.logger_connector.update_eval_epoch_metrics())
167+
self._logged_outputs.append(self.trainer._logger_connector.update_eval_epoch_metrics())
168168

169169
super().on_advance_end()
170170

171171
def on_run_end(self) -> List[_OUT_DICT]:
172172
"""Runs the ``_on_evaluation_epoch_end`` hook."""
173173
# if `done` returned True before any iterations were done, this won't have been called in `on_advance_end`
174-
self.trainer.logger_connector.epoch_end_reached()
174+
self.trainer._logger_connector.epoch_end_reached()
175175

176176
# hook
177177
self._evaluation_epoch_end(self._outputs)
@@ -182,12 +182,12 @@ def on_run_end(self) -> List[_OUT_DICT]:
182182

183183
logged_outputs, self._logged_outputs = self._logged_outputs, [] # free memory
184184
# include any logged outputs on epoch_end
185-
epoch_end_logged_outputs = self.trainer.logger_connector.update_eval_epoch_metrics()
185+
epoch_end_logged_outputs = self.trainer._logger_connector.update_eval_epoch_metrics()
186186
for dl_outputs in logged_outputs:
187187
dl_outputs.update(epoch_end_logged_outputs)
188188

189189
# log metrics
190-
self.trainer.logger_connector.log_eval_end_metrics()
190+
self.trainer._logger_connector.log_eval_end_metrics()
191191

192192
# hook
193193
self._on_evaluation_end()
@@ -266,11 +266,11 @@ def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
266266
self.trainer._call_strategy_hook("on_validation_end", *args, **kwargs)
267267

268268
# reset the logger connector state
269-
self.trainer.logger_connector.reset_results()
269+
self.trainer._logger_connector.reset_results()
270270

271271
def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
272272
"""Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks."""
273-
self.trainer.logger_connector.on_epoch_start()
273+
self.trainer._logger_connector.on_epoch_start()
274274
self.trainer._call_callback_hooks("on_epoch_start", *args, **kwargs)
275275
self.trainer._call_lightning_module_hook("on_epoch_start", *args, **kwargs)
276276

@@ -283,7 +283,7 @@ def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
283283

284284
def _evaluation_epoch_end(self, outputs: List[EPOCH_OUTPUT]) -> None:
285285
"""Runs ``{validation/test}_epoch_end``"""
286-
self.trainer.logger_connector._evaluation_epoch_end()
286+
self.trainer._logger_connector._evaluation_epoch_end()
287287

288288
# with a single dataloader don't pass a 2D list
289289
output_or_outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]] = (
@@ -304,7 +304,7 @@ def _on_evaluation_epoch_end(self) -> None:
304304

305305
self.trainer._call_callback_hooks("on_epoch_end")
306306
self.trainer._call_lightning_module_hook("on_epoch_end")
307-
self.trainer.logger_connector.on_epoch_end()
307+
self.trainer._logger_connector.on_epoch_end()
308308

309309
@staticmethod
310310
def _get_keys(data: dict) -> Iterable[str]:

pytorch_lightning/loops/epoch/evaluation_epoch_loop.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def advance( # type: ignore[override]
135135
self.batch_progress.increment_completed()
136136

137137
# log batch metrics
138-
self.trainer.logger_connector.update_eval_step_metrics()
138+
self.trainer._logger_connector.update_eval_step_metrics()
139139

140140
# track epoch level outputs
141141
if self._should_track_batch_outputs_for_epoch_end() and output is not None:
@@ -242,7 +242,7 @@ def _on_evaluation_batch_start(self, **kwargs: Any) -> None:
242242
Raises:
243243
AssertionError: If the number of dataloaders is None (has not yet been set).
244244
"""
245-
self.trainer.logger_connector.on_batch_start(**kwargs)
245+
self.trainer._logger_connector.on_batch_start(**kwargs)
246246

247247
kwargs.setdefault("dataloader_idx", 0) # TODO: the argument should be keyword for these
248248
hook_name = "on_test_batch_start" if self.trainer.testing else "on_validation_batch_start"
@@ -263,7 +263,7 @@ def _on_evaluation_batch_end(self, output: Optional[STEP_OUTPUT], **kwargs: Any)
263263
self.trainer._call_callback_hooks(hook_name, output, *kwargs.values())
264264
self.trainer._call_lightning_module_hook(hook_name, output, *kwargs.values())
265265

266-
self.trainer.logger_connector.on_batch_end()
266+
self.trainer._logger_connector.on_batch_end()
267267

268268
def _build_kwargs(self, kwargs: OrderedDict, batch: Any, batch_idx: int) -> OrderedDict:
269269
"""Helper function to build the arguments for the current step.

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
167167

168168
self.batch_progress.increment_ready()
169169

170-
self.trainer.logger_connector.on_batch_start(batch, batch_idx)
170+
self.trainer._logger_connector.on_batch_start(batch, batch_idx)
171171

172172
if batch is None:
173173
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
@@ -225,7 +225,7 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
225225
"on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs
226226
)
227227
self.trainer._call_callback_hooks("on_batch_end")
228-
self.trainer.logger_connector.on_batch_end()
228+
self.trainer._logger_connector.on_batch_end()
229229

230230
self.batch_progress.increment_completed()
231231

@@ -235,7 +235,7 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
235235
# -----------------------------------------
236236
# SAVE METRICS TO LOGGERS AND PROGRESS_BAR
237237
# -----------------------------------------
238-
self.trainer.logger_connector.update_train_step_metrics()
238+
self.trainer._logger_connector.update_train_step_metrics()
239239

240240
def on_advance_end(self) -> None:
241241
# -----------------------------------------
@@ -504,7 +504,7 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
504504
def _save_loggers_on_train_batch_end(self) -> None:
505505
"""Flushes loggers to disk."""
506506
# when loggers should save to disk
507-
should_flush_logs = self.trainer.logger_connector.should_flush_logs
507+
should_flush_logs = self.trainer._logger_connector.should_flush_logs
508508
if should_flush_logs:
509509
for logger in self.trainer.loggers:
510510
logger.save()

pytorch_lightning/loops/fit_loop.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def on_advance_start(self) -> None: # type: ignore[override]
258258

259259
self.epoch_progress.increment_ready()
260260

261-
self.trainer.logger_connector.on_epoch_start()
261+
self.trainer._logger_connector.on_epoch_start()
262262

263263
self.trainer._call_callback_hooks("on_epoch_start")
264264
self.trainer._call_lightning_module_hook("on_epoch_start")
@@ -282,7 +282,7 @@ def advance(self) -> None: # type: ignore[override]
282282

283283
def on_advance_end(self) -> None:
284284
# inform logger the batch loop has finished
285-
self.trainer.logger_connector.epoch_end_reached()
285+
self.trainer._logger_connector.epoch_end_reached()
286286

287287
# get the model and call model.training_epoch_end
288288
model = self.trainer.lightning_module
@@ -312,7 +312,7 @@ def on_advance_end(self) -> None:
312312
self.trainer._call_callback_hooks("on_epoch_end")
313313
self.trainer._call_lightning_module_hook("on_epoch_end")
314314

315-
self.trainer.logger_connector.on_epoch_end()
315+
self.trainer._logger_connector.on_epoch_end()
316316

317317
if self.epoch_loop._num_ready_batches_reached():
318318
self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=True)
@@ -325,7 +325,7 @@ def on_advance_end(self) -> None:
325325
# TODO(@carmocca): deprecate and rename so users don't get confused
326326
self.global_step -= 1
327327
# log epoch metrics
328-
self.trainer.logger_connector.update_train_epoch_metrics()
328+
self.trainer._logger_connector.update_train_epoch_metrics()
329329
self.global_step += 1
330330

331331
# if fault tolerant is enabled and process has been notified, exit.

pytorch_lightning/trainer/trainer.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ def __init__(
500500
amp_level=amp_level,
501501
plugins=plugins,
502502
)
503-
self.logger_connector = LoggerConnector(self, log_gpu_memory)
503+
self._logger_connector = LoggerConnector(self, log_gpu_memory)
504504
self._callback_connector = CallbackConnector(self)
505505
self._checkpoint_connector = CheckpointConnector(self, resume_from_checkpoint)
506506
self._signal_connector = SignalConnector(self)
@@ -614,7 +614,7 @@ def __init__(
614614

615615
# init logger flags
616616
self._loggers: List[LightningLoggerBase]
617-
self.logger_connector.on_trainer_init(logger, flush_logs_every_n_steps, log_every_n_steps, move_metrics_to_cpu)
617+
self._logger_connector.on_trainer_init(logger, flush_logs_every_n_steps, log_every_n_steps, move_metrics_to_cpu)
618618

619619
# init debugging flags
620620
self.val_check_interval: Union[int, float]
@@ -1210,8 +1210,8 @@ def _run(
12101210
# ----------------------------
12111211

12121212
# reset logger connector
1213-
self.logger_connector.reset_results()
1214-
self.logger_connector.reset_metrics()
1213+
self._logger_connector.reset_results()
1214+
self._logger_connector.reset_metrics()
12151215

12161216
# strategy will configure model and move it to the device
12171217
self.strategy.setup(self)
@@ -1302,7 +1302,7 @@ def _teardown(self):
13021302
# loop should never be `None` here but it can because we don't know the trainer stage with `ddp_spawn`
13031303
if loop is not None:
13041304
loop.teardown()
1305-
self.logger_connector.teardown()
1305+
self._logger_connector.teardown()
13061306
self._signal_connector.teardown()
13071307

13081308
def run_stage(self) -> None:
@@ -1397,8 +1397,8 @@ def _run_sanity_check(self) -> None:
13971397
self.sanity_checking = True
13981398

13991399
# reset logger connector
1400-
self.logger_connector.reset_results()
1401-
self.logger_connector.reset_metrics()
1400+
self._logger_connector.reset_results()
1401+
self._logger_connector.reset_metrics()
14021402

14031403
self._call_callback_hooks("on_sanity_check_start")
14041404

@@ -1415,8 +1415,8 @@ def _run_sanity_check(self) -> None:
14151415
self._call_callback_hooks("on_sanity_check_end")
14161416

14171417
# reset logger connector
1418-
self.logger_connector.reset_results()
1419-
self.logger_connector.reset_metrics()
1418+
self._logger_connector.reset_results()
1419+
self._logger_connector.reset_metrics()
14201420

14211421
# reset the progress tracking state after sanity checking. we don't need to set the state before
14221422
# because sanity check only runs when we are not restarting
@@ -2646,15 +2646,15 @@ def loggers(self, loggers: Optional[List[LightningLoggerBase]]) -> None:
26462646

26472647
@property
26482648
def callback_metrics(self) -> dict:
2649-
return self.logger_connector.callback_metrics
2649+
return self._logger_connector.callback_metrics
26502650

26512651
@property
26522652
def logged_metrics(self) -> dict:
2653-
return self.logger_connector.logged_metrics
2653+
return self._logger_connector.logged_metrics
26542654

26552655
@property
26562656
def progress_bar_metrics(self) -> dict:
2657-
return self.logger_connector.progress_bar_metrics
2657+
return self._logger_connector.progress_bar_metrics
26582658

26592659
@property
26602660
def _results(self) -> Optional[_ResultCollection]:

tests/loggers/test_tensorboard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def __init__(self):
260260
def training_step(self, *args):
261261
self.log("foo", 1, on_step=True, on_epoch=True)
262262
if not self.trainer.fit_loop._should_accumulate():
263-
if self.trainer.logger_connector.should_update_logs:
263+
if self.trainer._logger_connector.should_update_logs:
264264
self.indexes.append(self.trainer.global_step)
265265
return super().training_step(*args)
266266

0 commit comments

Comments
 (0)