Skip to content

Commit e4f9656

Browse files
rohitgr7carmocca
authored andcommitted
Fix filtration logic for eval results with multiple dataloaders (#10810)
Co-authored-by: Carlos Mocholi <[email protected]>
1 parent c1184dc commit e4f9656

File tree

7 files changed

+72
-47
lines changed

7 files changed

+72
-47
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111
- Disabled batch_size extraction for torchmetric instances because they accumulate the metrics internally ([#10815](https://github.com/PyTorchLightning/pytorch-lightning/pull/10815))
1212
- Fixed an issue with `SignalConnector` not restoring the default signal handlers on teardown when running on SLURM or with fault-tolerant training enabled ([#10611](https://github.com/PyTorchLightning/pytorch-lightning/pull/10611))
1313
- Fixed `SignalConnector._has_already_handler` check for callable type ([#10483](https://github.com/PyTorchLightning/pytorch-lightning/pull/10483))
14+
- Fixed an issue to return the results for each dataloader separately instead of duplicating them for each ([#10810](https://github.com/PyTorchLightning/pytorch-lightning/pull/10810))
1415
- Improved exception message if `rich` version is less than `10.2.2` ([#10839](https://github.com/PyTorchLightning/pytorch-lightning/pull/10839))
1516
- Fixed uploading best model checkpoint in NeptuneLogger ([#10369](https://github.com/PyTorchLightning/pytorch-lightning/pull/10369))
1617
- Fixed early schedule reset logic in PyTorch profiler that was causing data leak ([#10837](https://github.com/PyTorchLightning/pytorch-lightning/pull/10837))
@@ -27,6 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2728
- Fixed a consolidation error in Lite when attempting to save the state dict of a sharded optimizer ([#10746](https://github.com/PyTorchLightning/pytorch-lightning/pull/10746))
2829
- Fixed the default logging level for batch hooks associated with training from `on_step=False, on_epoch=True` to `on_step=True, on_epoch=False` ([#10756](https://github.com/PyTorchLightning/pytorch-lightning/pull/10756))
2930

31+
3032
### Removed
3133

3234
- Removed PyTorch 1.6 support ([#10367](https://github.com/PyTorchLightning/pytorch-lightning/pull/10367), [#10738](https://github.com/PyTorchLightning/pytorch-lightning/pull/10738))

pytorch_lightning/core/lightning.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,8 @@ def log(
486486
on_epoch=on_epoch,
487487
reduce_fx=reduce_fx,
488488
enable_graph=enable_graph,
489-
dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None),
489+
add_dataloader_idx=add_dataloader_idx,
490+
dataloader_idx=self._current_dataloader_idx,
490491
batch_size=batch_size,
491492
sync_dist=sync_dist and distributed_available(),
492493
sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp,

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -154,21 +154,20 @@ def update_eval_step_metrics(self) -> None:
154154
# increment the step even if nothing was logged
155155
self._increment_eval_log_step()
156156

157-
@staticmethod
158-
def _filter_metrics_for_dataloader(
159-
dl_idx: int, metrics: _OUT_DICT, metric_prefix: str = "dataloader_idx"
160-
) -> _OUT_DICT:
161-
return {k: v for k, v in metrics.items() if metric_prefix not in k or k.endswith(f"{metric_prefix}_{dl_idx}")}
162-
163-
def _prepare_eval_loop_results(self, metrics: _OUT_DICT) -> None:
157+
def _prepare_eval_loop_results(self) -> None:
164158
if self.trainer.sanity_checking:
165159
return
166160

161+
on_step = not self._epoch_end_reached
167162
num_dataloaders = self.trainer._evaluation_loop.num_dataloaders
168163
has_been_initialized = len(self.eval_loop_results) == num_dataloaders
169-
for dl_idx in range(self.trainer._evaluation_loop.num_dataloaders):
170-
# remove callback metrics that don't belong to this dataloader
171-
callback_metrics = self._filter_metrics_for_dataloader(dl_idx, metrics)
164+
assert self.trainer._evaluation_loop._results is not None
165+
for dl_idx in range(num_dataloaders):
166+
metrics = self.trainer._evaluation_loop._results.metrics(
167+
on_step, dataloader_idx=dl_idx if num_dataloaders > 1 else None
168+
)
169+
callback_metrics = metrics["callback"]
170+
172171
if has_been_initialized:
173172
self.eval_loop_results[dl_idx].update(callback_metrics)
174173
else:
@@ -182,7 +181,7 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]:
182181
# log all the metrics as a single dict
183182
self.log_metrics(metrics["log"])
184183

185-
self._prepare_eval_loop_results(metrics["callback"])
184+
self._prepare_eval_loop_results()
186185

187186
# log results of evaluation
188187
if (

pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class _Metadata:
113113
on_epoch: bool = True
114114
reduce_fx: Callable = torch.mean
115115
enable_graph: bool = False
116+
add_dataloader_idx: bool = True
116117
dataloader_idx: Optional[int] = None
117118
metric_attribute: Optional[str] = None
118119
_sync: Optional[_Sync] = None
@@ -434,6 +435,7 @@ def log(
434435
sync_dist: bool = False,
435436
sync_dist_fn: Callable = _Sync.no_op,
436437
sync_dist_group: Optional[Any] = None,
438+
add_dataloader_idx: bool = True,
437439
dataloader_idx: Optional[int] = None,
438440
batch_size: Optional[int] = None,
439441
metric_attribute: Optional[str] = None,
@@ -451,7 +453,7 @@ def log(
451453
# storage key
452454
key = f"{fx}.{name}"
453455
# add dataloader_suffix to both key and fx
454-
if dataloader_idx is not None:
456+
if add_dataloader_idx and dataloader_idx is not None:
455457
key += f".{dataloader_idx}"
456458
fx += f".{dataloader_idx}"
457459

@@ -464,6 +466,7 @@ def log(
464466
on_epoch=on_epoch,
465467
reduce_fx=reduce_fx,
466468
enable_graph=enable_graph,
469+
add_dataloader_idx=add_dataloader_idx,
467470
dataloader_idx=dataloader_idx,
468471
metric_attribute=metric_attribute,
469472
)
@@ -522,24 +525,29 @@ def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Ten
522525
return cache.detach()
523526
return cache
524527

525-
def valid_items(self) -> Generator:
528+
def valid_items(self, dataloader_idx: Optional[int] = None) -> Generator:
526529
"""This function is used to iterate over current valid metrics."""
527-
return ((k, v) for k, v in self.items() if not (isinstance(v, ResultMetric) and v.has_reset))
530+
return (
531+
(k, v)
532+
for k, v in self.items()
533+
if not (isinstance(v, ResultMetric) and v.has_reset) and (dataloader_idx in (None, v.meta.dataloader_idx))
534+
)
528535

529536
def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, str]:
530537
name = result_metric.meta.name
531538
forked_name = result_metric.meta.forked_name(on_step)
539+
add_dataloader_idx = result_metric.meta.add_dataloader_idx
532540
dl_idx = result_metric.meta.dataloader_idx
533-
if dl_idx is not None:
541+
if add_dataloader_idx and dl_idx is not None:
534542
dataloader_suffix = self.DATALOADER_SUFFIX.format(dl_idx)
535543
name += dataloader_suffix
536544
forked_name += dataloader_suffix
537545
return name, forked_name
538546

539-
def metrics(self, on_step: bool) -> _METRICS:
547+
def metrics(self, on_step: bool, dataloader_idx: Optional[int] = None) -> _METRICS:
540548
metrics = _METRICS(callback={}, log={}, pbar={})
541549

542-
for _, result_metric in self.valid_items():
550+
for _, result_metric in self.valid_items(dataloader_idx):
543551

544552
# extract forward_cache or computed from the ResultMetric. ignore when the output is None
545553
value = apply_to_collection(result_metric, ResultMetric, self._get_cache, on_step, include_none=False)

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1396,7 +1396,7 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_
13961396
" The best model of the previous `fit` call will be used."
13971397
f" You can pass `{fn}(ckpt_path='best')` to use and best model"
13981398
" checkpoint and avoid this warning or"
1399-
" `ckpt_path=trainer.model_checkpoint.last_model_path` to use the last model."
1399+
" `ckpt_path=trainer.checkpoint_callback.last_model_path` to use the last model."
14001400
)
14011401
ckpt_path = "best"
14021402

tests/plugins/test_ddp_spawn_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def on_predict_start(self) -> None:
128128
assert isinstance(self.trainer.model, LightningModule)
129129

130130

131-
@RunIf(skip_windows=True, skip_49370=True)
131+
@RunIf(skip_windows=True, skip_49370=True, skip_hanging_spawn=True)
132132
def test_ddp_spawn_configure_ddp(tmpdir):
133133
"""Tests with ddp spawn plugin."""
134134
trainer = Trainer(default_root_dir=tmpdir, num_processes=2, strategy="ddp_spawn", fast_dev_run=True)

tests/trainer/logging_/test_eval_loop_logging.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
from pytorch_lightning import callbacks, Trainer
2525
from pytorch_lightning.loggers import TensorBoardLogger
26-
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
2726
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2827
from tests.helpers import BoringModel, RandomDataset
2928
from tests.helpers.runif import RunIf
@@ -676,32 +675,6 @@ def val_dataloader(self):
676675
trainer.fit(model)
677676

678677

679-
@pytest.mark.parametrize(
680-
["kwargs", "expected"],
681-
[
682-
({"dl_idx": 0, "metrics": {"acc": 123}}, {"acc": 123}),
683-
(
684-
{"dl_idx": 0, "metrics": {"acc/dataloader_idx_0": 123, "acc/dataloader_idx_1": 321}},
685-
{"acc/dataloader_idx_0": 123},
686-
),
687-
(
688-
{"dl_idx": 10, "metrics": {"acc/dataloader_idx_1": 123, "acc/dataloader_idx_10": 321}},
689-
{"acc/dataloader_idx_10": 321},
690-
),
691-
(
692-
{"dl_idx": 3, "metrics": {"top_3_acc/dataloader_idx_0": 123, "top_3_acc/dataloader_idx_3": 321}},
693-
{"top_3_acc/dataloader_idx_3": 321},
694-
),
695-
# theoretical case, as `/dataloader_idx_3` would have been added
696-
({"dl_idx": 3, "metrics": {"top_3_acc": 123}}, {"top_3_acc": 123}),
697-
],
698-
)
699-
def test_filter_metrics_for_dataloader(kwargs, expected):
700-
"""Logged metrics should only include metrics from the concerned dataloader."""
701-
actual = LoggerConnector._filter_metrics_for_dataloader(**kwargs)
702-
assert actual == expected
703-
704-
705678
@RunIf(min_gpus=1)
706679
def test_evaluation_move_metrics_to_cpu_and_outputs(tmpdir):
707680
class TestModel(BoringModel):
@@ -723,3 +696,45 @@ def validation_epoch_end(self, outputs):
723696
model = TestModel()
724697
trainer = Trainer(default_root_dir=tmpdir, limit_val_batches=2, move_metrics_to_cpu=True, gpus=1)
725698
trainer.validate(model, verbose=False)
699+
700+
701+
def test_logging_results_with_no_dataloader_idx(tmpdir):
702+
num_dataloaders = 2
703+
log_common_same_val = {"test_log_common": 789}
704+
log_common_diff_val = "test_log_common_diff_value"
705+
log_key_no_dl_idx = "test_log_no_dl_idx_{}"
706+
log_key_dl0 = {"test_log_a_class": 123}
707+
log_key_dl1 = {"test_log_b_class": 456}
708+
709+
class CustomBoringModel(BoringModel):
710+
def test_step(self, batch, batch_idx, dataloader_idx):
711+
self.log_dict(log_common_same_val)
712+
self.log(log_common_diff_val, dataloader_idx + 1)
713+
self.log(
714+
log_key_no_dl_idx.format(dataloader_idx),
715+
321 * (dataloader_idx + 1),
716+
add_dataloader_idx=False,
717+
)
718+
self.log_dict(log_key_dl0 if dataloader_idx == 0 else log_key_dl1, add_dataloader_idx=False)
719+
720+
def test_dataloader(self):
721+
return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(num_dataloaders)]
722+
723+
model = CustomBoringModel()
724+
model.test_epoch_end = None
725+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
726+
results = trainer.test(model)
727+
728+
assert len(results) == num_dataloaders
729+
assert results[0] == {
730+
"test_log_common/dataloader_idx_0": 789.0,
731+
"test_log_common_diff_value/dataloader_idx_0": 1.0,
732+
"test_log_no_dl_idx_0": 321,
733+
"test_log_a_class": 123.0,
734+
}
735+
assert results[1] == {
736+
"test_log_common/dataloader_idx_1": 789.0,
737+
"test_log_common_diff_value/dataloader_idx_1": 2.0,
738+
"test_log_no_dl_idx_1": 321 * 2,
739+
"test_log_b_class": 456.0,
740+
}

0 commit comments

Comments
 (0)