Skip to content

log metrics for correct dataloader only #10522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Nov 18, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,19 @@ def update_eval_step_metrics(self) -> None:
# increment the step even if nothing was logged
self._increment_eval_log_step()

@staticmethod
def _filter_metrics_for_dataloader(
dl_idx: int, metrics: Dict[str, Union[Any, Dict[str, Any]]], metric_prefix: str = "dataloader_idx"
) -> Dict[str, Union[Any, Dict[str, Any]]]:
result = {}
for k, v in metrics.items():
if metric_prefix not in k:
result[k] = v
continue
if k.endswith(f"{metric_prefix}_{dl_idx}"):
result[k] = v
return result

def _prepare_eval_loop_results(self, metrics: _OUT_DICT) -> None:
if self.trainer.sanity_checking:
return
Expand All @@ -162,9 +175,7 @@ def _prepare_eval_loop_results(self, metrics: _OUT_DICT) -> None:
has_been_initialized = len(self.eval_loop_results) == num_dataloaders
for dl_idx in range(self.trainer._evaluation_loop.num_dataloaders):
# remove callback metrics that don't belong to this dataloader
callback_metrics = {
k: v for k, v in metrics.items() if "dataloader_idx" not in k or f"dataloader_idx_{dl_idx}" in k
}
callback_metrics = self._filter_metrics_for_dataloader(dl_idx, metrics)
if has_been_initialized:
self.eval_loop_results[dl_idx].update(callback_metrics)
else:
Expand Down
27 changes: 27 additions & 0 deletions tests/trainer/logging_/test_eval_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from pytorch_lightning import callbacks, Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset

Expand Down Expand Up @@ -672,3 +673,29 @@ def val_dataloader(self):
enable_model_summary=False,
)
trainer.fit(model)


@pytest.mark.parametrize(
["kwargs", "expected"],
[
({"dl_idx": 0, "metrics": {"acc": 123}}, {"acc": 123}),
(
{"dl_idx": 0, "metrics": {"acc/dataloader_idx_0": 123, "acc/dataloader_idx_1": 321}},
{"acc/dataloader_idx_0": 123},
),
(
{"dl_idx": 10, "metrics": {"acc/dataloader_idx_1": 123, "acc/dataloader_idx_10": 321}},
{"acc/dataloader_idx_10": 321},
),
(
{"dl_idx": 3, "metrics": {"top_3_acc/dataloader_idx_0": 123, "top_3_acc/dataloader_idx_3": 321}},
{"top_3_acc/dataloader_idx_3": 321},
),
# theoretical case, as `/dataloader_idx_3` would have been added
({"dl_idx": 3, "metrics": {"top_3_acc": 123}}, {"top_3_acc": 123}),
],
)
def test_filter_metrics_for_dataloader(kwargs, expected):
"""Logged metrics should only include metrics from the concerned dataloader."""
actual = LoggerConnector._filter_metrics_for_dataloader(**kwargs)
assert actual == expected