From f0aed43947dec66b8dd5d7aebff7d33c0de21402 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 13 Apr 2021 14:28:31 +0530 Subject: [PATCH 01/26] Add Plugins Registry --- pytorch_lightning/plugins/plugins_registry.py | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 pytorch_lightning/plugins/plugins_registry.py diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py new file mode 100644 index 0000000000000..2e25efaa45be2 --- /dev/null +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -0,0 +1,68 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import UserDict +from typing import Any, Callable, List, Optional + + +class _PluginsRegistry(UserDict): + """ + This class is a Registry that stores information about the Plugins. + + The Plugins are mapped to strings. These strings are names that idenitify + a plugin, eg., "deepspeed". It also returns Optional description and + parameters to initialize the Plugin, which were defined durng the + registeration. + + The motivation for having a PluginRegistry is to make it convenient + for the Users to try different Plugins by passing just strings + to the plugins flag to the Trainer. + + """ + + def register(self, name: str, func: Optional[Callable] = None, description: Optional[str] = None, **init_params): + + if not (name is None or isinstance(name, str)): + raise TypeError(f'`name` must be a str, found {name}') + + data = {} + data["description"] = description if description is not None else "" + + data["init_params"] = init_params + + def do_register(func): + data["func"] = func + self[name] = data + return data + + if func is not None: + return do_register(func) + + return do_register + + def get(self, name: str): + if name in self: + return self[name] + raise KeyError("Key not Found") + + def remove(self, name: str): + self.pop(name) + + def available_plugins(self) -> List: + return list(self.keys()) + + def __str__(self): + return "Registered Plugins: {}".format(", ".join(self.keys())) + + +PluginsRegistry = _PluginsRegistry() From 38004f0f78847c8d7d6bc5b0ed4d0633544d70ab Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 13 Apr 2021 14:46:21 +0530 Subject: [PATCH 02/26] Update get --- pytorch_lightning/plugins/plugins_registry.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py index 2e25efaa45be2..dc4847fa0a08e 100644 --- a/pytorch_lightning/plugins/plugins_registry.py +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -14,6 +14,8 @@ from collections import UserDict from typing import Any, Callable, List, Optional +from pytorch_lightning.utilities.exceptions import MisconfigurationException + class _PluginsRegistry(UserDict): """ @@ -35,6 +37,9 @@ def register(self, name: str, func: Optional[Callable] = None, description: Opti if not (name is None or isinstance(name, str)): raise TypeError(f'`name` must be a str, found {name}') + if name in self: + raise MisconfigurationException(f"{name} is already present in the registry.") + data = {} data["description"] = description if description is not None else "" @@ -53,7 +58,9 @@ def do_register(func): def get(self, name: str): if name in self: return self[name] - raise KeyError("Key not Found") + err_msg = "'{}' not found in registry. Available names: {}" + available_names = ", ".join(sorted(self.keys())) or "none" + raise KeyError(err_msg.format(name, available_names)) def remove(self, name: str): self.pop(name) From 25ca516ef2f757980120b44bc19b9ba26bdc99e5 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 13 Apr 2021 14:57:06 +0530 Subject: [PATCH 03/26] update --- pytorch_lightning/plugins/plugins_registry.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py index dc4847fa0a08e..189fee1396594 100644 --- a/pytorch_lightning/plugins/plugins_registry.py +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -32,13 +32,23 @@ class _PluginsRegistry(UserDict): """ - def register(self, name: str, func: Optional[Callable] = None, description: Optional[str] = None, **init_params): + def register( + self, + name: str, + func: Optional[Callable] = None, + description: Optional[str] = None, + override: bool = False, + **init_params + ): if not (name is None or isinstance(name, str)): raise TypeError(f'`name` must be a str, found {name}') - if name in self: - raise MisconfigurationException(f"{name} is already present in the registry.") + if name in self and not override: + raise MisconfigurationException( + f"'{name}' is already present in the registry." + " HINT: Use `override=True`." + ) data = {} data["description"] = description if description is not None else "" @@ -57,7 +67,9 @@ def do_register(func): def get(self, name: str): if name in self: - return self[name] + data = self[name] + return data["func"](**data["init_params"]) + err_msg = "'{}' not found in registry. Available names: {}" available_names = ", ".join(sorted(self.keys())) or "none" raise KeyError(err_msg.format(name, available_names)) From 62a3c9f5b765b51f429641aca43d34ee0693149d Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 13 Apr 2021 15:55:32 +0530 Subject: [PATCH 04/26] Add training type registry --- pytorch_lightning/plugins/__init__.py | 37 +++++++++++++++++++ pytorch_lightning/plugins/plugins_registry.py | 8 ++-- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index a67235baa4767..e5cfc97a26264 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -1,4 +1,5 @@ from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401 +from pytorch_lightning.plugins.plugins_registry import TrainingTypePluginsRegistry from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401 @@ -47,3 +48,39 @@ 'DDPShardedPlugin', 'DDPSpawnShardedPlugin', ] + +REGISTRY_TRAINING_TYPE_PLUGINS = [{ + "plugin": DeepSpeedPlugin, + "variants": [{ + "name": "deepspeed", + "description": "Default DeepSpeed Plugin" + }, { + "name": "deepspeed_stage_2", + "description": "DeepSpeed with ZeRO Stage 2 enabled", + "stage": 2 + }, { + "name": "deepspeed_stage_2_offload", + "description": "DeepSpeed with ZeRO Stage 2 enabled and Offload", + "stage": 2, + "cpu_offload": True + }, { + "name": "deepspeed_stage_3", + "description": "DeepSpeed with ZeRO Stage 3 enabled", + "stage": 3 + }, { + "name": "deepspeed_stage_3_offload", + "description": "DeepSpeed with ZeRO Stage 2 enabled and Offload", + "stage": 3, + "cpu_offload": True + }] +}] + + +def register_training_type_plugins(plugins): + for plugin_info in plugins: + plugin = plugin_info["plugin"] + for variant in plugin_info["variants"]: + TrainingTypePluginsRegistry.register(**variant) + + +register_training_type_plugins(REGISTRY_TRAINING_TYPE_PLUGINS) diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py index 189fee1396594..9dff900cfbea5 100644 --- a/pytorch_lightning/plugins/plugins_registry.py +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import UserDict -from typing import Any, Callable, List, Optional +from typing import Callable, List, Optional from pytorch_lightning.utilities.exceptions import MisconfigurationException -class _PluginsRegistry(UserDict): +class _TrainingTypePluginsRegistry(UserDict): """ - This class is a Registry that stores information about the Plugins. + This class is a Registry that stores information about the Training Type Plugins. The Plugins are mapped to strings. These strings are names that idenitify a plugin, eg., "deepspeed". It also returns Optional description and @@ -84,4 +84,4 @@ def __str__(self): return "Registered Plugins: {}".format(", ".join(self.keys())) -PluginsRegistry = _PluginsRegistry() +TrainingTypePluginsRegistry = _TrainingTypePluginsRegistry() From a05a0f90eecf014306ba39af31dac3b54fcc6534 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 13 Apr 2021 17:55:55 +0530 Subject: [PATCH 05/26] Example --- pytorch_lightning/plugins/plugins_registry.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py index 9dff900cfbea5..319205d8f733a 100644 --- a/pytorch_lightning/plugins/plugins_registry.py +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -26,10 +26,17 @@ class _TrainingTypePluginsRegistry(UserDict): parameters to initialize the Plugin, which were defined durng the registeration. - The motivation for having a PluginRegistry is to make it convenient + The motivation for having a TrainingTypePluginRegistry is to make it convenient for the Users to try different Plugins by passing just strings to the plugins flag to the Trainer. + Example:: + + @TrainingTypePluginsRegistry.register("lightning", description="Super fast", a=1, b=True) + class LightningPlugin: + def __init__(self, a, b): + ... + """ def register( From 4e9cec836af97fe57d5e4a48a894eb34b070caf5 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 14 Apr 2021 08:59:53 +0530 Subject: [PATCH 06/26] Update register plugin --- pytorch_lightning/plugins/__init__.py | 36 ------------------- .../plugins/training_type/deepspeed.py | 6 ++++ 2 files changed, 6 insertions(+), 36 deletions(-) diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index e5cfc97a26264..9aaf6aad64af3 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -48,39 +48,3 @@ 'DDPShardedPlugin', 'DDPSpawnShardedPlugin', ] - -REGISTRY_TRAINING_TYPE_PLUGINS = [{ - "plugin": DeepSpeedPlugin, - "variants": [{ - "name": "deepspeed", - "description": "Default DeepSpeed Plugin" - }, { - "name": "deepspeed_stage_2", - "description": "DeepSpeed with ZeRO Stage 2 enabled", - "stage": 2 - }, { - "name": "deepspeed_stage_2_offload", - "description": "DeepSpeed with ZeRO Stage 2 enabled and Offload", - "stage": 2, - "cpu_offload": True - }, { - "name": "deepspeed_stage_3", - "description": "DeepSpeed with ZeRO Stage 3 enabled", - "stage": 3 - }, { - "name": "deepspeed_stage_3_offload", - "description": "DeepSpeed with ZeRO Stage 2 enabled and Offload", - "stage": 3, - "cpu_offload": True - }] -}] - - -def register_training_type_plugins(plugins): - for plugin_info in plugins: - plugin = plugin_info["plugin"] - for variant in plugin_info["variants"]: - TrainingTypePluginsRegistry.register(**variant) - - -register_training_type_plugins(REGISTRY_TRAINING_TYPE_PLUGINS) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 3dc52b60055d8..350e46fd396f1 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -26,6 +26,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.plugins_registry import TrainingTypePluginsRegistry from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.utilities import AMPType @@ -70,6 +71,11 @@ def _move_float_tensors_to_half(self, batch: Any): return batch +@TrainingTypePluginsRegistry.register('deepspeed') +@TrainingTypePluginsRegistry.register('deepspeed_stage_2', stage=2) +@TrainingTypePluginsRegistry.register('deepspeed_stage_2_offload', stage=2, cpu_offload=True) +@TrainingTypePluginsRegistry.register('deepspeed_stage_3', stage=3) +@TrainingTypePluginsRegistry.register('deepspeed_stage_3_offload', stage=3, cpu_offload=True) class DeepSpeedPlugin(DDPPlugin): distributed_backend = "deepspeed" DEEPSPEED_ENV_VAR = "PL_DEEPSPEED_CONFIG_PATH" From e85de51961e2292078b83cc6632cdbd38c79e678 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 14 Apr 2021 14:00:04 +0530 Subject: [PATCH 07/26] Update registry --- pytorch_lightning/plugins/plugins_registry.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py index 319205d8f733a..894244f0d3c17 100644 --- a/pytorch_lightning/plugins/plugins_registry.py +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -42,7 +42,7 @@ def __init__(self, a, b): def register( self, name: str, - func: Optional[Callable] = None, + plugin: Optional[Callable] = None, description: Optional[str] = None, override: bool = False, **init_params @@ -62,20 +62,20 @@ def register( data["init_params"] = init_params - def do_register(func): - data["func"] = func + def do_register(plugin): + data["plugin"] = plugin self[name] = data - return data + return plugin - if func is not None: - return do_register(func) + if plugin is not None: + return do_register(plugin) return do_register def get(self, name: str): if name in self: data = self[name] - return data["func"](**data["init_params"]) + return data["plugin"](**data["init_params"]) err_msg = "'{}' not found in registry. Available names: {}" available_names = ", ".join(sorted(self.keys())) or "none" From 40842af4ca23ed5274db880d631032255ff5a797 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 14 Apr 2021 14:26:40 +0530 Subject: [PATCH 08/26] add description --- .../plugins/training_type/deepspeed.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 350e46fd396f1..fdfaf458670b6 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -71,11 +71,21 @@ def _move_float_tensors_to_half(self, batch: Any): return batch -@TrainingTypePluginsRegistry.register('deepspeed') -@TrainingTypePluginsRegistry.register('deepspeed_stage_2', stage=2) -@TrainingTypePluginsRegistry.register('deepspeed_stage_2_offload', stage=2, cpu_offload=True) -@TrainingTypePluginsRegistry.register('deepspeed_stage_3', stage=3) -@TrainingTypePluginsRegistry.register('deepspeed_stage_3_offload', stage=3, cpu_offload=True) +@TrainingTypePluginsRegistry.register("deepspeed", description="Default DeepSpeed Plugin") +@TrainingTypePluginsRegistry.register("deepspeed_stage_2", description="DeepSpeed with ZeRO Stage 2 enabled", stage=2) +@TrainingTypePluginsRegistry.register( + "deepspeed_stage_2_offload", + description="DeepSpeed with ZeRO Stage 2 enabled and Offload", + stage=2, + cpu_offload=True +) +@TrainingTypePluginsRegistry.register("deepspeed_stage_3", description="DeepSpeed with ZeRO Stage 3 enabled", stage=3) +@TrainingTypePluginsRegistry.register( + "deepspeed_stage_3_offload", + description="DeepSpeed with ZeRO Stage 3 enabled and Offload", + stage=3, + cpu_offload=True +) class DeepSpeedPlugin(DDPPlugin): distributed_backend = "deepspeed" DEEPSPEED_ENV_VAR = "PL_DEEPSPEED_CONFIG_PATH" From 7ce6f86805c784e235f15a40f376c7d4bf19d041 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 14 Apr 2021 15:14:41 +0530 Subject: [PATCH 09/26] add docs --- pytorch_lightning/plugins/plugins_registry.py | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py index 894244f0d3c17..c3b5de73d0f52 100644 --- a/pytorch_lightning/plugins/plugins_registry.py +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -46,8 +46,17 @@ def register( description: Optional[str] = None, override: bool = False, **init_params - ): - + ) -> Callable: + """ + Registers a plugin mapped to a name and with required metadata. + + Args: + name (str): the name that identifies a plugin, e.g. "deepspeed_stage_3" + plugin (callable): plugin class + description (str): plugin description + override (bool): overrides the registered plugin, if True + init_params: parameters to initialize the plugin + """ if not (name is None or isinstance(name, str)): raise TypeError(f'`name` must be a str, found {name}') @@ -72,7 +81,14 @@ def do_register(plugin): return do_register - def get(self, name: str): + def get(self, name: str) -> Callable: + """ + Calls the registered plugin with the required parameters + and returns the plugin object + + Args: + name (str): the name that identifies a plugin, e.g. "deepspeed_stage_3" + """ if name in self: data = self[name] return data["plugin"](**data["init_params"]) @@ -81,10 +97,12 @@ def get(self, name: str): available_names = ", ".join(sorted(self.keys())) or "none" raise KeyError(err_msg.format(name, available_names)) - def remove(self, name: str): + def remove(self, name: str) -> None: + """Removes the registered plugin by name""" self.pop(name) def available_plugins(self) -> List: + """Returns a list of registered plugins""" return list(self.keys()) def __str__(self): From 8fb0a1652c7a1571d414f18725791a0e1e53ac62 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 14 Apr 2021 16:37:37 +0530 Subject: [PATCH 10/26] Update acc connector --- .../trainer/connectors/accelerator_connector.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index aa52ec1c40d82..65d8b9c74d2bc 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -42,6 +42,7 @@ TPUHalfPrecisionPlugin, TPUSpawnPlugin, TrainingTypePlugin, + TrainingTypePluginsRegistry, ) from pytorch_lightning.plugins.environments import ( ClusterEnvironment, @@ -163,7 +164,16 @@ def handle_given_plugins( cluster_environment = None for plug in plugins: - if isinstance(plug, str): + if isinstance(plug, str) and plug in TrainingTypePluginsRegistry: + if training_type is None: + training_type = TrainingTypePluginsRegistry.get(plug) + else: + raise MisconfigurationException( + 'You can only specify one precision and one training type plugin.' + ' Found more than 1 training type plugin:' + f' {type(TrainingTypePluginsRegistry.get(plug)).__name__}' + ) + elif isinstance(plug, str): # Reset the distributed type as the user has overridden training type # via the plugins argument self._distrib_type = None @@ -515,7 +525,7 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): rank_zero_warn( 'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.' ) - # todo: in some cases it yield in comarison None and int + # todo: in some cases it yield in comparison None and int if (self.num_nodes and self.num_nodes > 1) or (self.num_processes and self.num_processes > 1): self._distrib_type = DistributedType.DDP else: From 33e5329114a0a4d00932d4aed33000e419526124 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 14 Apr 2021 16:46:50 +0530 Subject: [PATCH 11/26] Update acc connector --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 65d8b9c74d2bc..444caff052c08 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -171,7 +171,7 @@ def handle_given_plugins( raise MisconfigurationException( 'You can only specify one precision and one training type plugin.' ' Found more than 1 training type plugin:' - f' {type(TrainingTypePluginsRegistry.get(plug)).__name__}' + f' {TrainingTypePluginsRegistry["deepspeed"]["plugin"]} registered to {plug}' ) elif isinstance(plug, str): # Reset the distributed type as the user has overridden training type From 73718d8a10e48f80f86d66e093c8bcf2389e69dd Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 14 Apr 2021 17:44:28 +0530 Subject: [PATCH 12/26] add register plugin method --- .../plugins/training_type/deepspeed.py | 33 ++++++++++--------- .../training_type/training_type_plugin.py | 4 +++ .../connectors/accelerator_connector.py | 2 +- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index fdfaf458670b6..10097e5b8c383 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -71,21 +71,6 @@ def _move_float_tensors_to_half(self, batch: Any): return batch -@TrainingTypePluginsRegistry.register("deepspeed", description="Default DeepSpeed Plugin") -@TrainingTypePluginsRegistry.register("deepspeed_stage_2", description="DeepSpeed with ZeRO Stage 2 enabled", stage=2) -@TrainingTypePluginsRegistry.register( - "deepspeed_stage_2_offload", - description="DeepSpeed with ZeRO Stage 2 enabled and Offload", - stage=2, - cpu_offload=True -) -@TrainingTypePluginsRegistry.register("deepspeed_stage_3", description="DeepSpeed with ZeRO Stage 3 enabled", stage=3) -@TrainingTypePluginsRegistry.register( - "deepspeed_stage_3_offload", - description="DeepSpeed with ZeRO Stage 3 enabled and Offload", - stage=3, - cpu_offload=True -) class DeepSpeedPlugin(DDPPlugin): distributed_backend = "deepspeed" DEEPSPEED_ENV_VAR = "PL_DEEPSPEED_CONFIG_PATH" @@ -536,3 +521,21 @@ def update_global_step(self, total_batch_idx: int, current_global_step: int) -> if total_batch_idx % self._original_accumulate_grad_batches == 0: current_global_step += 1 return current_global_step + + @classmethod + def register_plugins(cls, plugin_registry): + plugin_registry.register("deepspeed", description="Default DeepSpeed Plugin") + plugin_registry.register("deepspeed_stage_2", description="DeepSpeed with ZeRO Stage 2 enabled", stage=2) + plugin_registry.register( + "deepspeed_stage_2_offload", + description="DeepSpeed with ZeRO Stage 2 enabled and Offload", + stage=2, + cpu_offload=True + ) + plugin_registry.register("deepspeed_stage_3", description="DeepSpeed with ZeRO Stage 3 enabled", stage=3) + plugin_registry.register( + "deepspeed_stage_3_offload", + description="DeepSpeed with ZeRO Stage 3 enabled and Offload", + stage=3, + cpu_offload=True + ) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 6fd02142bf410..8ba002e1641d3 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -284,3 +284,7 @@ def call_configure_sharded_model_hook(self) -> bool: @call_configure_sharded_model_hook.setter def call_configure_sharded_model_hook(self, mode: bool) -> None: self._call_configure_sharded_model_hook = mode + + @classmethod + def register_plugins(cls, plugin_registry): + pass diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 444caff052c08..ee58bce1c8fc6 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -171,7 +171,7 @@ def handle_given_plugins( raise MisconfigurationException( 'You can only specify one precision and one training type plugin.' ' Found more than 1 training type plugin:' - f' {TrainingTypePluginsRegistry["deepspeed"]["plugin"]} registered to {plug}' + f' {TrainingTypePluginsRegistry[plug]["plugin"]} registered to {plug}' ) elif isinstance(plug, str): # Reset the distributed type as the user has overridden training type From b691237878445f5a287cfc99c196c2f1a67de639 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 14 Apr 2021 20:22:03 +0530 Subject: [PATCH 13/26] add call register plugins --- pytorch_lightning/plugins/__init__.py | 9 +++++- pytorch_lightning/plugins/plugins_registry.py | 29 +++++++++++++++++++ .../plugins/training_type/deepspeed.py | 9 +++--- 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index 9aaf6aad64af3..46bdcaa42f56d 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -1,5 +1,5 @@ from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401 -from pytorch_lightning.plugins.plugins_registry import TrainingTypePluginsRegistry +from pytorch_lightning.plugins.plugins_registry import call_register_plugins, TrainingTypePluginsRegistry from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401 @@ -48,3 +48,10 @@ 'DDPShardedPlugin', 'DDPSpawnShardedPlugin', ] + +from pathlib import Path + +FILE_ROOT = Path(__file__).parent +TRAINING_TYPE_BASE_MODULE = "pytorch_lightning.plugins.training_type" + +call_register_plugins(FILE_ROOT, TRAINING_TYPE_BASE_MODULE) diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py index c3b5de73d0f52..2e1afef760a8d 100644 --- a/pytorch_lightning/plugins/plugins_registry.py +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -11,9 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib +import os +import sys from collections import UserDict +from inspect import getmembers, isclass from typing import Callable, List, Optional +from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401 from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -110,3 +115,27 @@ def __str__(self): TrainingTypePluginsRegistry = _TrainingTypePluginsRegistry() + + +def is_register_plugins_overriden(plugin): + method_name = "register_plugins" + plugin_attr = getattr(plugin, method_name) + super_attr = getattr(TrainingTypePlugin, method_name) + + if hasattr(plugin_attr, 'patch_loader_code'): + is_overridden = plugin_attr.patch_loader_code != str(super_attr.__code__) + else: + is_overridden = plugin_attr.__code__ is not super_attr.__code__ + return is_overridden + + +def call_register_plugins(root: str, base_module: str) -> None: + for file in os.listdir(str(root) + "/training_type"): + if file.endswith(".py") and not file.startswith("_"): + module = file[:file.find(".py")] + if module not in sys.modules: + module = importlib.import_module(".".join([base_module, module])) + for mod_info in getmembers(module, isclass): + mod = mod_info[1] + if issubclass(mod, TrainingTypePlugin) and is_register_plugins_overriden(mod): + mod.register_plugins(TrainingTypePluginsRegistry) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 10097e5b8c383..af8b78acd1e0b 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -26,7 +26,6 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment -from pytorch_lightning.plugins.plugins_registry import TrainingTypePluginsRegistry from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.utilities import AMPType @@ -524,17 +523,19 @@ def update_global_step(self, total_batch_idx: int, current_global_step: int) -> @classmethod def register_plugins(cls, plugin_registry): - plugin_registry.register("deepspeed", description="Default DeepSpeed Plugin") - plugin_registry.register("deepspeed_stage_2", description="DeepSpeed with ZeRO Stage 2 enabled", stage=2) + plugin_registry.register("deepspeed", cls, description="Default DeepSpeed Plugin") + plugin_registry.register("deepspeed_stage_2", cls, description="DeepSpeed with ZeRO Stage 2 enabled", stage=2) plugin_registry.register( "deepspeed_stage_2_offload", + cls, description="DeepSpeed with ZeRO Stage 2 enabled and Offload", stage=2, cpu_offload=True ) - plugin_registry.register("deepspeed_stage_3", description="DeepSpeed with ZeRO Stage 3 enabled", stage=3) + plugin_registry.register("deepspeed_stage_3", cls, description="DeepSpeed with ZeRO Stage 3 enabled", stage=3) plugin_registry.register( "deepspeed_stage_3_offload", + cls, description="DeepSpeed with ZeRO Stage 3 enabled and Offload", stage=3, cpu_offload=True From 870bee516aabedec572c482e2a869216831c317e Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 14 Apr 2021 20:29:01 +0530 Subject: [PATCH 14/26] Update --- pytorch_lightning/plugins/__init__.py | 4 ++-- pytorch_lightning/plugins/plugins_registry.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index 46bdcaa42f56d..aaa5f83613154 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -1,5 +1,5 @@ from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401 -from pytorch_lightning.plugins.plugins_registry import call_register_plugins, TrainingTypePluginsRegistry +from pytorch_lightning.plugins.plugins_registry import call_training_type_register_plugins, TrainingTypePluginsRegistry from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401 @@ -54,4 +54,4 @@ FILE_ROOT = Path(__file__).parent TRAINING_TYPE_BASE_MODULE = "pytorch_lightning.plugins.training_type" -call_register_plugins(FILE_ROOT, TRAINING_TYPE_BASE_MODULE) +call_training_type_register_plugins(FILE_ROOT, TRAINING_TYPE_BASE_MODULE) diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py index 2e1afef760a8d..4520323a46086 100644 --- a/pytorch_lightning/plugins/plugins_registry.py +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -117,7 +117,7 @@ def __str__(self): TrainingTypePluginsRegistry = _TrainingTypePluginsRegistry() -def is_register_plugins_overriden(plugin): +def is_register_plugins_overriden(plugin) -> bool: method_name = "register_plugins" plugin_attr = getattr(plugin, method_name) super_attr = getattr(TrainingTypePlugin, method_name) @@ -129,13 +129,13 @@ def is_register_plugins_overriden(plugin): return is_overridden -def call_register_plugins(root: str, base_module: str) -> None: - for file in os.listdir(str(root) + "/training_type"): +def call_training_type_register_plugins(root: str, base_module: str) -> None: + directory = "training_type" + for file in os.listdir(root / directory): if file.endswith(".py") and not file.startswith("_"): module = file[:file.find(".py")] if module not in sys.modules: module = importlib.import_module(".".join([base_module, module])) - for mod_info in getmembers(module, isclass): - mod = mod_info[1] + for _, mod in getmembers(module, isclass): if issubclass(mod, TrainingTypePlugin) and is_register_plugins_overriden(mod): mod.register_plugins(TrainingTypePluginsRegistry) From 388e4878a4e6e76b5dacbf4315a1e4ee32aeff9e Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 15 Apr 2021 00:10:06 +0530 Subject: [PATCH 15/26] fix code format --- pytorch_lightning/plugins/plugins_registry.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py index 4520323a46086..3cdd37b7015e7 100644 --- a/pytorch_lightning/plugins/plugins_registry.py +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -16,7 +16,8 @@ import sys from collections import UserDict from inspect import getmembers, isclass -from typing import Callable, List, Optional +from pathlib import Path +from typing import Any, Callable, List, Optional from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401 from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -50,7 +51,7 @@ def register( plugin: Optional[Callable] = None, description: Optional[str] = None, override: bool = False, - **init_params + **init_params: Any, ) -> Callable: """ Registers a plugin mapped to a name and with required metadata. @@ -76,7 +77,7 @@ def register( data["init_params"] = init_params - def do_register(plugin): + def do_register(plugin: Callable) -> Callable: data["plugin"] = plugin self[name] = data return plugin @@ -110,14 +111,14 @@ def available_plugins(self) -> List: """Returns a list of registered plugins""" return list(self.keys()) - def __str__(self): + def __str__(self) -> str: return "Registered Plugins: {}".format(", ".join(self.keys())) TrainingTypePluginsRegistry = _TrainingTypePluginsRegistry() -def is_register_plugins_overriden(plugin) -> bool: +def is_register_plugins_overriden(plugin: Callable) -> bool: method_name = "register_plugins" plugin_attr = getattr(plugin, method_name) super_attr = getattr(TrainingTypePlugin, method_name) @@ -129,7 +130,7 @@ def is_register_plugins_overriden(plugin) -> bool: return is_overridden -def call_training_type_register_plugins(root: str, base_module: str) -> None: +def call_training_type_register_plugins(root: Path, base_module: str) -> None: directory = "training_type" for file in os.listdir(root / directory): if file.endswith(".py") and not file.startswith("_"): From adea1b4c0206b67c41e477254cd842ab6d2285ee Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 15 Apr 2021 00:58:49 +0530 Subject: [PATCH 16/26] add tests --- pytorch_lightning/plugins/__init__.py | 5 +- pytorch_lightning/plugins/plugins_registry.py | 2 +- tests/plugins/test_plugins_registry.py | 68 +++++++++++++++++++ 3 files changed, 73 insertions(+), 2 deletions(-) create mode 100644 tests/plugins/test_plugins_registry.py diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index aaa5f83613154..444d2aaef978b 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -1,5 +1,8 @@ from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401 -from pytorch_lightning.plugins.plugins_registry import call_training_type_register_plugins, TrainingTypePluginsRegistry +from pytorch_lightning.plugins.plugins_registry import ( # noqa: F401 + call_training_type_register_plugins, + TrainingTypePluginsRegistry, +) from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py index 3cdd37b7015e7..0aac68c084348 100644 --- a/pytorch_lightning/plugins/plugins_registry.py +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -87,7 +87,7 @@ def do_register(plugin: Callable) -> Callable: return do_register - def get(self, name: str) -> Callable: + def get(self, name: str) -> Any: """ Calls the registered plugin with the required parameters and returns the plugin object diff --git a/tests/plugins/test_plugins_registry.py b/tests/plugins/test_plugins_registry.py new file mode 100644 index 0000000000000..12c8c0957d5fb --- /dev/null +++ b/tests/plugins/test_plugins_registry.py @@ -0,0 +1,68 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from pytorch_lightning.plugins.plugins_registry import TrainingTypePluginsRegistry +from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin + + +def test_training_type_plugins_registry_with_new_plugin(): + + class TestPlugin: + + def __init__(self, param1, param2): + self.param1 = param1 + self.param2 = param2 + + plugin_name = "test_plugin" + plugin_description = "Test Plugin" + + TrainingTypePluginsRegistry.register( + plugin_name, TestPlugin, description=plugin_description, param1="abc", param2=123 + ) + + assert plugin_name in TrainingTypePluginsRegistry + assert TrainingTypePluginsRegistry[plugin_name]["description"] == plugin_description + assert TrainingTypePluginsRegistry[plugin_name]["init_params"] == {"param1": "abc", "param2": 123} + assert isinstance(TrainingTypePluginsRegistry.get(plugin_name), TestPlugin) + + TrainingTypePluginsRegistry.remove(plugin_name) + assert plugin_name not in TrainingTypePluginsRegistry + + +@pytest.mark.parametrize( + "plugin_name, init_params", + [ + ("deepspeed", {}), + ("deepspeed_stage_2", { + "stage": 2 + }), + ("deepspeed_stage_2_offload", { + "stage": 2, + "cpu_offload": True + }), + ("deepspeed_stage_3", { + "stage": 3 + }), + ("deepspeed_stage_3_offload", { + "stage": 3, + "cpu_offload": True + }), + ], +) +def test_training_type_plugins_registry_with_deepspeed_plugins(plugin_name, init_params): + + assert plugin_name in TrainingTypePluginsRegistry + assert TrainingTypePluginsRegistry[plugin_name]["init_params"] == init_params + assert TrainingTypePluginsRegistry[plugin_name]["plugin"] == DeepSpeedPlugin From 410fc612e41dd5d3495d29058d999c818efaa43c Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 15 Apr 2021 01:25:50 +0530 Subject: [PATCH 17/26] Update tests --- tests/plugins/test_plugins_registry.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/plugins/test_plugins_registry.py b/tests/plugins/test_plugins_registry.py index 12c8c0957d5fb..20be936d89d35 100644 --- a/tests/plugins/test_plugins_registry.py +++ b/tests/plugins/test_plugins_registry.py @@ -15,6 +15,8 @@ from pytorch_lightning.plugins.plugins_registry import TrainingTypePluginsRegistry from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin +from tests.helpers.boring_model import BoringModel +from tests.helpers.runif import RunIf def test_training_type_plugins_registry_with_new_plugin(): @@ -66,3 +68,20 @@ def test_training_type_plugins_registry_with_deepspeed_plugins(plugin_name, init assert plugin_name in TrainingTypePluginsRegistry assert TrainingTypePluginsRegistry[plugin_name]["init_params"] == init_params assert TrainingTypePluginsRegistry[plugin_name]["plugin"] == DeepSpeedPlugin + + +@RunIf(min_gpus=1, deepspeed=True, special=True) +@pytest.mark.parametrize("plugin", ["deepspeed", "deepspeed_stage_2_offload", "deepspeed_stage_3"]) +def test_training_type_plugins_registry_with_trainer(tmpdir, plugin): + + model = BoringModel() + + trainer = Trainer( + fast_dev_run=True, + default_root_dir=tmpdir, + plugins=plugin, + gpus=1, + precision=16, + ) + + trainer.fit(model) From 7af1b65d7a12c3c572ee342e6d02b2b35655f66e Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 15 Apr 2021 01:34:09 +0530 Subject: [PATCH 18/26] update --- pytorch_lightning/plugins/plugins_registry.py | 7 ++++++- tests/plugins/test_plugins_registry.py | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py index 0aac68c084348..3a568698f7f60 100644 --- a/pytorch_lightning/plugins/plugins_registry.py +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -19,7 +19,7 @@ from pathlib import Path from typing import Any, Callable, List, Optional -from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -43,6 +43,10 @@ class LightningPlugin: def __init__(self, a, b): ... + or + + TrainingTypePluginsRegistry.register("lightning", LightningPlugin, description="Super fast", a=1, b=True) + """ def register( @@ -131,6 +135,7 @@ def is_register_plugins_overriden(plugin: Callable) -> bool: def call_training_type_register_plugins(root: Path, base_module: str) -> None: + # Ref: https://github.com/facebookresearch/ClassyVision/blob/master/classy_vision/generic/registry_utils.py#L14 directory = "training_type" for file in os.listdir(root / directory): if file.endswith(".py") and not file.startswith("_"): diff --git a/tests/plugins/test_plugins_registry.py b/tests/plugins/test_plugins_registry.py index 20be936d89d35..617be30a15309 100644 --- a/tests/plugins/test_plugins_registry.py +++ b/tests/plugins/test_plugins_registry.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest +from pytorch_lightning import Trainer from pytorch_lightning.plugins.plugins_registry import TrainingTypePluginsRegistry from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin from tests.helpers.boring_model import BoringModel From 80bc1c564b1666e9be202928d76b8059c8b4fc00 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 15 Apr 2021 02:42:44 +0530 Subject: [PATCH 19/26] update --- tests/plugins/test_plugins_registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/test_plugins_registry.py b/tests/plugins/test_plugins_registry.py index 617be30a15309..fb6f3b0cba16a 100644 --- a/tests/plugins/test_plugins_registry.py +++ b/tests/plugins/test_plugins_registry.py @@ -14,7 +14,7 @@ import pytest from pytorch_lightning import Trainer -from pytorch_lightning.plugins.plugins_registry import TrainingTypePluginsRegistry +from pytorch_lightning.plugins import TrainingTypePluginsRegistry from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf From 6ad2c62dc2a2a9b95ed19a25d9e6a4d0c4da7735 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 16 Apr 2021 00:56:15 +0530 Subject: [PATCH 20/26] update tests --- pytorch_lightning/plugins/plugins_registry.py | 10 +++++----- tests/plugins/test_plugins_registry.py | 9 ++------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py index 3a568698f7f60..847a04efe7559 100644 --- a/pytorch_lightning/plugins/plugins_registry.py +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -140,8 +140,8 @@ def call_training_type_register_plugins(root: Path, base_module: str) -> None: for file in os.listdir(root / directory): if file.endswith(".py") and not file.startswith("_"): module = file[:file.find(".py")] - if module not in sys.modules: - module = importlib.import_module(".".join([base_module, module])) - for _, mod in getmembers(module, isclass): - if issubclass(mod, TrainingTypePlugin) and is_register_plugins_overriden(mod): - mod.register_plugins(TrainingTypePluginsRegistry) + module = importlib.import_module(".".join([base_module, module])) + for _, mod in getmembers(module, isclass): + if issubclass(mod, TrainingTypePlugin) and is_register_plugins_overriden(mod): + mod.register_plugins(TrainingTypePluginsRegistry) + break diff --git a/tests/plugins/test_plugins_registry.py b/tests/plugins/test_plugins_registry.py index fb6f3b0cba16a..91d9596578dfc 100644 --- a/tests/plugins/test_plugins_registry.py +++ b/tests/plugins/test_plugins_registry.py @@ -16,7 +16,6 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins import TrainingTypePluginsRegistry from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin -from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -71,18 +70,14 @@ def test_training_type_plugins_registry_with_deepspeed_plugins(plugin_name, init assert TrainingTypePluginsRegistry[plugin_name]["plugin"] == DeepSpeedPlugin -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(deepspeed=True) @pytest.mark.parametrize("plugin", ["deepspeed", "deepspeed_stage_2_offload", "deepspeed_stage_3"]) def test_training_type_plugins_registry_with_trainer(tmpdir, plugin): - model = BoringModel() - trainer = Trainer( - fast_dev_run=True, default_root_dir=tmpdir, plugins=plugin, - gpus=1, precision=16, ) - trainer.fit(model) + assert isinstance(trainer.training_type_plugin, DeepSpeedPlugin) From a2010872667620070dda65002cdd43182ba44f53 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 16 Apr 2021 00:59:46 +0530 Subject: [PATCH 21/26] fix flake8 --- pytorch_lightning/plugins/plugins_registry.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py index 847a04efe7559..aad10d3eae4d4 100644 --- a/pytorch_lightning/plugins/plugins_registry.py +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -13,7 +13,6 @@ # limitations under the License. import importlib import os -import sys from collections import UserDict from inspect import getmembers, isclass from pathlib import Path From ef080a8715fe9bb70ee2405ac4d0faa6122c85a1 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 16 Apr 2021 11:59:30 +0100 Subject: [PATCH 22/26] Update pytorch_lightning/plugins/training_type/deepspeed.py Co-authored-by: Sean Naren --- pytorch_lightning/plugins/training_type/deepspeed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index af8b78acd1e0b..c50f0e1d97bdd 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -528,7 +528,7 @@ def register_plugins(cls, plugin_registry): plugin_registry.register( "deepspeed_stage_2_offload", cls, - description="DeepSpeed with ZeRO Stage 2 enabled and Offload", + description="DeepSpeed ZeRO Stage 2 and CPU Offload", stage=2, cpu_offload=True ) From 57c609f5cd0cdefef5d849c17d21446ee8bb4ddf Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Fri, 16 Apr 2021 16:40:50 +0530 Subject: [PATCH 23/26] Update pytorch_lightning/plugins/training_type/deepspeed.py Co-authored-by: Sean Naren --- pytorch_lightning/plugins/training_type/deepspeed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index c50f0e1d97bdd..1e18491c9ee22 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -532,7 +532,7 @@ def register_plugins(cls, plugin_registry): stage=2, cpu_offload=True ) - plugin_registry.register("deepspeed_stage_3", cls, description="DeepSpeed with ZeRO Stage 3 enabled", stage=3) + plugin_registry.register("deepspeed_stage_3", cls, description="DeepSpeed ZeRO Stage 3", stage=3) plugin_registry.register( "deepspeed_stage_3_offload", cls, From 844d2a9982029054d45e817d2376eb9ccd58d2d4 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Fri, 16 Apr 2021 16:40:58 +0530 Subject: [PATCH 24/26] Update pytorch_lightning/plugins/training_type/deepspeed.py Co-authored-by: Sean Naren --- pytorch_lightning/plugins/training_type/deepspeed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 1e18491c9ee22..f3af6346120f8 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -536,7 +536,7 @@ def register_plugins(cls, plugin_registry): plugin_registry.register( "deepspeed_stage_3_offload", cls, - description="DeepSpeed with ZeRO Stage 3 enabled and Offload", + description="DeepSpeed ZeRO Stage 3 and CPU Offload", stage=3, cpu_offload=True ) From 89b1e390c352e38626dcb4cf57cee4c5abdb7b96 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 16 Apr 2021 16:51:58 +0530 Subject: [PATCH 25/26] Apply code suggestions --- pytorch_lightning/plugins/plugins_registry.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py index aad10d3eae4d4..16defd6253264 100644 --- a/pytorch_lightning/plugins/plugins_registry.py +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -27,7 +27,7 @@ class _TrainingTypePluginsRegistry(UserDict): This class is a Registry that stores information about the Training Type Plugins. The Plugins are mapped to strings. These strings are names that idenitify - a plugin, eg., "deepspeed". It also returns Optional description and + a plugin, e.g., "deepspeed". It also returns Optional description and parameters to initialize the Plugin, which were defined durng the registeration. @@ -60,10 +60,10 @@ def register( Registers a plugin mapped to a name and with required metadata. Args: - name (str): the name that identifies a plugin, e.g. "deepspeed_stage_3" - plugin (callable): plugin class - description (str): plugin description - override (bool): overrides the registered plugin, if True + name : the name that identifies a plugin, e.g. "deepspeed_stage_3" + plugin : plugin class + description : plugin description + override : overrides the registered plugin, if True init_params: parameters to initialize the plugin """ if not (name is None or isinstance(name, str)): @@ -121,7 +121,7 @@ def __str__(self) -> str: TrainingTypePluginsRegistry = _TrainingTypePluginsRegistry() -def is_register_plugins_overriden(plugin: Callable) -> bool: +def is_register_plugins_overridden(plugin: Callable) -> bool: method_name = "register_plugins" plugin_attr = getattr(plugin, method_name) super_attr = getattr(TrainingTypePlugin, method_name) @@ -141,6 +141,6 @@ def call_training_type_register_plugins(root: Path, base_module: str) -> None: module = file[:file.find(".py")] module = importlib.import_module(".".join([base_module, module])) for _, mod in getmembers(module, isclass): - if issubclass(mod, TrainingTypePlugin) and is_register_plugins_overriden(mod): + if issubclass(mod, TrainingTypePlugin) and is_register_plugins_overridden(mod): mod.register_plugins(TrainingTypePluginsRegistry) break From 47a6d70f1aa07f5ee6007dd482e796a0ab922ec6 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 16 Apr 2021 16:54:18 +0530 Subject: [PATCH 26/26] fix typo --- pytorch_lightning/plugins/plugins_registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py index 16defd6253264..59dd7d8db6bff 100644 --- a/pytorch_lightning/plugins/plugins_registry.py +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -29,7 +29,7 @@ class _TrainingTypePluginsRegistry(UserDict): The Plugins are mapped to strings. These strings are names that idenitify a plugin, e.g., "deepspeed". It also returns Optional description and parameters to initialize the Plugin, which were defined durng the - registeration. + registration. The motivation for having a TrainingTypePluginRegistry is to make it convenient for the Users to try different Plugins by passing just strings