Skip to content

Commit 8f1c855

Browse files
authored
Fix ResultCollection._get_cache with multielement tensors (#9582)
1 parent e64f358 commit 8f1c855

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
360360
- Fixed `BasePredictionWriter` not returning the batch_indices in a non-distributed setting ([#9432](https://github.com/PyTorchLightning/pytorch-lightning/pull/9432))
361361

362362

363+
- Fixed check on torchmetrics logged whose `compute()` output is a multielement tensor ([#9582](https://github.com/PyTorchLightning/pytorch-lightning/pull/9582))
364+
365+
363366
- Fixed `add_argparse_args` raising `TypeError` when args are typed as `typing.Generic` in Python 3.6 ([#9554](https://github.com/PyTorchLightning/pytorch-lightning/pull/9554))
364367

365368

pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Ten
467467
if on_step and result_metric.meta.on_step:
468468
cache = result_metric._forward_cache
469469
elif not on_step and result_metric.meta.on_epoch:
470-
if not result_metric._computed:
470+
if result_metric._computed is None:
471471
# always reduce on epoch end
472472
should = result_metric.meta.sync.should
473473
result_metric.meta.sync.should = True

tests/core/test_metric_result_integration.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@
2727
import tests.helpers.utils as tutils
2828
from pytorch_lightning import Trainer
2929
from pytorch_lightning.callbacks import ModelCheckpoint
30-
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
30+
from pytorch_lightning.trainer.connectors.logger_connector.result import (
31+
_Metadata,
32+
_Sync,
33+
ResultCollection,
34+
ResultMetric,
35+
)
3136
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_7
3237
from tests.helpers import BoringModel
3338
from tests.helpers.runif import RunIf
@@ -544,3 +549,16 @@ def on_train_epoch_end(self) -> None:
544549

545550
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, limit_train_batches=2, limit_val_batches=0)
546551
trainer.fit(model)
552+
553+
554+
def test_metric_result_computed_check():
555+
"""Unittest ``_get_cache`` with multielement tensors."""
556+
sync = _Sync()
557+
metadata = _Metadata("foo", "bar", on_epoch=True, enable_graph=True)
558+
metadata.sync = sync
559+
rm = ResultMetric(metadata, is_tensor=True)
560+
computed_value = torch.tensor([1, 2, 3])
561+
rm._computed = computed_value
562+
cache = ResultCollection._get_cache(rm, on_step=False)
563+
# `enable_graph=True` so no detach, identity works
564+
assert cache is computed_value

0 commit comments

Comments
 (0)