Skip to content

Commit 529c4e2

Browse files
committed
Remove BasePlugin
1 parent de57fef commit 529c4e2

File tree

7 files changed

+59
-65
lines changed

7 files changed

+59
-65
lines changed

CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
147147
- Deprecated `prepare_data_per_node` flag on Trainer and set it as a property of `DataHooks`, accessible in the `LightningModule` and `LightningDataModule` [#8958](https://github.com/PyTorchLightning/pytorch-lightning/pull/8958)
148148

149149

150-
-
151-
152150
### Removed
153151

154152
- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
@@ -205,6 +203,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
205203
- Removed `InterBatchProcessor` in favor of `DataLoaderIterDataFetcher` ([#9052](https://github.com/PyTorchLightning/pytorch-lightning/pull/9052))
206204

207205

206+
- Removed `Plugin` in `base_plugin.py`, access `TrainingTypePlugin` and `PrecisionPlugin` directly instead ([#9066](https://github.com/PyTorchLightning/pytorch-lightning/pull/9066))
207+
208+
208209
### Fixed
209210

210211
- Fixed save/load/resume from checkpoint for DeepSpeed Plugin (

pytorch_lightning/accelerators/accelerator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
186186
- hiddens(:class:`~torch.Tensor`): Passed in if
187187
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
188188
"""
189-
with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
189+
with self.precision_plugin.train_step_context():
190190
return self.training_type_plugin.training_step(*step_kwargs.values())
191191

192192
def post_training_step(self) -> None:
@@ -204,7 +204,7 @@ def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[S
204204
- dataloader_idx (int): The index of the dataloader that produced this batch
205205
(only if multiple val dataloaders used)
206206
"""
207-
with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context():
207+
with self.precision_plugin.val_step_context():
208208
return self.training_type_plugin.validation_step(*step_kwargs.values())
209209

210210
def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]:
@@ -219,7 +219,7 @@ def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OU
219219
- dataloader_idx (int): The index of the dataloader that produced this batch
220220
(only if multiple test dataloaders used).
221221
"""
222-
with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context():
222+
with self.precision_plugin.test_step_context():
223223
return self.training_type_plugin.test_step(*step_kwargs.values())
224224

225225
def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
@@ -234,7 +234,7 @@ def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
234234
- dataloader_idx (int): The index of the dataloader that produced this batch
235235
(only if multiple predict dataloaders used).
236236
"""
237-
with self.precision_plugin.predict_step_context(), self.training_type_plugin.predict_step_context():
237+
with self.precision_plugin.predict_step_context():
238238
return self.training_type_plugin.predict_step(*step_kwargs.values())
239239

240240
def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:

pytorch_lightning/plugins/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from pytorch_lightning.plugins.base_plugin import Plugin
21
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
32
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
43
from pytorch_lightning.plugins.plugins_registry import ( # noqa: F401

pytorch_lightning/plugins/base_plugin.py

Lines changed: 0 additions & 51 deletions
This file was deleted.

pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Callable, List, Optional, Tuple, Union
14+
import contextlib
15+
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
1516

1617
import torch
1718
from torch import Tensor
@@ -20,12 +21,11 @@
2021

2122
import pytorch_lightning as pl
2223
from pytorch_lightning.core.hooks import CheckpointHooks
23-
from pytorch_lightning.plugins.base_plugin import Plugin
2424
from pytorch_lightning.utilities import GradClipAlgorithmType
2525
from pytorch_lightning.utilities.types import _PARAMETERS
2626

2727

28-
class PrecisionPlugin(Plugin, CheckpointHooks):
28+
class PrecisionPlugin(CheckpointHooks):
2929
"""
3030
Base class for all plugins handling the precision-specific parts of the training.
3131
The class attribute precision must be overwritten in child classes.
@@ -136,3 +136,32 @@ def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -
136136
"""Clip gradients by norm"""
137137
parameters = self.master_params(optimizer)
138138
torch.nn.utils.clip_grad_norm_(parameters, clip_val)
139+
140+
def pre_dispatch(self) -> None:
141+
"""Hook to do something before the training/evaluation/prediction starts."""
142+
143+
def dispatch(self, trainer: "pl.Trainer") -> None:
144+
"""Hook to do something when ``Accelerator.dispatch()`` gets called."""
145+
146+
def post_dispatch(self) -> None:
147+
"""Hook to do something after the training/evaluation/prediction finishes."""
148+
149+
@contextlib.contextmanager
150+
def train_step_context(self) -> Generator:
151+
"""A contextmanager for the training step"""
152+
yield
153+
154+
@contextlib.contextmanager
155+
def val_step_context(self) -> Generator:
156+
"""A contextmanager for the validation step"""
157+
yield
158+
159+
@contextlib.contextmanager
160+
def test_step_context(self) -> Generator:
161+
"""A contextmanager for the test step"""
162+
yield
163+
164+
@contextlib.contextmanager
165+
def predict_step_context(self) -> Generator:
166+
"""A contextmanager for the predict step"""
167+
yield

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,13 @@
2525
import pytorch_lightning as pl
2626
from pytorch_lightning.overrides.base import unwrap_lightning_module
2727
from pytorch_lightning.plugins import TorchCheckpointIO
28-
from pytorch_lightning.plugins.base_plugin import Plugin
2928
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
3029
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
3130

3231
TBroadcast = TypeVar("T")
3332

3433

35-
class TrainingTypePlugin(Plugin, ABC):
34+
class TrainingTypePlugin(ABC):
3635
"""
3736
Base class for all training type plugins that change the behaviour of the training, validation and test-loop.
3837
"""
@@ -352,3 +351,12 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int)
352351
Called in the training loop before anything happens for that batch.
353352
"""
354353
pass
354+
355+
def pre_dispatch(self) -> None:
356+
"""Hook to do something before the training/evaluation/prediction starts."""
357+
358+
def dispatch(self, trainer: "pl.Trainer") -> None:
359+
"""Hook to do something at trainer run_stage starts."""
360+
361+
def post_dispatch(self) -> None:
362+
"""Hook to do something after the training/evaluation/prediction finishes."""

pytorch_lightning/trainer/trainer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
3333
from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop
3434
from pytorch_lightning.loops.fit_loop import FitLoop
35-
from pytorch_lightning.plugins import DDPSpawnPlugin, Plugin
35+
from pytorch_lightning.plugins import DDPSpawnPlugin, PrecisionPlugin, TrainingTypePlugin
3636
from pytorch_lightning.plugins.environments import ClusterEnvironment
3737
from pytorch_lightning.profiler import (
3838
AdvancedProfiler,
@@ -151,7 +151,15 @@ def __init__(
151151
terminate_on_nan: bool = False,
152152
auto_scale_batch_size: Union[str, bool] = False,
153153
prepare_data_per_node: Optional[bool] = None,
154-
plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None,
154+
plugins: Optional[
155+
Union[
156+
List[Union[TrainingTypePlugin, PrecisionPlugin, ClusterEnvironment, str]],
157+
TrainingTypePlugin,
158+
PrecisionPlugin,
159+
ClusterEnvironment,
160+
str,
161+
]
162+
] = None,
155163
amp_backend: str = "native",
156164
amp_level: str = "O2",
157165
distributed_backend: Optional[str] = None,

0 commit comments

Comments
 (0)