Skip to content

Mark logger_connector as protected #12195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 5, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed `is_global_zero` check in `training_epoch_loop` before `logger.save`. If you have a custom logger that implements `save` the Trainer will now call `save` on all ranks by default. To change this behavior add `@rank_zero_only` to your `save` implementation ([#12134](https://github.com/PyTorchLightning/pytorch-lightning/pull/12134))


- Marked `trainer.logger_connector` as protected ([#12195](https://github.com/PyTorchLightning/pytorch-lightning/pull/12195))

### Deprecated

- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141))
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/device_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def on_train_batch_start(
if not trainer.loggers:
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")

if not trainer.logger_connector.should_update_logs:
if not trainer._logger_connector.should_update_logs:
return

device = trainer.strategy.root_device
Expand All @@ -80,7 +80,7 @@ def on_train_batch_end(
if not trainer.loggers:
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")

if not trainer.logger_connector.should_update_logs:
if not trainer._logger_connector.should_update_logs:
return

device = trainer.strategy.root_device
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/gpu_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def on_train_batch_start(
if self._log_stats.intra_step_time:
self._snap_intra_step_time = time.time()

if not trainer.logger_connector.should_update_logs:
if not trainer._logger_connector.should_update_logs:
return

gpu_stat_keys = self._get_gpu_stat_keys()
Expand All @@ -176,7 +176,7 @@ def on_train_batch_end(
if self._log_stats.inter_step_time:
self._snap_inter_step_time = time.time()

if not trainer.logger_connector.should_update_logs:
if not trainer._logger_connector.should_update_logs:
return

gpu_stat_keys = self._get_gpu_stat_keys() + self._get_gpu_device_stat_keys()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _check_no_key(key: str) -> bool:
self.last_momentum_values = {name + "-momentum": None for name in names_flatten}

def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
if not trainer.logger_connector.should_update_logs:
if not trainer._logger_connector.should_update_logs:
return

if self.logging_interval != "epoch":
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def log(

value = apply_to_collection(value, numbers.Number, self.__to_tensor)

if self.trainer.logger_connector.should_reset_tensors(self._current_fx_name):
if self.trainer._logger_connector.should_reset_tensors(self._current_fx_name):
# if we started a new epoch (running its first batch) the hook name has changed
# reset any tensors for the new hook name
results.reset(metrics=False, fx=self._current_fx_name)
Expand Down Expand Up @@ -433,7 +433,7 @@ def log(
rank_zero_only=rank_zero_only,
)

self.trainer.logger_connector._current_fx = self._current_fx_name
self.trainer._logger_connector._current_fx = self._current_fx_name

def log_dict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
void(batch)
self.split_idx, split_batch = self._remaining_splits.pop(0)

self.trainer.logger_connector.on_train_split_start(self.split_idx)
self.trainer._logger_connector.on_train_split_start(self.split_idx)

outputs: Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] = None # for mypy
# choose which loop will run the optimization
Expand Down
18 changes: 9 additions & 9 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,16 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
self._has_run = True

def on_advance_end(self) -> None:
self.trainer.logger_connector.epoch_end_reached()
self.trainer._logger_connector.epoch_end_reached()

self._logged_outputs.append(self.trainer.logger_connector.update_eval_epoch_metrics())
self._logged_outputs.append(self.trainer._logger_connector.update_eval_epoch_metrics())

super().on_advance_end()

def on_run_end(self) -> List[_OUT_DICT]:
"""Runs the ``_on_evaluation_epoch_end`` hook."""
# if `done` returned True before any iterations were done, this won't have been called in `on_advance_end`
self.trainer.logger_connector.epoch_end_reached()
self.trainer._logger_connector.epoch_end_reached()

# hook
self._evaluation_epoch_end(self._outputs)
Expand All @@ -182,12 +182,12 @@ def on_run_end(self) -> List[_OUT_DICT]:

logged_outputs, self._logged_outputs = self._logged_outputs, [] # free memory
# include any logged outputs on epoch_end
epoch_end_logged_outputs = self.trainer.logger_connector.update_eval_epoch_metrics()
epoch_end_logged_outputs = self.trainer._logger_connector.update_eval_epoch_metrics()
for dl_outputs in logged_outputs:
dl_outputs.update(epoch_end_logged_outputs)

# log metrics
self.trainer.logger_connector.log_eval_end_metrics()
self.trainer._logger_connector.log_eval_end_metrics()

# hook
self._on_evaluation_end()
Expand Down Expand Up @@ -266,11 +266,11 @@ def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
self.trainer._call_strategy_hook("on_validation_end", *args, **kwargs)

# reset the logger connector state
self.trainer.logger_connector.reset_results()
self.trainer._logger_connector.reset_results()

def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks."""
self.trainer.logger_connector.on_epoch_start()
self.trainer._logger_connector.on_epoch_start()
self.trainer._call_callback_hooks("on_epoch_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_epoch_start", *args, **kwargs)

Expand All @@ -283,7 +283,7 @@ def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:

def _evaluation_epoch_end(self, outputs: List[EPOCH_OUTPUT]) -> None:
"""Runs ``{validation/test}_epoch_end``"""
self.trainer.logger_connector._evaluation_epoch_end()
self.trainer._logger_connector._evaluation_epoch_end()

# with a single dataloader don't pass a 2D list
output_or_outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]] = (
Expand All @@ -304,7 +304,7 @@ def _on_evaluation_epoch_end(self) -> None:

self.trainer._call_callback_hooks("on_epoch_end")
self.trainer._call_lightning_module_hook("on_epoch_end")
self.trainer.logger_connector.on_epoch_end()
self.trainer._logger_connector.on_epoch_end()

@staticmethod
def _get_keys(data: dict) -> Iterable[str]:
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def advance( # type: ignore[override]
self.batch_progress.increment_completed()

# log batch metrics
self.trainer.logger_connector.update_eval_step_metrics()
self.trainer._logger_connector.update_eval_step_metrics()

# track epoch level outputs
if self._should_track_batch_outputs_for_epoch_end() and output is not None:
Expand Down Expand Up @@ -242,7 +242,7 @@ def _on_evaluation_batch_start(self, **kwargs: Any) -> None:
Raises:
AssertionError: If the number of dataloaders is None (has not yet been set).
"""
self.trainer.logger_connector.on_batch_start(**kwargs)
self.trainer._logger_connector.on_batch_start(**kwargs)

kwargs.setdefault("dataloader_idx", 0) # TODO: the argument should be keyword for these
hook_name = "on_test_batch_start" if self.trainer.testing else "on_validation_batch_start"
Expand All @@ -263,7 +263,7 @@ def _on_evaluation_batch_end(self, output: Optional[STEP_OUTPUT], **kwargs: Any)
self.trainer._call_callback_hooks(hook_name, output, *kwargs.values())
self.trainer._call_lightning_module_hook(hook_name, output, *kwargs.values())

self.trainer.logger_connector.on_batch_end()
self.trainer._logger_connector.on_batch_end()

def _build_kwargs(self, kwargs: OrderedDict, batch: Any, batch_idx: int) -> OrderedDict:
"""Helper function to build the arguments for the current step.
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov

self.batch_progress.increment_ready()

self.trainer.logger_connector.on_batch_start(batch, batch_idx)
self.trainer._logger_connector.on_batch_start(batch, batch_idx)

if batch is None:
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
Expand Down Expand Up @@ -225,7 +225,7 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
"on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs
)
self.trainer._call_callback_hooks("on_batch_end")
self.trainer.logger_connector.on_batch_end()
self.trainer._logger_connector.on_batch_end()

self.batch_progress.increment_completed()

Expand All @@ -235,7 +235,7 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
# -----------------------------------------
# SAVE METRICS TO LOGGERS AND PROGRESS_BAR
# -----------------------------------------
self.trainer.logger_connector.update_train_step_metrics()
self.trainer._logger_connector.update_train_step_metrics()

def on_advance_end(self) -> None:
# -----------------------------------------
Expand Down Expand Up @@ -504,7 +504,7 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
def _save_loggers_on_train_batch_end(self) -> None:
"""Flushes loggers to disk."""
# when loggers should save to disk
should_flush_logs = self.trainer.logger_connector.should_flush_logs
should_flush_logs = self.trainer._logger_connector.should_flush_logs
if should_flush_logs:
for logger in self.trainer.loggers:
logger.save()
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def on_advance_start(self) -> None: # type: ignore[override]

self.epoch_progress.increment_ready()

self.trainer.logger_connector.on_epoch_start()
self.trainer._logger_connector.on_epoch_start()

self.trainer._call_callback_hooks("on_epoch_start")
self.trainer._call_lightning_module_hook("on_epoch_start")
Expand All @@ -282,7 +282,7 @@ def advance(self) -> None: # type: ignore[override]

def on_advance_end(self) -> None:
# inform logger the batch loop has finished
self.trainer.logger_connector.epoch_end_reached()
self.trainer._logger_connector.epoch_end_reached()

# get the model and call model.training_epoch_end
model = self.trainer.lightning_module
Expand Down Expand Up @@ -312,7 +312,7 @@ def on_advance_end(self) -> None:
self.trainer._call_callback_hooks("on_epoch_end")
self.trainer._call_lightning_module_hook("on_epoch_end")

self.trainer.logger_connector.on_epoch_end()
self.trainer._logger_connector.on_epoch_end()

if self.epoch_loop._num_ready_batches_reached():
self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=True)
Expand All @@ -325,7 +325,7 @@ def on_advance_end(self) -> None:
# TODO(@carmocca): deprecate and rename so users don't get confused
self.global_step -= 1
# log epoch metrics
self.trainer.logger_connector.update_train_epoch_metrics()
self.trainer._logger_connector.update_train_epoch_metrics()
self.global_step += 1

# if fault tolerant is enabled and process has been notified, exit.
Expand Down
24 changes: 12 additions & 12 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def __init__(
amp_level=amp_level,
plugins=plugins,
)
self.logger_connector = LoggerConnector(self, log_gpu_memory)
self._logger_connector = LoggerConnector(self, log_gpu_memory)
self._callback_connector = CallbackConnector(self)
self._checkpoint_connector = CheckpointConnector(self, resume_from_checkpoint)
self._signal_connector = SignalConnector(self)
Expand Down Expand Up @@ -614,7 +614,7 @@ def __init__(

# init logger flags
self._loggers: List[LightningLoggerBase]
self.logger_connector.on_trainer_init(logger, flush_logs_every_n_steps, log_every_n_steps, move_metrics_to_cpu)
self._logger_connector.on_trainer_init(logger, flush_logs_every_n_steps, log_every_n_steps, move_metrics_to_cpu)

# init debugging flags
self.val_check_interval: Union[int, float]
Expand Down Expand Up @@ -1210,8 +1210,8 @@ def _run(
# ----------------------------

# reset logger connector
self.logger_connector.reset_results()
self.logger_connector.reset_metrics()
self._logger_connector.reset_results()
self._logger_connector.reset_metrics()

# strategy will configure model and move it to the device
self.strategy.setup(self)
Expand Down Expand Up @@ -1302,7 +1302,7 @@ def _teardown(self):
# loop should never be `None` here but it can because we don't know the trainer stage with `ddp_spawn`
if loop is not None:
loop.teardown()
self.logger_connector.teardown()
self._logger_connector.teardown()
self._signal_connector.teardown()

def run_stage(self) -> None:
Expand Down Expand Up @@ -1397,8 +1397,8 @@ def _run_sanity_check(self) -> None:
self.sanity_checking = True

# reset logger connector
self.logger_connector.reset_results()
self.logger_connector.reset_metrics()
self._logger_connector.reset_results()
self._logger_connector.reset_metrics()

self._call_callback_hooks("on_sanity_check_start")

Expand All @@ -1415,8 +1415,8 @@ def _run_sanity_check(self) -> None:
self._call_callback_hooks("on_sanity_check_end")

# reset logger connector
self.logger_connector.reset_results()
self.logger_connector.reset_metrics()
self._logger_connector.reset_results()
self._logger_connector.reset_metrics()

# reset the progress tracking state after sanity checking. we don't need to set the state before
# because sanity check only runs when we are not restarting
Expand Down Expand Up @@ -2644,15 +2644,15 @@ def loggers(self, loggers: Optional[List[LightningLoggerBase]]) -> None:

@property
def callback_metrics(self) -> dict:
return self.logger_connector.callback_metrics
return self._logger_connector.callback_metrics

@property
def logged_metrics(self) -> dict:
return self.logger_connector.logged_metrics
return self._logger_connector.logged_metrics

@property
def progress_bar_metrics(self) -> dict:
return self.logger_connector.progress_bar_metrics
return self._logger_connector.progress_bar_metrics

@property
def _results(self) -> Optional[_ResultCollection]:
Expand Down
2 changes: 1 addition & 1 deletion tests/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def __init__(self):
def training_step(self, *args):
self.log("foo", 1, on_step=True, on_epoch=True)
if not self.trainer.fit_loop._should_accumulate():
if self.trainer.logger_connector.should_update_logs:
if self.trainer._logger_connector.should_update_logs:
self.indexes.append(self.trainer.global_step)
return super().training_step(*args)

Expand Down