Skip to content

Commit 308649b

Browse files
carmoccarohitgr7
authored andcommitted
Fix scripting causing false positive deprecation warnings (#10555)
Co-authored-by: Rohit Gupta <[email protected]>
1 parent 074d51e commit 308649b

File tree

3 files changed

+15
-9
lines changed

3 files changed

+15
-9
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2727
- Fixed sampler replacement logic with `overfit_batches` to only replace the sample when `SequentialSampler` is not used ([#10486](https://github.com/PyTorchLightning/pytorch-lightning/issues/10486))
2828

2929

30-
- Fixed `to_torchscript()` causing false positive deprecation warnings ([#10470](https://github.com/PyTorchLightning/pytorch-lightning/issues/10470))
30+
- Fixed scripting causing false positive deprecation warnings ([#10470](https://github.com/PyTorchLightning/pytorch-lightning/pull/10470), [#10555](https://github.com/PyTorchLightning/pytorch-lightning/pull/10555))
3131

3232

3333
- Fixed `isinstance` not working with `init_meta_context`, materialized model not being moved to the device ([#10493](https://github.com/PyTorchLightning/metrics/pull/10493))

pytorch_lightning/loggers/tensorboard.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,9 @@ def log_graph(self, model: "pl.LightningModule", input_array=None):
240240

241241
if input_array is not None:
242242
input_array = model._apply_batch_transfer_handler(input_array)
243+
model._running_torchscript = True
243244
self.experiment.add_graph(model, input_array)
245+
model._running_torchscript = False
244246
else:
245247
rank_zero_warn(
246248
"Could not log computational graph since the"

pytorch_lightning/plugins/training_type/ipu.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -237,21 +237,25 @@ def to_tensor(x):
237237
args = apply_to_collection(args, dtype=(int, float), function=to_tensor)
238238
return args
239239

240-
def training_step(self, *args, **kwargs):
240+
def _step(self, stage: RunningStage, *args: Any, **kwargs: Any):
241241
args = self._prepare_input(args)
242-
return self.poptorch_models[RunningStage.TRAINING](*args, **kwargs)
242+
poptorch_model = self.poptorch_models[stage]
243+
self.lightning_module._running_torchscript = True
244+
out = poptorch_model(*args, **kwargs)
245+
self.lightning_module._running_torchscript = False
246+
return out
247+
248+
def training_step(self, *args, **kwargs):
249+
return self._step(RunningStage.TRAINING, *args, **kwargs)
243250

244251
def validation_step(self, *args, **kwargs):
245-
args = self._prepare_input(args)
246-
return self.poptorch_models[RunningStage.VALIDATING](*args, **kwargs)
252+
return self._step(RunningStage.VALIDATING, *args, **kwargs)
247253

248254
def test_step(self, *args, **kwargs):
249-
args = self._prepare_input(args)
250-
return self.poptorch_models[RunningStage.TESTING](*args, **kwargs)
255+
return self._step(RunningStage.TESTING, *args, **kwargs)
251256

252257
def predict_step(self, *args, **kwargs):
253-
args = self._prepare_input(args)
254-
return self.poptorch_models[RunningStage.PREDICTING](*args, **kwargs)
258+
return self._step(RunningStage.PREDICTING, *args, **kwargs)
255259

256260
def teardown(self) -> None:
257261
# undo dataloader patching

0 commit comments

Comments
 (0)