Skip to content

Commit c179a7d

Browse files
carmoccaawaelchli
authored andcommitted
Fix move_metrics_to_cpu with evaluation (#10631)
1 parent ef4feb7 commit c179a7d

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2626
- Fixed an issue that caused Lightning to extract the batch size even though it was set by the user in `LightningModule.log` ([#10408](https://github.com/PyTorchLightning/pytorch-lightning/pull/10408))
2727

2828

29-
-
29+
- Fixed `Trainer(move_metrics_to_cpu=True)` not moving the evaluation logged results to CPU ([#10631](https://github.com/PyTorchLightning/pytorch-lightning/pull/10631))
3030

3131

32-
-
32+
- Fixed the `{validation,test}_step` outputs getting moved to CPU with `Trainer(move_metrics_to_cpu=True)` ([#10631](https://github.com/PyTorchLightning/pytorch-lightning/pull/10631))
3333

3434

3535

pytorch_lightning/loops/epoch/evaluation_epoch_loop.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from pytorch_lightning.trainer.progress import BatchProgress
2525
from pytorch_lightning.utilities.auto_restart import MergedIteratorState, reload_dataloader_state_dict
2626
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher
27-
from pytorch_lightning.utilities.memory import recursive_detach
2827
from pytorch_lightning.utilities.model_helpers import is_overridden
2928
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
3029

@@ -134,10 +133,13 @@ def advance(
134133
self.trainer.logger_connector.update_eval_step_metrics()
135134

136135
# track epoch level outputs
137-
if self._should_track_batch_outputs_for_epoch_end():
138-
output = recursive_detach(output, to_cpu=self.trainer.move_metrics_to_cpu)
139-
if output is not None:
140-
self.outputs.append(output)
136+
if self._should_track_batch_outputs_for_epoch_end() and output is not None:
137+
self.outputs.append(output)
138+
139+
if self.trainer.move_metrics_to_cpu:
140+
# the evaluation step output is not moved as they are not considered "metrics"
141+
assert self.trainer._results is not None
142+
self.trainer._results.cpu()
141143

142144
if not self.batch_progress.is_last_batch:
143145
# if fault tolerant is enabled and process has been notified, exit.

tests/trainer/logging_/test_eval_loop_logging.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
2727
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2828
from tests.helpers import BoringModel, RandomDataset
29+
from tests.helpers.runif import RunIf
2930

3031

3132
def test__validation_step__log(tmpdir):
@@ -699,3 +700,26 @@ def test_filter_metrics_for_dataloader(kwargs, expected):
699700
"""Logged metrics should only include metrics from the concerned dataloader."""
700701
actual = LoggerConnector._filter_metrics_for_dataloader(**kwargs)
701702
assert actual == expected
703+
704+
705+
@RunIf(min_gpus=1)
706+
def test_evaluation_move_metrics_to_cpu_and_outputs(tmpdir):
707+
class TestModel(BoringModel):
708+
def validation_step(self, *args):
709+
x = torch.tensor(2.0, requires_grad=True, device=self.device)
710+
y = x * 2
711+
assert x.requires_grad is True
712+
assert y.grad_fn is None # disabled by validation
713+
714+
self.log("foo", y)
715+
return y
716+
717+
def validation_epoch_end(self, outputs):
718+
# the step outputs were not moved
719+
assert all(o.device == self.device for o in outputs), outputs
720+
# but the logging results were
721+
assert self.trainer.callback_metrics["foo"].device.type == "cpu"
722+
723+
model = TestModel()
724+
trainer = Trainer(default_root_dir=tmpdir, limit_val_batches=2, move_metrics_to_cpu=True, gpus=1)
725+
trainer.validate(model, verbose=False)

0 commit comments

Comments
 (0)