Skip to content

Commit 6ac6992

Browse files
rohitgr7carmocca
authored andcommitted
Fix batch size extraction when set by the user in LightningModule.log (#10408)
Co-authored-by: Carlos Mocholi <[email protected]>
1 parent 5c404de commit 6ac6992

File tree

8 files changed

+109
-62
lines changed

8 files changed

+109
-62
lines changed

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
161161

162162
self.batch_progress.increment_ready()
163163

164-
# cache the batch size value to avoid extracting it again after the batch loop runs as the value will be
165-
# different if tbptt is enabled
166-
batch_size = self.trainer.logger_connector.on_batch_start(batch_idx, batch)
164+
self.trainer.logger_connector.on_batch_start(batch_idx, batch)
167165

168166
if batch is None:
169167
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
@@ -194,8 +192,6 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
194192
with self.trainer.profiler.profile("run_training_batch"):
195193
batch_output = self.batch_loop.run(batch, batch_idx)
196194

197-
self.trainer._results.batch_size = batch_size
198-
199195
self.batch_progress.increment_processed()
200196

201197
# update non-plateau LR schedulers

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,6 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]:
210210

211211
def on_train_split_start(self, split_idx: int, split_batch: Any) -> None:
212212
self._split_idx = split_idx
213-
self.on_new_batch(split_batch)
214213

215214
def update_train_step_metrics(self) -> None:
216215
if self.trainer.fit_loop._should_accumulate() and self.trainer.lightning_module.automatic_optimization:
@@ -253,28 +252,23 @@ def _log_gpus_metrics(self) -> None:
253252
Utilities and properties
254253
"""
255254

256-
def on_new_batch(self, batch: Any) -> int:
257-
# when the user requests `dataloader_iter`, we can't track the batch_size
258-
# and this is left to user responsibility.
259-
if not isinstance(batch, pl.utilities.fetching.StepFuncDataLoaderIter):
260-
assert self.trainer._results is not None
261-
return self.trainer._results.extract_batch_size(batch)
262-
return 1
263-
264255
def on_epoch_start(self) -> None:
265256
self._epoch_end_reached = False
266257

267-
def on_batch_start(self, batch_idx: int, batch: Any) -> int:
258+
def on_batch_start(self, batch_idx: int, batch: Any) -> None:
268259
self._batch_idx = batch_idx
269260
self._epoch_end_reached = False
270-
return self.on_new_batch(batch)
261+
262+
assert self.trainer._results is not None
263+
# attach reference to the new batch and remove the cached batch_size
264+
self.trainer._results.batch = batch
265+
self.trainer._results.batch_size = None
271266

272267
def epoch_end_reached(self) -> None:
273268
self._epoch_end_reached = True
274269
self._batch_idx = None
275270
self._split_idx = None
276271
assert self.trainer._results is not None
277-
self.trainer._results.batch_size = 1
278272

279273
def on_epoch_end(self) -> None:
280274
assert self._epoch_end_reached
@@ -291,6 +285,11 @@ def on_batch_end(self) -> None:
291285
self._callback_metrics.update(metrics["callback"])
292286
self._logged_metrics.update(metrics["log"])
293287

288+
assert self.trainer._results is not None
289+
# drop the reference to current batch and batch_size
290+
self.trainer._results.batch = None
291+
self.trainer._results.batch_size = None
292+
294293
def should_reset_tensors(self, fx: str) -> bool:
295294
is_different_fx = self._current_fx != fx
296295
if self._split_idx is None:

pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
211211
if self.meta.is_mean_reduction:
212212
self.add_state("cumulated_batch_size", torch.tensor(0, dtype=torch.float), dist_reduce_fx=torch.sum)
213213

214-
def update(self, value: _IN_METRIC, batch_size: torch.Tensor) -> None:
214+
def update(self, value: _IN_METRIC, batch_size: int) -> None:
215215
if self.is_tensor:
216216
value = value.float()
217217
if self.meta.on_step:
@@ -250,7 +250,7 @@ def reset(self) -> None:
250250
self.value.reset()
251251
self.has_reset = True
252252

253-
def forward(self, value: _IN_METRIC, batch_size: torch.Tensor) -> None:
253+
def forward(self, value: _IN_METRIC, batch_size: int) -> None:
254254
if self.meta.enable_graph:
255255
with torch.no_grad():
256256
self.update(value, batch_size)
@@ -376,8 +376,9 @@ class ResultCollection(dict):
376376
def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = None) -> None:
377377
super().__init__()
378378
self.training = training
379-
self._batch_size = torch.tensor(1, device=device)
380379
self.device: Optional[Union[str, torch.device]] = device
380+
self.batch: Optional[Any] = None
381+
self.batch_size: Optional[int] = None
381382

382383
@property
383384
def result_metrics(self) -> List[ResultMetric]:
@@ -390,14 +391,23 @@ def append_fn(v: ResultMetric) -> None:
390391
apply_to_collection(list(self.values()), ResultMetric, append_fn)
391392
return o
392393

393-
@property
394-
def batch_size(self) -> torch.Tensor:
395-
# performance: cache the `batch_size` tensor instead of re-creating it
396-
return self._batch_size
394+
def _extract_batch_size(self, batch_size: Optional[int], meta: _Metadata) -> int:
395+
# check if we have extracted the batch size already
396+
if batch_size is None:
397+
batch_size = self.batch_size
398+
399+
if batch_size is not None:
400+
return batch_size
397401

398-
@batch_size.setter
399-
def batch_size(self, value: int) -> None:
400-
self._batch_size = torch.tensor(value, device=self.device)
402+
batch_size = 1
403+
if self.batch is not None and meta.on_epoch and meta.is_mean_reduction:
404+
try:
405+
batch_size = extract_batch_size(self.batch)
406+
self.batch_size = batch_size
407+
except RecursionError:
408+
pass
409+
410+
return batch_size
401411

402412
def log(
403413
self,
@@ -458,10 +468,8 @@ def log(
458468
f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed"
459469
)
460470

461-
if batch_size is not None:
462-
self.batch_size = batch_size
463-
464-
self.update_metrics(key, value)
471+
batch_size = self._extract_batch_size(batch_size, meta)
472+
self.update_metrics(key, value, batch_size)
465473

466474
def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None:
467475
"""Create one ResultMetric object per value.
@@ -478,10 +486,10 @@ def fn(v: _IN_METRIC) -> ResultMetric:
478486
value = ResultMetricCollection(value)
479487
self[key] = value
480488

