Skip to content

Commit a079d7f

Browse files
tangbinhananthsub
andauthored
Enable inference mode for testing and predicting (#8813)
Co-authored-by: ananthsub <[email protected]>
1 parent 25af4b1 commit a079d7f

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
107107

108108
- Added `on_exception` callback hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183))
109109

110+
110111
- Add a warning to deepspeed when inferring batch size ([#9221](https://github.com/PyTorchLightning/pytorch-lightning/pull/9221))
111112

112113

114+
- Added `inference_mode` for evaluation and prediction ([8813](https://github.com/PyTorchLightning/pytorch-lightning/pull/8813))
115+
116+
113117
### Changed
114118

115119
- Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770))
@@ -289,7 +293,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
289293
- Fixed `EarlyStopping` running on train epoch end when `check_val_every_n_epoch>1` is set ([#9156](https://github.com/PyTorchLightning/pytorch-lightning/pull/9156))
290294

291295

292-
- Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8333](https://github.com/PyTorchLightning/pytorch-lightning/issues/8333))
296+
- Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8685](https://github.com/PyTorchLightning/pytorch-lightning/pull/8685))
293297

294298

295299
- Fixed the Apex and DeepSpeed plugin closure running after the `on_before_optimizer_step` hook ([#9288](https://github.com/PyTorchLightning/pytorch-lightning/issues/9288))

pytorch_lightning/trainer/trainer.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
import os
1717
import traceback
1818
import warnings
19+
from contextlib import contextmanager
1920
from datetime import timedelta
2021
from pathlib import Path
21-
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
22+
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple, Union
2223
from weakref import proxy
2324

2425
import torch
@@ -76,7 +77,7 @@
7677
from pytorch_lightning.utilities.debugging import InternalDebugger
7778
from pytorch_lightning.utilities.distributed import distributed_available
7879
from pytorch_lightning.utilities.exceptions import MisconfigurationException
79-
from pytorch_lightning.utilities.imports import _fault_tolerant_training
80+
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_9
8081
from pytorch_lightning.utilities.model_helpers import is_overridden
8182
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
8283
from pytorch_lightning.utilities.seed import reset_seed
@@ -1146,7 +1147,7 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT:
11461147
# reset trainer on this loop and all child loops in case user connected a custom loop
11471148
self._evaluation_loop.trainer = self
11481149

1149-
with self.profiler.profile(f"run_{self.state.stage}_evaluation"), torch.no_grad():
1150+
with self.profiler.profile(f"run_{self.state.stage}_evaluation"), self._evaluation_context():
11501151
eval_loop_results = self._evaluation_loop.run()
11511152

11521153
# remove the tensors from the eval results
@@ -1162,7 +1163,7 @@ def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
11621163
self.reset_predict_dataloader(self.lightning_module)
11631164
# reset trainer on this loop and all child loops in case user connected a custom loop
11641165
self.predict_loop.trainer = self
1165-
with torch.no_grad():
1166+
with self._evaluation_context():
11661167
return self.predict_loop.run()
11671168

11681169
def _run_sanity_check(self, ref_model):
@@ -1391,3 +1392,8 @@ def _on_exception(self):
13911392
# save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure.
13921393
file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt")
13931394
self.save_checkpoint(file_path)
1395+
1396+
@contextmanager
1397+
def _evaluation_context(self) -> Generator:
1398+
with torch.inference_mode() if _TORCH_GREATER_EQUAL_1_9 else torch.no_grad():
1399+
yield

0 commit comments

Comments
 (0)