Skip to content

Commit c38fbad

Browse files
rohitgr7lexierule
authored andcommitted
Fix schedule reset logic in pytorch profiler (#10837)
1 parent 5820711 commit c38fbad

File tree

6 files changed

+74
-19
lines changed

6 files changed

+74
-19
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1717
- Fixed uploading best model checkpoint in NeptuneLogger ([#10369](https://github.com/PyTorchLightning/pytorch-lightning/pull/10369))
1818

1919

20+
- Fixed early schedule reset logic in PyTorch profiler that was causing data leak ([#10837](https://github.com/PyTorchLightning/pytorch-lightning/pull/10837))
21+
22+
23+
-
24+
25+
26+
-
27+
28+
2029
## [1.5.4] - 2021-11-30
2130

2231
### Fixed

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, List, Optional, Sequence, Union
14+
from typing import Any, List, Sequence
1515

1616
from deprecate.utils import void
1717
from torch.utils.data.dataloader import DataLoader
@@ -32,7 +32,8 @@ def __init__(self):
3232
self.epoch_loop = EvaluationEpochLoop()
3333

3434
self._results = ResultCollection(training=False)
35-
self._max_batches: Optional[Union[int, Sequence[int]]] = None
35+
self._outputs: List[EPOCH_OUTPUT] = []
36+
self._max_batches: List[int] = []
3637
self._has_run: bool = False
3738

3839
@property
@@ -147,7 +148,7 @@ def teardown(self) -> None:
147148
self._results.cpu()
148149
self.epoch_loop.teardown()
149150

150-
def _get_max_batches(self) -> List[Union[int, float]]:
151+
def _get_max_batches(self) -> List[int]:
151152
"""Returns the max number of batches for each dataloader."""
152153
if self.trainer.testing:
153154
max_batches = self.trainer.num_test_batches

pytorch_lightning/loops/dataloader/prediction_loop.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,7 @@ def num_dataloaders(self) -> int:
5353
@property
5454
def max_batches(self) -> List[int]:
5555
"""The max number of batches this loop will run for each dataloader."""
56-
max_batches = self.trainer.num_predict_batches
57-
if isinstance(max_batches, int):
58-
max_batches = [max_batches] * len(self.dataloaders)
59-
return max_batches
56+
return self.trainer.num_predict_batches
6057

6158
@property
6259
def dataloaders(self) -> Sequence[DataLoader]:

pytorch_lightning/profiler/pytorch.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -335,9 +335,24 @@ def _init_kineto(self, profiler_kwargs: Any) -> None:
335335
with_stack = profiler_kwargs.get("with_stack", False) or self._export_to_flame_graph
336336
self._profiler_kwargs["with_stack"] = with_stack
337337

338+
@property
339+
def _total_steps(self) -> int:
340+
trainer = self._lightning_module.trainer
341+
if self._schedule.is_training:
342+
return trainer.num_training_batches
343+
if self._schedule._current_action == "validation_step":
344+
return sum(trainer.num_val_batches) + sum(trainer.num_sanity_val_batches)
345+
if self._schedule._current_action == "test_step":
346+
return sum(trainer.num_test_batches)
347+
if self._schedule._current_action == "predict_step":
348+
return sum(trainer.num_predict_batches)
349+
338350
def _should_override_schedule(self) -> bool:
339-
return (self._lightning_module is not None and self._lightning_module.trainer.limit_train_batches < 5) and (
340-
self._schedule is not None and self._schedule._schedule == self._default_schedule()
351+
return (
352+
self._lightning_module is not None
353+
and self._schedule is not None
354+
and self._total_steps < 5
355+
and self._schedule._schedule == self._default_schedule()
341356
)
342357

343358
@staticmethod
@@ -410,6 +425,9 @@ def stop(self, action_name: str) -> None:
410425
action_name in self.STEP_FUNCTIONS or action_name.startswith(self.STEP_FUNCTION_PREFIX)
411426
):
412427

428+
if self._schedule is not None:
429+
self._schedule.pre_step(action_name)
430+
413431
# the default schedule requires a minimum of 5 steps to properly work: `wait=1, warmup=1, active=3`.
414432
# otherwise, this will raise a `segmentation fault`.
415433
if self._should_override_schedule():
@@ -420,9 +438,6 @@ def stop(self, action_name: str) -> None:
420438
self._schedule = None
421439
self.profiler.schedule = torch.profiler.profiler._default_schedule_fn
422440

423-
if self._schedule is not None:
424-
self._schedule.pre_step(action_name)
425-
426441
def on_trace_ready(profiler):
427442
if self.dirpath is not None:
428443
if self._export_to_chrome:

pytorch_lightning/trainer/data_loading.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,18 @@ class TrainerDataLoadingMixin(ABC):
5252
val_check_interval: float
5353
tpu_local_core_rank: int
5454
train_dataloader: DataLoader
55-
num_training_batches: Union[int, float]
56-
val_check_batch: float
57-
val_dataloaders: Optional[List[DataLoader]]
58-
num_val_batches: List[Union[int, float]]
59-
test_dataloaders: Optional[List[DataLoader]]
60-
num_test_batches: List[Union[int, float]]
6155
limit_train_batches: Union[int, float]
56+
num_training_batches: int
57+
val_check_batch: float
58+
val_dataloaders: List[DataLoader]
59+
limit_val_batches: Union[int, float]
60+
num_val_batches: List[int]
61+
test_dataloaders: List[DataLoader]
62+
limit_test_batches: Union[int, float]
63+
num_test_batches: List[int]
64+
predict_dataloaders: List[DataLoader]
65+
limit_predict_batches: Union[int, float]
66+
num_predict_batches: List[int]
6267
log_every_n_steps: int
6368
overfit_batches: Union[int, float]
6469
distributed_sampler_kwargs: dict

tests/profiler/test_profiler.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pytorch_lightning.loggers.base import LoggerCollection
2626
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
2727
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
28-
from pytorch_lightning.profiler.pytorch import RegisterRecordFunction
28+
from pytorch_lightning.profiler.pytorch import RegisterRecordFunction, warning_cache
2929
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7
3030
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3131
from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE
@@ -523,3 +523,31 @@ def test_trainer_profiler_incorrect_str_arg():
523523
match=r"When passing string value for the `profiler` parameter of `Trainer`, it can only be one of.*",
524524
):
525525
Trainer(profiler="unknown_profiler")
526+
527+
528+
@pytest.mark.skipif(not _KINETO_AVAILABLE, reason="Requires PyTorch Profiler Kineto")
529+
@pytest.mark.parametrize(
530+
["trainer_config", "trainer_fn"],
531+
[
532+
({"limit_train_batches": 4, "limit_val_batches": 7}, "fit"),
533+
({"limit_train_batches": 7, "limit_val_batches": 4, "num_sanity_val_steps": 0}, "fit"),
534+
(
535+
{
536+
"limit_train_batches": 7,
537+
"limit_val_batches": 2,
538+
},
539+
"fit",
540+
),
541+
({"limit_val_batches": 4}, "validate"),
542+
({"limit_test_batches": 4}, "test"),
543+
({"limit_predict_batches": 4}, "predict"),
544+
],
545+
)
546+
def test_pytorch_profiler_raises_warning_for_limited_steps(tmpdir, trainer_config, trainer_fn):
547+
model = BoringModel()
548+
trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", max_epochs=1, **trainer_config)
549+
warning_cache.clear()
550+
with pytest.warns(UserWarning, match="not enough steps to properly record traces"):
551+
getattr(trainer, trainer_fn)(model)
552+
assert trainer.profiler._schedule is None
553+
warning_cache.clear()

0 commit comments

Comments
 (0)