481-
def update_metrics(self, key: str, value: _METRIC_COLLECTION) -> None:
482-
def fn(result_metric: ResultMetric, v: ResultMetric) -> None:
489+
def update_metrics(self, key: str, value: _METRIC_COLLECTION, batch_size: int) -> None:
490+
def fn(result_metric: ResultMetric, v: torch.Tensor) -> None:
483491
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
484-
result_metric.forward(v.to(self.device), self.batch_size)
492+
result_metric.forward(v.to(self.device), batch_size)
485493
result_metric.has_reset = False
486494

487495
apply_to_collections(self[key], value, ResultMetric, fn)
@@ -575,19 +583,10 @@ def fn(item: ResultMetric) -> None:
575583

576584
apply_to_collection(self, ResultMetric, fn)
577585

578-
def extract_batch_size(self, batch: Any) -> int:
579-
try:
580-
batch_size = extract_batch_size(batch)
581-
except RecursionError:
582-
batch_size = 1
583-
self.batch_size = batch_size # the setter converts it to `Tensor`
584-
return batch_size
585-
586586
def to(self, *args: Any, **kwargs: Any) -> "ResultCollection":
587587
"""Move all data to the given device."""
588588
self.update(apply_to_collection(dict(self), (torch.Tensor, Metric), move_data_to_device, *args, **kwargs))
589589

590-
self._batch_size = self._batch_size.to(*args, **kwargs)
591590
if "device" in kwargs:
592591
self.device = kwargs["device"]
593592
return self

