diff --git a/CHANGELOG.md b/CHANGELOG.md index d7211081bb374..daa3dc0ede39b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -243,7 +243,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898)) -- Fixed pickle error checker to now check for `pickle.PickleError` to catch all pickle errors ([#6917](https://github.com/PyTorchLightning/pytorch-lightning/pull/6917)) +- Fixed pickle error checker to now check for `pickle.PickleError` to catch all pickle errors ([#6917](https://github.com/PyTorchLightning/pytorch-lightning/pull/6917)) - Fixed `AttributeError` for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index db8eb28e2bce5..30454436994b5 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -15,22 +15,26 @@ from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Union import torch +from torch import Tensor +from torch.nn import Module from torch.optim import Optimizer from torch.utils.data import DataLoader import pytorch_lightning as pl -from pytorch_lightning.core import LightningModule from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin from pytorch_lightning.plugins.training_type import TrainingTypePlugin from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum +if _NATIVE_AMP_AVAILABLE: + from torch.cuda.amp import GradScaler + _STEP_OUTPUT_TYPE = Union[torch.Tensor, Dict[str, torch.Tensor], None] -class Accelerator(object): +class Accelerator: """ The Accelerator Base Class. An Accelerator is meant to deal with one type of Hardware. @@ -52,7 +56,6 @@ def __init__( training_type_plugin: TrainingTypePlugin, ) -> None: """ - Args: precision_plugin: the plugin to handle precision-specific parts training_type_plugin: the plugin to handle different training routines @@ -64,7 +67,7 @@ def __init__( self.lr_schedulers: Sequence = [] self.optimizer_frequencies: Sequence = [] - def connect(self, model: LightningModule) -> None: + def connect(self, model: 'pl.LightningModule') -> None: """Transfers ownership of the model to this plugin""" self.training_type_plugin.connect(model) @@ -76,7 +79,7 @@ def setup_environment(self) -> None: """ self.training_type_plugin.setup_environment() - def setup(self, trainer: 'pl.Trainer', model: LightningModule) -> None: + def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None: """ Setup plugins for the trainer fit and creates optimizers. @@ -111,22 +114,22 @@ def post_dispatch(self, trainer: 'pl.Trainer') -> None: self.precision_plugin.post_dispatch() @property - def model(self) -> torch.nn.Module: - """Returns the model. This can also be a wrapped LightningModule. + def model(self) -> Module: + """ + Returns the model. This can also be a wrapped LightningModule. For retrieving the pure LightningModule use :attr:`Accelerator.lightning_module` - """ return self.training_type_plugin.model @model.setter - def model(self, new_model: torch.nn.Module) -> None: + def model(self, new_model: Module) -> None: self.training_type_plugin.model = new_model @property - def lightning_module(self) -> LightningModule: - """Returns the pure LightningModule. + def lightning_module(self) -> 'pl.LightningModule': + """ + Returns the pure LightningModule. To get the potentially wrapped model use :attr:`Accelerator.model` - """ return self.training_type_plugin.lightning_module @@ -135,7 +138,8 @@ def root_device(self) -> torch.device: return self.training_type_plugin.root_device def teardown(self) -> None: - """This method is called to teardown the training process. + """ + This method is called to teardown the training process. It is the right place to release memory and free other ressources. """ pass @@ -268,13 +272,13 @@ def validation_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE: def backward( self, - closure_loss: torch.Tensor, + closure_loss: Tensor, optimizer: Optimizer, optimizer_idx: int, should_accumulate: bool, *args: Any, **kwargs: Any, - ) -> torch.Tensor: + ) -> Tensor: """Forwards backward-calls to the precision plugin. Args: @@ -325,9 +329,7 @@ def clip_gradients( gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: """clips all the optimizer parameters to the given value""" - self.precision_plugin.clip_gradients( - self.model, optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm - ) + self.precision_plugin.clip_gradients(optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm) def on_train_epoch_end(self, outputs: Sequence[_STEP_OUTPUT_TYPE]) -> None: """Hook to do something on the end of an training epoch @@ -342,11 +344,11 @@ def on_train_end(self) -> None: pass def setup_optimizers(self, trainer: 'pl.Trainer') -> None: - """creates optimizers and schedulers + """ + Creates optimizers and schedulers Args: trainer: the Trainer, these optimizers should be connected to - model: the model to be optimized by the created optimizers """ if trainer.state not in (TrainerState.FITTING, TrainerState.TUNING): return @@ -357,7 +359,7 @@ def setup_optimizers(self, trainer: 'pl.Trainer') -> None: self.lr_schedulers = lr_schedulers self.optimizer_frequencies = optimizer_frequencies - def setup_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None: + def setup_training_type_plugin(self, plugin: TrainingTypePlugin, model: 'pl.LightningModule') -> None: """Attaches the training type plugin to the accelerator.""" plugin.setup(model) @@ -390,22 +392,21 @@ def precision(self) -> Union[str, int]: return self.precision_plugin.precision @property - def scaler(self) -> Optional['torch.cuda.amp.GradScaler']: - + def scaler(self) -> Optional['GradScaler']: return getattr(self.precision_plugin, 'scaler', None) @property def rpc_enabled(self) -> bool: return self.training_type_plugin.rpc_enabled - def optimizer_state(self, optimizer: Optimizer) -> Dict[str, torch.Tensor]: + def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: """ Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom plugins. """ return getattr(self.training_type_plugin, 'optimizer_state', lambda x: x.state_dict())(optimizer) - def on_save(self, checkpoint: Dict[str, Union[Any, torch.Tensor]]) -> Dict[str, Union[Any, torch.Tensor]]: + def on_save(self, checkpoint: Dict[str, Union[Any, Tensor]]) -> Dict[str, Union[Any, Tensor]]: return self.training_type_plugin.on_save(checkpoint) def barrier(self, name: Optional[str] = None) -> None: @@ -420,7 +421,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: """ return self.training_type_plugin.broadcast(obj, src) - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: """ Function to gather a tensor from several distributed processes. @@ -464,7 +465,7 @@ def model_sharded_context(self) -> Generator[None, None, None]: yield # todo: remove in v1.5 - def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None: + def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: 'pl.LightningModule') -> None: """ Attaches the training type plugin to the accelerator. Also transfers ownership of the model to this plugin diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 9aac6854db142..b1b9a2d96f7f5 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -15,6 +15,7 @@ from torch.optim import Optimizer +import pytorch_lightning as pl from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin @@ -26,10 +27,9 @@ import torch_xla.core.xla_model as xm from torch_xla._patched_functions import clip_grad_norm_ + # rename to mock in a test xla_clip_grad_norm_ = clip_grad_norm_ -import pytorch_lightning as pl - class TPUAccelerator(Accelerator): """ Accelerator for TPU devices. """ @@ -59,19 +59,16 @@ def clip_gradients( self, optimizer: Optimizer, clip_val: Union[float, int], - norm_type: float = 2.0, - gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM + gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: - assert gradient_clip_algorithm is GradClipAlgorithmType.NORM, \ + assert gradient_clip_algorithm == GradClipAlgorithmType.NORM, \ "Only NORM gradient clipping is supported on TPU for now" - model = self.lightning_module - parameters = model.parameters() - grad_clip_val = float(clip_val) if grad_clip_val <= 0: return - max_norm = grad_clip_val + parameters = self.model.parameters() + norm_type = 2.0 - xla_clip_grad_norm_(parameters, max_norm, norm_type) + xla_clip_grad_norm_(parameters, grad_clip_val, norm_type) diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index b2b1c726a0467..30614d3faa187 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -11,14 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Generator, List, Sequence, Tuple, Type +from typing import Any, Callable, ContextManager, Iterator, List, Sequence, Tuple, Type import torch +from torch import Tensor +from torch.nn import Module from torch.optim import Optimizer from pytorch_lightning.core import LightningModule from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin -from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, rank_zero_warn +from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType + +PARAMETERS = Iterator[torch.nn.Parameter] if _APEX_AVAILABLE: from apex import amp @@ -32,11 +36,15 @@ def __init__(self, amp_level: str = "O2") -> None: self.backend = AMPType.APEX self.amp_level = amp_level - def master_params(self, optimizer: Optimizer) -> Generator[torch.Tensor, None, None]: + def master_params(self, optimizer: Optimizer) -> PARAMETERS: return amp.master_params(optimizer) - def connect(self, model: torch.nn.Module, optimizers: Sequence[Optimizer], - lr_schedulers: Sequence[Any]) -> Tuple[torch.nn.Module, Sequence[Optimizer], Sequence[Any]]: + def connect( + self, + model: Module, + optimizers: Sequence[Optimizer], + lr_schedulers: Sequence[Any], + ) -> Tuple[Module, Sequence[Optimizer], Sequence[Any]]: """Connects the precision plugin to the training process, configures apex and reinits the schedulers """ @@ -49,28 +57,28 @@ def connect(self, model: torch.nn.Module, optimizers: Sequence[Optimizer], def backward( self, model: LightningModule, - closure_loss: torch.Tensor, + closure_loss: Tensor, optimizer: Optimizer, opt_idx: int, should_accumulate: bool, *args: Any, **kwargs: Any, - ) -> torch.Tensor: + ) -> Tensor: """performs the actual backpropagation Args: model: the model to be optimized closure_loss: the loss value obtained from the closure optimizer: the optimizer to perform the step lateron - opt_idx: the optimizer's index + opt_idx: the optimizer index should_accumulate: whether to accumulate gradients or not """ - closure_loss = amp.scale_loss(closure_loss, model.trainer.optimizers if optimizer is None else optimizer) + opt = model.trainer.optimizers if optimizer is None else optimizer + scaled_loss: ContextManager[Tensor] = amp.scale_loss(closure_loss, opt) # enter apex context - context = closure_loss - closure_loss = closure_loss.__enter__() + closure_loss = scaled_loss.__enter__() # do backward pass # TODO: not entirely sure, why we need this @@ -84,10 +92,8 @@ def backward( closure_loss.backward(*args, **kwargs) # exit amp context - a, b, c = None, None, None - error = context.__exit__(a, b, c) + error = scaled_loss.__exit__(None, None, None) if error: - rank_zero_warn(a, b, c) raise Exception("apex unscale error") # once backward has been applied, release graph @@ -97,17 +103,17 @@ def backward( def configure_apex( self, amp: Type, - model: LightningModule, + model: Module, optimizers: List[Optimizer], amp_level: str, - ) -> Tuple[LightningModule, List[Optimizer]]: + ) -> Tuple[Module, List[Optimizer]]: r""" Override to init AMP your own way. Must return a model and list of optimizers. Args: amp: pointer to amp library object. - model: pointer to current :class:`LightningModule`. + model: pointer to current :class:`torch.nn.Module`. optimizers: list of optimizers passed in :meth:`configure_optimizers`. amp_level: AMP mode chosen ('O1', 'O2', etc...) diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index 22fa5bf082357..dc29a5cee4014 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any, Callable, Union -import torch +from torch import Tensor from torch.optim import Optimizer import pytorch_lightning as pl @@ -54,13 +54,13 @@ def pre_optimizer_step( def backward( self, model: 'pl.LightningModule', - closure_loss: torch.Tensor, + closure_loss: Tensor, optimizer: Optimizer, opt_idx: int, should_accumulate: bool, *args: Any, **kwargs: Any, - ) -> torch.Tensor: + ) -> Tensor: if is_overridden('backward', model): warning_cache.warn( "Overridden backward hook in the LightningModule will be ignored since DeepSpeed handles" @@ -76,7 +76,6 @@ def backward( def clip_gradients( self, - model: 'pl.LightningModule', optimizer: Optimizer, clip_val: Union[int, float], gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 059506f830b8f..ac33eeea287eb 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -12,16 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Any, Callable, Generator, Sequence, Tuple, Union +from typing import Any, Callable, Iterator, Sequence, Tuple, Union import torch -import torch.nn as nn +from torch import Tensor +from torch.nn import Module from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning.plugins.base_plugin import Plugin from pytorch_lightning.utilities import GradClipAlgorithmType +PARAMETERS = Iterator[torch.nn.Parameter] + class PrecisionPlugin(Plugin): """ @@ -32,17 +35,10 @@ class PrecisionPlugin(Plugin): EPSILON: float = 1e-6 precision: Union[str, int] = 32 - def __init__(self) -> None: - super().__init__() - self.clip_grad_funcs = { - GradClipAlgorithmType.VALUE: self.clip_grad_by_value, - GradClipAlgorithmType.NORM: self.clip_grad_by_norm, - } - - def master_params(self, optimizer: Optimizer) -> Generator[torch.Tensor, None, None]: - """The master params of the model. Returns the plain model params here. + def master_params(self, optimizer: Optimizer) -> PARAMETERS: + """ + The master params of the model. Returns the plain model params here. Maybe different in other precision plugins. - """ for group in optimizer.param_groups: for p in group["params"]: @@ -50,23 +46,23 @@ def master_params(self, optimizer: Optimizer) -> Generator[torch.Tensor, None, N def connect( self, - model: nn.Module, + model: Module, optimizers: Sequence[Optimizer], lr_schedulers: Sequence[Any], - ) -> Tuple[nn.Module, Sequence[Optimizer], Sequence[Any]]: + ) -> Tuple[Module, Sequence[Optimizer], Sequence[Any]]: """Connects this plugin to the accelerator and the training process""" return model, optimizers, lr_schedulers def backward( self, model: 'pl.LightningModule', - closure_loss: torch.Tensor, + closure_loss: Tensor, optimizer: Optimizer, opt_idx: int, should_accumulate: bool, *args: Any, **kwargs: Any, - ) -> torch.Tensor: + ) -> Tensor: """performs the actual backpropagation Args: @@ -106,7 +102,6 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: def clip_gradients( self, - model: 'pl.LightningModule', optimizer: Optimizer, clip_val: Union[int, float], gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, @@ -119,24 +114,25 @@ def clip_gradients( if clip_val <= 0: return - clip_grad_func = self.clip_grad_funcs[gradient_clip_algorithm] - clip_grad_func(optimizer, clip_val) # type: ignore + if gradient_clip_algorithm == GradClipAlgorithmType.VALUE: + self.clip_grad_by_value(optimizer, clip_val) + elif gradient_clip_algorithm == GradClipAlgorithmType.NORM: + # TODO: there should be a mechanism to set `norm_type` + self.clip_grad_by_norm(optimizer, clip_val, eps=self.EPSILON) def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: """Clip gradients by value""" - parameters = list(self.master_params(optimizer)) + parameters = self.master_params(optimizer) torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val) - def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = 2.0) -> None: + def clip_grad_by_norm( + self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = 2.0, eps: float = 1e-6 + ) -> None: """Clip gradients by norm""" - # TODO: separate TPU case from here - parameters = list(self.master_params(optimizer)) - max_norm = clip_val + parameters = self.master_params(optimizer) - if isinstance(parameters, torch.Tensor): - parameters = [parameters] + # TODO: replace this with torch.nn.clip_grad_norm_ parameters = list(filter(lambda p: p.grad is not None, parameters)) - device = parameters[0].device if norm_type == math.inf: @@ -147,9 +143,7 @@ def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float], n torch.norm(p.grad.data.to(device), norm_type, out=out[i]) total_norm = torch.norm(out, norm_type) - eps = self.EPSILON - - clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps) + clip_coef = torch.tensor(clip_val, device=device) / (total_norm + eps) clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef)) for p in parameters: p.grad.data.mul_(clip_coef.to(p.grad.data.device)) diff --git a/pytorch_lightning/plugins/precision/sharded_native_amp.py b/pytorch_lightning/plugins/precision/sharded_native_amp.py index 28555a1a60b8d..4d8a2f0f934ed 100644 --- a/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import cast, Union - -from torch.optim import Optimizer +from typing import Union from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE @@ -24,13 +22,13 @@ class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): - """Mixed Precision for Sharded Training - """ + """Mixed Precision for Sharded Training""" def __init__(self) -> None: super().__init__() self.scaler = ShardedGradScaler() - def clip_grad_by_norm(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None: - optimizer = cast(OSS, optimizer) + def clip_grad_by_norm( + self, optimizer: 'OSS', clip_val: Union[int, float], norm_type: float = 2.0, eps: float = 1e-6 + ) -> None: optimizer.clip_grad_norm(clip_val, norm_type=norm_type) diff --git a/setup.cfg b/setup.cfg index 70139348462aa..3fa6e39076725 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,7 +50,6 @@ omit = [flake8] -# TODO: this should be 88 or 100 according PEP8 max-line-length = 120 exclude = .tox, @@ -105,9 +104,6 @@ NO_SPACES_AROUND_SELECTED_BINARY_OPERATORS = false [mypy] -# Typing tests is low priority, but enabling type checking on the -# untyped test functions (using `--check-untyped-defs`) is still -# high-value because it helps test the typing. files = pytorch_lightning, pl_examples, benchmarks, tests disallow_untyped_defs = True ignore_missing_imports = True @@ -115,12 +111,10 @@ show_error_codes = True warn_redundant_casts = True warn_unused_configs = True warn_unused_ignores = True +allow_redefinition = True +# disable this rule as the Trainer attributes are defined in the connectors, not in its __init__ disable_error_code = attr-defined -# todo: this is magically failing, need to be revisited -[mypy-pytorch_lightning.accelerators.tpu.*] -ignore_errors = True - # todo: add proper typing to this module... [mypy-pytorch_lightning.callbacks.*] ignore_errors = True @@ -164,8 +158,7 @@ ignore_errors = True # todo: add proper typing to this module... [mypy-pytorch_lightning.trainer.*] ignore_errors = True - -# whitelist evaluation_loop.py +# whitelist [mypy-pytorch_lightning.trainer.evaluation_loop] ignore_errors = False diff --git a/setup.py b/setup.py index 7e75c514734b5..264f219e22b55 100755 --- a/setup.py +++ b/setup.py @@ -23,7 +23,6 @@ try: from pytorch_lightning import __about__ as info from pytorch_lightning import setup_tools - except ImportError: # alternative https://stackoverflow.com/a/67692/4521646 sys.path.append("pytorch_lightning") diff --git a/tests/trainer/test_training_loop.py b/tests/trainer/test_training_loop.py index 25be29d73f1a4..b024a7eabbecc 100644 --- a/tests/trainer/test_training_loop.py +++ b/tests/trainer/test_training_loop.py @@ -72,9 +72,7 @@ def optimizer_step( super().optimizer_step( epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs ) - self.called.append( - "optimizer_step" - ) # append after as closure calls other methods + self.called.append("optimizer_step") # append after as closure calls other methods def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): self.called.append("on_train_batch_end") diff --git a/tests/utilities/test_parsing.py b/tests/utilities/test_parsing.py index 9c6900f81fcae..57e49df2df066 100644 --- a/tests/utilities/test_parsing.py +++ b/tests/utilities/test_parsing.py @@ -14,7 +14,6 @@ import inspect import pytest - from torch.jit import ScriptModule from pytorch_lightning.utilities.parsing import (