Skip to content

Restrict setup methods to accept a single model #10064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 25, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- LightningLite:
* Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988))
* Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018), [#10022](https://github.com/PyTorchLightning/pytorch-lightning/pull/10022))
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994))
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994), [#10064](https://github.com/PyTorchLightning/pytorch-lightning/pull/10064))
* Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010))
* Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009))
* Implemented `{DDPShardedPlugin,DDPShardedSpawnPlugin}._setup_models_and_optimizers` ([#10028](https://github.com/PyTorchLightning/pytorch-lightning/pull/10028))
* Implemented `DeepSpeedPlugin._setup_model_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009), [#10064](https://github.com/PyTorchLightning/pytorch-lightning/pull/10064))
* Implemented `{DDPShardedPlugin,DDPShardedSpawnPlugin}._setup_model_and_optimizers` ([#10028](https://github.com/PyTorchLightning/pytorch-lightning/pull/10028), [#10064](https://github.com/PyTorchLightning/pytorch-lightning/pull/10064))
* Added optional `model` argument to the `optimizer_step` methods in accelerators and plugins ([#10023](https://github.com/PyTorchLightning/pytorch-lightning/pull/10023))


Expand Down
20 changes: 9 additions & 11 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,30 +379,28 @@ def pre_dispatch(self):
self.init_deepspeed()
self.barrier()

def _setup_models_and_optimizers(
self, models: List[Module], optimizers: List[Optimizer]
) -> Tuple[List[Module], List[Optimizer]]:
"""Setup multiple models and multiple optimizers together.
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Setup a model and multiple optimizers together.

Currently only one model paired with a single optimizer is supported.
Currently only a single optimizer is supported.

Return:
A list with one model wrapped into a :class:`deepspeed.DeepSpeedEngine` and list with a single
The model wrapped into a :class:`deepspeed.DeepSpeedEngine` and a list with a single
deepspeed optimizer.
"""
if not (len(models) == len(optimizers) == 1):
if len(optimizers) != 1:
raise ValueError(
f"Currently only one model and one optimizer is supported with DeepSpeed."
f" Got {len(models)} models and {len(optimizers)} optimizers instead."
f"Currently only one optimizer is supported with DeepSpeed."
f" Got {len(optimizers)} optimizers instead."
)

# train_micro_batch_size_per_gpu is used for throughput logging purposes
# normally we set this to the batch size, but it is not available here unless the user provides it
# as part of the config
self.config.setdefault("train_micro_batch_size_per_gpu", 1)
self._model, optimizer = self._setup_model_and_optimizer(models[0], optimizers[0])
self._model, optimizer = self._setup_model_and_optimizer(model, optimizers[0])
self._set_deepspeed_activation_checkpointing()
return [self._model], [optimizer]
return self._model, [optimizer]

def _setup_model_and_optimizer(
self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[_LRScheduler] = None
Expand Down
22 changes: 6 additions & 16 deletions pytorch_lightning/plugins/training_type/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,34 +47,24 @@ def configure_ddp(self) -> None:
# For multi-node training, enabling bucketing will improve performance.
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0

[self._model], optimizers = self._setup_models_and_optimizers(
models=[LightningShardedDataParallel(self.model)],
self._model, optimizers = self._setup_model_and_optimizers(
model=LightningShardedDataParallel(self.model),
optimizers=trainer.optimizers,
)
trainer.optimizers = optimizers
trainer.convert_to_lightning_optimizers()

def _setup_models_and_optimizers(
self, models: List[Module], optimizers: List[Optimizer]
) -> Tuple[List[Module], List[Optimizer]]:
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Wraps the model and optimizers with fairscale components.

Currently only one model can be setup at once.

Return:
A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
The model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
"""
if len(models) > 1:
raise ValueError(
"DDPSharded only supports setting up a single model with one or several optimizers."
f" Got {len(models)} models."
)

optimizers = self._wrap_optimizers(optimizers)
model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs)
model = ShardedDataParallel(model, sharded_optimizer=optimizers, **self._ddp_kwargs)
setattr(model, "require_backward_grad_sync", False) # TODO: needed?
return [model], optimizers
return model, optimizers

def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]:
for x, optimizer in enumerate(optimizers):
Expand Down
22 changes: 6 additions & 16 deletions pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,33 +39,23 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):

def configure_ddp(self) -> None:
trainer = self.lightning_module.trainer
[self._model], optimizers = self._setup_models_and_optimizers(
models=[LightningShardedDataParallel(self.model)],
self._model, optimizers = self._setup_model_and_optimizers(
model=LightningShardedDataParallel(self.model),
optimizers=trainer.optimizers,
)
trainer.optimizers = optimizers

def _setup_models_and_optimizers(
self, models: List[Module], optimizers: List[Optimizer]
) -> Tuple[List[Module], List[Optimizer]]:
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Wraps the model and optimizers with fairscale components.

Currently only one model can be setup at once.

Return:
A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
The model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
"""
if len(models) > 1:
raise ValueError(
f"DDPShardedSpawn only supports setting up a single model with one or several optimizers."
f" Got {len(models)} models."
)

optimizers = self._wrap_optimizers(optimizers)
model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs)
model = ShardedDataParallel(model, sharded_optimizer=optimizers, **self._ddp_kwargs)
setattr(model, "require_backward_grad_sync", False) # TODO: needed?
return [model], optimizers
return model, optimizers

def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]:
for x, optimizer in enumerate(optimizers):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,16 @@ def setup_environment(self) -> None:
def setup(self) -> None:
"""Called by the accelerator to finish setup."""

def _setup_models_and_optimizers(
self, models: List[Module], optimizers: List[Optimizer]
) -> Tuple[List[Module], List[Optimizer]]:
"""Setup multiple models and multiple optimizers together.
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Setup a model and multiple optimizers together.

The returned objects are expected to be in the same order they were passed in. The default implementation will
call :meth:`_setup_model` and :meth:`_setup_optimizer` on the input lists.
call :meth:`_setup_model` and :meth:`_setup_optimizer` on the inputs.
"""
# TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324
models = [self._setup_model(model) for model in models]
model = self._setup_model(model)
optimizers = [self._setup_optimizer(optimizer) for optimizer in optimizers]
return models, optimizers
return model, optimizers

def _setup_model(self, model: Module) -> Module:
"""Performs setup for the model, e.g., by wrapping it by another class."""
Expand Down