Skip to content

Commit 59020b9

Browse files
awaelchlicarmocca
andcommitted
Fix min/max logging default value (#11310)
Co-authored-by: Carlos Mocholi <[email protected]>
1 parent 10e9892 commit 59020b9

File tree

3 files changed

+19
-3
lines changed

3 files changed

+19
-3
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010
### Fixed
1111

1212
- Fixed `LightningCLI` race condition while saving the config ([#11199](https://github.com/PyTorchLightning/pytorch-lightning/pull/11199))
13+
- Fixed the default value used with `log(reduce_fx=min|max)` ([#11310](https://github.com/PyTorchLightning/pytorch-lightning/pull/11310))
1314
- Fixed an issue with the `TPUSpawnPlugin` handling the `XLA_USE_BF16` environment variable incorrectly ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/pull/10990))
1415
- Fixed data fetcher selection ([#11294](https://github.com/PyTorchLightning/pytorch-lightning/pull/11294))
1516
- Fixed a race condition that could result in incorrect (zero) values being observed in prediction writer callbacks ([#11288](https://github.com/PyTorchLightning/pytorch-lightning/pull/11288))

pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,14 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
208208
self.meta = metadata
209209
self.has_reset = False
210210
if is_tensor:
211+
if metadata.is_max_reduction:
212+
default = float("-inf")
213+
elif metadata.is_min_reduction:
214+
default = float("inf")
215+
else:
216+
default = 0.0
211217
# do not set a dtype in case the default dtype was changed
212-
self.add_state("value", torch.tensor(0.0), dist_reduce_fx=torch.sum)
218+
self.add_state("value", torch.tensor(default), dist_reduce_fx=torch.sum)
213219
if self.meta.is_mean_reduction:
214220
self.add_state("cumulated_batch_size", torch.tensor(0), dist_reduce_fx=torch.sum)
215221

tests/core/test_metric_result_integration.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -575,14 +575,14 @@ def test_metric_result_respects_dtype(floating_dtype):
575575
assert rm.cumulated_batch_size.dtype == fixed_dtype
576576

577577
# two fixed point numbers - should be converted
578-
value, batch_size = torch.tensor(2), torch.tensor(3)
578+
value, batch_size = torch.tensor(2), 3
579579
assert value.dtype == fixed_dtype
580580
with pytest.warns(
581581
UserWarning, match=rf"`self.log\('bar', ...\)` in your `foo` .* Converting it to {floating_dtype}"
582582
):
583583
rm.update(value, batch_size)
584584
# floating and fixed
585-
rm.update(torch.tensor(4.0), torch.tensor(5))
585+
rm.update(torch.tensor(4.0), 5)
586586

587587
total = rm.compute()
588588

@@ -591,3 +591,12 @@ def test_metric_result_respects_dtype(floating_dtype):
591591

592592
# restore to avoid impacting other tests
593593
torch.set_default_dtype(torch.float)
594+
595+
596+
@pytest.mark.parametrize(["reduce_fx", "expected"], [(max, -2), (min, 2)])
597+
def test_result_metric_max_min(reduce_fx, expected):
598+
metadata = _Metadata("foo", "bar", reduce_fx=reduce_fx)
599+
metadata.sync = _Sync()
600+
rm = ResultMetric(metadata, is_tensor=True)
601+
rm.update(torch.tensor(expected), 1)
602+
assert rm.compute() == expected

0 commit comments

Comments
 (0)