pytorch_lightning/utilities/data.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929

3030
def _extract_batch_size(batch: BType) -> Generator[int, None, None]:
3131
if isinstance(batch, torch.Tensor):
32-
yield batch.size(0)
32+
if batch.ndim == 0:
33+
yield 1
34+
else:
35+
yield batch.size(0)
3336
elif isinstance(batch, str):
3437
yield len(batch)
3538
elif isinstance(batch, (Iterable, Mapping)):

tests/deprecated_api/__init__.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""Test deprecated functionality which will be removed in vX.Y.Z."""
1515
import sys
1616
from contextlib import contextmanager
17-
from typing import Optional
17+
from typing import Optional, Type
1818

1919
import pytest
2020

@@ -26,14 +26,28 @@ def _soft_unimport_module(str_module):
2626

2727

2828
@contextmanager
29-
def no_deprecated_call(match: Optional[str] = None):
29+
def no_warning_call(expected_warning: Type[Warning] = UserWarning, match: Optional[str] = None):
3030
with pytest.warns(None) as record:
3131
yield
32+
33+
if match is None:
3234
try:
33-
w = record.pop(DeprecationWarning)
34-
if match is not None and match not in str(w.message):
35-
return
35+
w = record.pop(expected_warning)
3636
except AssertionError:
37-
# no DeprecationWarning raised
37+
# no warning raised
38+
return
39+
else:
40+
for w in record.list:
41+
if w.category is expected_warning and match in w.message.args[0]:
42+
break
43+
else:
3844
return
39-
raise AssertionError(f"`DeprecationWarning` was raised: {w}")
45+
46+
msg = "A warning" if expected_warning is None else f"`{expected_warning.__name__}`"
47+
raise AssertionError(f"{msg} was raised: {w}")
48+
49+
50+
@contextmanager
51+
def no_deprecated_call(match: Optional[str] = None):
52+
with no_warning_call(expected_warning=DeprecationWarning, match=match):
53+
yield

tests/loops/test_loop_state_dict.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from unittest.mock import Mock
1515

1616
import pytest
17-
import torch
1817

1918
from pytorch_lightning.loops import FitLoop
2019
from pytorch_lightning.trainer.trainer import Trainer
@@ -80,14 +79,16 @@ def test_loops_state_dict_structure():
8079
"is_last_batch": False,
8180
},
8281
"epoch_loop.val_loop._results": {
82+
"batch": None,
83+
"batch_size": None,
8384
"training": False,
84-
"_batch_size": torch.tensor(1),
8585
"device": None,
8686
"items": {},
8787
},
8888
"epoch_loop._results": {
89+
"batch": None,
90+
"batch_size": None,
8991
"training": True,
90-
"_batch_size": torch.tensor(1),
9192
"device": None,
9293
"items": {},
9394
},
@@ -106,8 +107,9 @@ def test_loops_state_dict_structure():
106107
"is_last_batch": False,
107108
},
108109
"_results": {
110+
"batch": None,
111+
"batch_size": None,
109112
"training": False,
110-
"_batch_size": torch.tensor(1),
111113
"device": None,
112114
"items": {},
113115
},
@@ -122,8 +124,9 @@ def test_loops_state_dict_structure():
122124
"is_last_batch": False,
123125
},
124126
"_results": {
127+
"batch": None,
128+
"batch_size": None,
125129
"training": False,
126-
"_batch_size": torch.tensor(1),
127130
"device": None,
128131
"items": {},
129132
},

tests/trainer/logging_/test_train_loop_logging.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar
2828
from pytorch_lightning.core.lightning import LightningModule
2929
from pytorch_lightning.utilities.exceptions import MisconfigurationException
30+
from tests.deprecated_api import no_warning_call
3031
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomDictDataset
3132
from tests.helpers.runif import RunIf
3233

@@ -715,19 +716,15 @@ def on_validation_epoch_end(self):
715716
assert all(v == 3 for v in self.trainer.callback_metrics.values())
716717

717718
def on_train_batch_start(self, batch, batch_idx):
718-
assert self.trainer._results.batch_size == 2
719719
self.log("on_train_batch_start", 1.0, reduce_fx="sum")
720720

721721
def on_train_batch_end(self, outputs, batch, batch_idx):
722-
assert self.trainer._results.batch_size == 2
723722
self.log("on_train_batch_end", 1.0, reduce_fx="sum")
724723

725724
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
726-
assert self.trainer._results.batch_size == 2
727725
self.log("on_validation_batch_start", 1.0, reduce_fx="sum")
728726

729727
def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
730-
assert self.trainer._results.batch_size == 2
731728
self.log("on_validation_batch_end", 1.0, reduce_fx="sum")
732729

733730
def training_epoch_end(self, *_) -> None:
@@ -749,3 +746,36 @@ def validation_epoch_end(self, *_) -> None:
749746
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
750747
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
751748
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
749+
750+
751+
def test_no_batch_size_extraction_with_specifying_explictly(tmpdir):
752+
batch_size = BoringModel().train_dataloader().batch_size + 1
753+
fast_dev_run = 2
754+
log_val = 7
755+
756+
class CustomBoringModel(BoringModel):
757+
def on_before_batch_transfer(self, batch, *args, **kwargs):
758+
# This is an ambiguous batch which have multiple potential batch sizes
759+
if self.trainer.training:
760+
batch = {"batch1": torch.randn(batch_size, 10), "batch2": batch}
761+
return batch
762+
763+
def training_step(self, batch, batch_idx):
764+
self.log("step_log_val", log_val, on_epoch=False)
765+
self.log("epoch_log_val", log_val, batch_size=batch_size, on_step=False, on_epoch=True)
766+
self.log("epoch_sum_log_val", log_val, on_epoch=True, reduce_fx="sum")
767+
return super().training_step(batch["batch2"], batch_idx)
768+
769+
def on_train_epoch_end(self, *args, **kwargs):
770+
results = self.trainer._results
771+
assert results["training_step.step_log_val"].value == log_val
772+
assert results["training_step.step_log_val"].cumulated_batch_size == 0
773+
assert results["training_step.epoch_log_val"].value == log_val * batch_size * fast_dev_run
774+
assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size * fast_dev_run
775+
assert results["training_step.epoch_sum_log_val"].value == log_val * fast_dev_run
776+
777+
model = CustomBoringModel()
778+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=fast_dev_run)
779+
780+
with no_warning_call(match="Trying to infer the `batch_size`"):
781+
trainer.fit(model)

tests/utilities/test_data.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212
warning_cache,
1313
)
1414
from pytorch_lightning.utilities.exceptions import MisconfigurationException
15+
from tests.deprecated_api import no_warning_call
1516
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset
1617

1718

1819
def test_extract_batch_size():
1920
"""Tests the behavior of extracting the batch size."""
2021

2122
def _check_warning_not_raised(data, expected):
22-
with pytest.warns(None) as record:
23+
with no_warning_call(match="Trying to infer the `batch_size`"):
2324
assert extract_batch_size(data) == expected
24-
assert len(record) == 0
2525

2626
def _check_warning_raised(data, expected):
2727
with pytest.warns(UserWarning, match=f"Trying to infer the `batch_size` .* we found is {expected}."):
@@ -43,6 +43,9 @@ def _check_warning_raised(data, expected):
4343
batch = {"test": [{"test": [torch.zeros(11, 10)]}]}
4444
_check_warning_not_raised(batch, 11)
4545

46+
batch = {"a": [torch.tensor(1), torch.tensor(2)], "b": torch.tensor([1, 2, 3, 4])}
47+
_check_warning_raised(batch, 1)
48+
4649
batch = {"test": [{"test": [torch.zeros(11, 10), torch.zeros(10, 10)]}]}
4750
_check_warning_raised(batch, 11)
4851

0 commit comments

Comments
 (0)