Skip to content

Commit 06c5903

Browse files
Simplify several profile calls (#11031)
1 parent 61a744f commit 06c5903

File tree

5 files changed

+11
-14
lines changed

5 files changed

+11
-14
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
102102
* Some configuration errors that were previously raised as `MisconfigurationException`s will now be raised as `ProcessRaisedException` (torch>=1.8) or as `Exception` (torch<1.8)
103103

104104

105+
- Changed `batch_to_device` entry in profiling from stage-specific to generic, to match profiling of other hooks ([#11031](https://github.com/PyTorchLightning/pytorch-lightning/pull/11031))
106+
107+
105108
- Changed the info message for finalizing ddp-spawn worker processes to a debug-level message ([#10864](https://github.com/PyTorchLightning/pytorch-lightning/pull/10864))
106109

107110

pytorch_lightning/loops/batch/training_batch_loop.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,7 @@ def _tbptt_split_batch(self, batch: Any) -> List[Any]:
116116
if tbptt_steps == 0:
117117
return [batch]
118118

119-
model_ref = self.trainer.lightning_module
120-
with self.trainer.profiler.profile("tbptt_split_batch"):
121-
splits = model_ref.tbptt_split_batch(batch, tbptt_steps)
119+
splits = self.trainer._call_lightning_module_hook("tbptt_split_batch", batch, tbptt_steps)
122120
return splits
123121

124122
def _update_running_loss(self, current_loss: Tensor) -> None:

pytorch_lightning/loops/epoch/evaluation_epoch_loop.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,7 @@ def advance( # type: ignore[override]
107107
raise StopIteration
108108

109109
if not data_fetcher.store_on_device:
110-
with self.trainer.profiler.profile("evaluation_batch_to_device"):
111-
batch = self.trainer.training_type_plugin.batch_to_device(batch, dataloader_idx=(dataloader_idx or 0))
110+
batch = self.trainer._call_ttp_hook("batch_to_device", batch, dataloader_idx=(dataloader_idx or 0))
112111

113112
self.batch_progress.increment_ready()
114113

@@ -121,9 +120,9 @@ def advance( # type: ignore[override]
121120
self.batch_progress.increment_started()
122121

123122
# lightning module methods
124-
with self.trainer.profiler.profile("evaluation_step_and_end"):
125-
output = self._evaluation_step(**kwargs)
126-
output = self._evaluation_step_end(output)
123+
124+
output = self._evaluation_step(**kwargs)
125+
output = self._evaluation_step_end(output)
127126

128127
self.batch_progress.increment_processed()
129128

pytorch_lightning/loops/epoch/prediction_epoch_loop.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,11 @@ def advance( # type: ignore[override]
9696
if batch is None:
9797
raise StopIteration
9898

99-
with self.trainer.profiler.profile("predict_batch_to_device"):
100-
batch = self.trainer.training_type_plugin.batch_to_device(batch, dataloader_idx=dataloader_idx)
99+
batch = self.trainer._call_ttp_hook("batch_to_device", batch, dataloader_idx=dataloader_idx)
101100

102101
self.batch_progress.increment_ready()
103102

104-
with self.trainer.profiler.profile("predict_step"):
105-
self._predict_step(batch, batch_idx, dataloader_idx)
103+
self._predict_step(batch, batch_idx, dataloader_idx)
106104

107105
def on_run_end(self) -> Tuple[List[Any], List[List[int]]]:
108106
"""Returns the predictions and the corresponding batch indices."""

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,7 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
156156
batch_idx, (batch, self.batch_progress.is_last_batch) = next(self._dataloader_iter)
157157

158158
if not data_fetcher.store_on_device:
159-
with self.trainer.profiler.profile("training_batch_to_device"):
160-
batch = self.trainer.training_type_plugin.batch_to_device(batch)
159+
batch = self.trainer._call_ttp_hook("batch_to_device", batch)
161160

162161
self.batch_progress.increment_ready()
163162

0 commit comments

Comments
 (0)