Skip to content

Commit d2def36

Browse files
authored
[bugfix] Revert inference mode support from #8813 (#9443)
Fixes #9431
1 parent cc2ac02 commit d2def36

File tree

2 files changed

+4
-13
lines changed

2 files changed

+4
-13
lines changed

CHANGELOG.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
114114
- Added a warning to deepspeed when inferring batch size ([#9221](https://github.com/PyTorchLightning/pytorch-lightning/pull/9221))
115115

116116

117-
- Added `inference_mode` for evaluation and prediction ([#8813](https://github.com/PyTorchLightning/pytorch-lightning/pull/8813))
118-
119-
120117
- Added `remove_checkpoint` to `CheckpointIO` plugin by moving the responsibility from `ModelCheckpoint` Callback ([#9373](https://github.com/PyTorchLightning/pytorch-lightning/pull/9373))
121118

122119

pytorch_lightning/trainer/trainer.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@
1616
import os
1717
import traceback
1818
import warnings
19-
from contextlib import contextmanager
2019
from datetime import timedelta
2120
from pathlib import Path
22-
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple, Union
21+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
2322
from weakref import proxy
2423

2524
import torch
@@ -77,7 +76,7 @@
7776
from pytorch_lightning.utilities.debugging import InternalDebugger
7877
from pytorch_lightning.utilities.distributed import distributed_available
7978
from pytorch_lightning.utilities.exceptions import MisconfigurationException
80-
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_9
79+
from pytorch_lightning.utilities.imports import _fault_tolerant_training
8180
from pytorch_lightning.utilities.model_helpers import is_overridden
8281
from pytorch_lightning.utilities.seed import reset_seed
8382
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
@@ -1137,7 +1136,7 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT:
11371136
# reset trainer on this loop and all child loops in case user connected a custom loop
11381137
self._evaluation_loop.trainer = self
11391138

1140-
with self.profiler.profile(f"run_{self.state.stage}_evaluation"), self._evaluation_context():
1139+
with self.profiler.profile(f"run_{self.state.stage}_evaluation"), torch.no_grad():
11411140
eval_loop_results = self._evaluation_loop.run()
11421141

11431142
# remove the tensors from the eval results
@@ -1153,7 +1152,7 @@ def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
11531152
self.reset_predict_dataloader(self.lightning_module)
11541153
# reset trainer on this loop and all child loops in case user connected a custom loop
11551154
self.predict_loop.trainer = self
1156-
with self._evaluation_context():
1155+
with torch.no_grad():
11571156
return self.predict_loop.run()
11581157

11591158
def _run_sanity_check(self, ref_model):
@@ -1382,8 +1381,3 @@ def _on_exception(self):
13821381
# save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure.
13831382
file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt")
13841383
self.save_checkpoint(file_path)
1385-
1386-
@contextmanager
1387-
def _evaluation_context(self) -> Generator:
1388-
with torch.inference_mode() if _TORCH_GREATER_EQUAL_1_9 else torch.no_grad():
1389-
yield

0 commit comments

Comments
 (0)