Skip to content

2/n Simplify spawn plugins: Spawn immediately #10896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 119 commits into from
Dec 9, 2021
Merged
Show file tree
Hide file tree
Changes from 110 commits
Commits
Show all changes
119 commits
Select commit Hold shift + click to select a range
d05363f
improve spawn queue
awaelchli Oct 20, 2021
d650e26
clean up
awaelchli Oct 20, 2021
5fda23a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 20, 2021
d6b4a34
Merge branch 'master' into feature/simple-spawn
awaelchli Nov 30, 2021
bcfb853
fix
awaelchli Nov 30, 2021
97b4bf6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 30, 2021
38b3a54
rename
awaelchli Nov 30, 2021
955b6c8
delete dead code
awaelchli Nov 30, 2021
13393e8
Merge remote-tracking branch 'origin/feature/simple-spawn' into featu…
awaelchli Nov 30, 2021
f3216b2
clean up
awaelchli Nov 30, 2021
2d00231
update lite
awaelchli Nov 30, 2021
7aa3646
retain the queue interface in hooks
awaelchli Nov 30, 2021
fb0c0d8
update tests
awaelchli Nov 30, 2021
1bc59ae
Merge branch 'master' into feature/simple-spawn
awaelchli Nov 30, 2021
7e6c75e
_notebooks
awaelchli Nov 30, 2021
b7efc50
reset notebooks
awaelchli Nov 30, 2021
84ca8b4
avoid circular import
awaelchli Nov 30, 2021
965c724
fix unused imports
awaelchli Nov 30, 2021
1aae8dd
reset debugging script
awaelchli Nov 30, 2021
4b998db
typing _ExtraQueue
awaelchli Nov 30, 2021
5871a4b
bring changes to tpu_spawn plugin
awaelchli Nov 30, 2021
aa76840
unify
awaelchli Nov 30, 2021
37f9db9
remove dead code
awaelchli Nov 30, 2021
d68cb35
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 30, 2021
dd80be9
remove queue from tpu spawn
awaelchli Nov 30, 2021
f97eee8
type annotation for new_process
awaelchli Nov 30, 2021
ad61d27
Merge remote-tracking branch 'origin/feature/simple-spawn' into refac…
awaelchli Nov 30, 2021
459121e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 30, 2021
72535ff
unused imports
awaelchli Nov 30, 2021
3095da9
Merge remote-tracking branch 'origin/feature/simple-spawn' into refac…
awaelchli Nov 30, 2021
61192df
move check
awaelchli Nov 30, 2021
801f529
revert
awaelchli Nov 30, 2021
1cd258b
collect results on tpu
awaelchli Nov 30, 2021
ae6019e
Merge branch 'master' into refactor/spawn/simple-spawn
awaelchli Nov 30, 2021
10ecbfd
rename
awaelchli Nov 30, 2021
ebba63f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 30, 2021
d7df4d9
fix merge errors
awaelchli Nov 30, 2021
4c547aa
fix merge errors
awaelchli Nov 30, 2021
e4e2a77
re-add clean_logger
awaelchli Dec 1, 2021
86e43b2
Merge branch 'master' into refactor/spawn/simple-spawn
awaelchli Dec 1, 2021
acac29d
fix typing
awaelchli Dec 1, 2021
0ae457a
Merge branch 'master' into refactor/spawn/simple-spawn
awaelchli Dec 1, 2021
880c8fc
changelog entries
awaelchli Dec 1, 2021
5eeb02a
Merge branch 'master' into refactor/spawn/simple-spawn
awaelchli Dec 1, 2021
7520adc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 1, 2021
96f2749
rename _ExtraQueue -> _FakeQueue
awaelchli Dec 1, 2021
65d183c
missing typing updates
awaelchli Dec 1, 2021
8c4e2e4
Introducing NamedTuple for spawn output typing
awaelchli Dec 1, 2021
213b447
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 1, 2021
de4617f
remove post_dispatch
awaelchli Dec 2, 2021
815172e
step 1
awaelchli Dec 2, 2021
be735bd
update flow
awaelchli Dec 2, 2021
2879ccb
fix it
awaelchli Dec 2, 2021
ace196e
jackpot!
awaelchli Dec 2, 2021
4ff41a9
Merge branch 'master' into refactor/spawn/dispatch
awaelchli Dec 2, 2021
34a889a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 2, 2021
ad3f39d
update sharded and tests
awaelchli Dec 2, 2021
c897a20
pull down spawn call
awaelchli Dec 2, 2021
90054cf
simplify test
awaelchli Dec 2, 2021
009abfa
attach model as early as possible
awaelchli Dec 2, 2021
376e4fe
demonstrate which tests fails
awaelchli Dec 2, 2021
de1811e
set module
awaelchli Dec 3, 2021
ef61a0b
update exception
awaelchli Dec 3, 2021
809014a
imports
awaelchli Dec 3, 2021
440b639
transfer trainer state
awaelchli Dec 3, 2021
ab5559e
fix problem with getqueue
awaelchli Dec 3, 2021
f4f1269
deprecation calls don't come through ddp_spawn
awaelchli Dec 3, 2021
b30c352
prepare data only gets called on rank 0
awaelchli Dec 3, 2021
5434ae5
import
awaelchli Dec 3, 2021
24f05f1
update test
awaelchli Dec 3, 2021
3959955
update exception
awaelchli Dec 3, 2021
f491abe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 3, 2021
0c808ce
adapt tpu spawn
awaelchli Dec 3, 2021
d6dd343
imports
awaelchli Dec 3, 2021
63e6e21
Merge remote-tracking branch 'origin/refactor/spawn/dispatch' into re…
awaelchli Dec 3, 2021
15dabb8
Merge branch 'master' into refactor/spawn/dispatch
awaelchli Dec 3, 2021
b047687
update
awaelchli Dec 3, 2021
c524e52
add missing arg
awaelchli Dec 3, 2021
223e7aa
fix exception import on torch < 1.8
awaelchli Dec 3, 2021
ed309d6
debug
awaelchli Dec 3, 2021
12eed61
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 3, 2021
9401e66
Merge branch 'master' into refactor/spawn/dispatch
awaelchli Dec 3, 2021
be73261
debug tpu
awaelchli Dec 3, 2021
c71fc57
fix docs
awaelchli Dec 3, 2021
2ed6333
fix teardown being called twice
awaelchli Dec 3, 2021
2a8b9b4
revert a sate check
awaelchli Dec 3, 2021
5335664
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 3, 2021
5c7e159
Merge remote-tracking branch 'origin/refactor/spawn/dispatch' into re…
awaelchli Dec 3, 2021
93cfaf8
fix
awaelchli Dec 3, 2021
26408b8
Merge branch 'master' into refactor/spawn/dispatch
awaelchli Dec 4, 2021
70b332d
Merge branch 'master' into refactor/spawn/dispatch
awaelchli Dec 6, 2021
d9669c7
Merge branch 'master' into refactor/spawn/dispatch
awaelchli Dec 6, 2021
3663bd7
Merge branch 'master' into refactor/spawn/dispatch
awaelchli Dec 6, 2021
3d81c11
Merge branch 'master' into refactor/spawn/dispatch
awaelchli Dec 6, 2021
fb47802
Merge branch 'master' into refactor/spawn/dispatch
awaelchli Dec 6, 2021
dde5a3a
reset bug report model
awaelchli Dec 6, 2021
77329b2
fix merge error
awaelchli Dec 6, 2021
eb05fc9
barrier clean ups
awaelchli Dec 7, 2021
dbcb76c
update comments in trainer
awaelchli Dec 7, 2021
ed0defa
unused import
awaelchli Dec 7, 2021
79975f2
debug
awaelchli Dec 7, 2021
d5ec0b7
changelog
awaelchli Dec 7, 2021
b2f8347
update changelog
awaelchli Dec 7, 2021
d8e6218
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2021
436572b
update changelog
awaelchli Dec 7, 2021
a3bc1b1
Update tests/trainer/test_trainer.py
awaelchli Dec 7, 2021
b2ce8eb
Merge remote-tracking branch 'origin/refactor/spawn/dispatch' into re…
awaelchli Dec 7, 2021
bafd95c
add clarification comment
awaelchli Dec 8, 2021
338605a
update changelog
awaelchli Dec 8, 2021
c992a55
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 8, 2021
ac1428d
skip test that can't run on too old torch version on windows
awaelchli Dec 8, 2021
77ee0ef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 8, 2021
c7dd23d
remove todo
awaelchli Dec 8, 2021
18f32bc
Merge branch 'master' into refactor/spawn/dispatch
awaelchli Dec 9, 2021
8af6d48
remove obsolete start_* methods from TTP
awaelchli Dec 9, 2021
f181c59
update changelog
awaelchli Dec 9, 2021
8599673
update user guide inside _run() code
awaelchli Dec 9, 2021
d51f482
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 9, 2021
bef2416
fix call
awaelchli Dec 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the name of the temporary checkpoint that the `DDPSpawnPlugin` and related plugins save ([#10934](https://github.com/PyTorchLightning/pytorch-lightning/pull/10934))


- Redesigned process creation for spawn-based plugins (`DDPSpawnPlugin`, `TPUSpawnPlugin`, etc.) ([#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896))
* All spawn-based plugins now spawn processes immediately upon calling `Trainer.{fit,validate,test,predict}`
* The hooks/callbacks `prepare_data`, `setup`, `configure_sharded_model` and `teardown` now run under initialized process group for spawn-based plugins just like their non-spawn counterparts
* Some configuration errors that were previously raised as `MisconfigurationException`s will now be raised as `ProcessRaisedException` (torch>=1.8) or as `Exception` (torch<1.8)


### Deprecated

- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103))
Expand Down
32 changes: 4 additions & 28 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,23 +132,6 @@ def set_world_ranks(self, process_idx: int = 0) -> None:
def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]:
return {"nprocs": self.num_processes}

def start_training(self, trainer: "pl.Trainer") -> Any:
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
self._recover_results_in_main_process(spawn_output, trainer)
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
trainer.optimizers = []
return spawn_output.trainer_results

def start_evaluating(self, trainer: "pl.Trainer") -> Any:
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
self._recover_results_in_main_process(spawn_output, trainer)
return spawn_output.trainer_results

def start_predicting(self, trainer: "pl.Trainer") -> Any:
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
self._recover_results_in_main_process(spawn_output, trainer)
return spawn_output.trainer_results

def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
"""Spawn processes that run the given function.

Expand Down Expand Up @@ -184,7 +167,9 @@ def _worker_setup(self, process_idx: int):
self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size
)

def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]:
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
super().pre_dispatch(trainer)

# move the model to the correct device
self.model_to_device()

Expand All @@ -196,15 +181,6 @@ def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]:
if trainer_fn == TrainerFn.FITTING:
self.configure_ddp()

self.barrier()

results = trainer.run_stage()
outputs = self._collect_rank_zero_results(trainer, results)

# ensure that spawned processes go through teardown before joining
trainer._call_teardown_hook()
return outputs

def pre_configure_ddp(self):
# if unset, default `find_unused_parameters` `True`
# Many models require setting this parameter to True, as there are corner cases
Expand Down Expand Up @@ -268,7 +244,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt

return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra)

def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer) -> None:
def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None:
# transfer back the best path to the trainer
if trainer.checkpoint_callback:
trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp_spawn import _SpawnOutput, DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.enums import _StrategyType
Expand Down Expand Up @@ -114,12 +114,12 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None:
def post_training_step(self):
pass

def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]:
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
# Ensure that the scaler points to the correct process group
# which is re-initialized in a new process
if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin):
self._precision_plugin.scaler = ShardedGradScaler()
return super().new_process(trainer)
return super().pre_dispatch(trainer)

@classmethod
def register_plugins(cls, plugin_registry: Dict) -> None:
Expand Down
84 changes: 22 additions & 62 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
Expand Down Expand Up @@ -118,10 +117,23 @@ def connect(self, model: "pl.LightningModule") -> None:
return super().connect(model)

def pre_dispatch(self, trainer: "pl.Trainer") -> None:
super().pre_dispatch(trainer)
self._move_optimizer_state()
if self.debug:
os.environ["PT_XLA_DEBUG"] = str(1)

if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
trainer.progress_bar_callback.disable()

shared_params = find_shared_parameters(self.model)
self.model_to_device()
if is_overridden("on_post_move_to_device", self.lightning_module):
self.model.module.on_post_move_to_device()
else:
set_shared_parameters(self.model.module, shared_params)

self.setup_optimizers(trainer)
self.precision_plugin.connect(self._model, None, None)

def setup(self, trainer: "pl.Trainer") -> None:
self.start_method = "fork"
super().setup(trainer)
Expand Down Expand Up @@ -154,37 +166,6 @@ def init_dist_connection(self, global_rank: int, world_size: int) -> None:
def set_world_ranks(self, process_idx: int = 0) -> None:
pass

def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]:
if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
trainer.progress_bar_callback.disable()

shared_params = find_shared_parameters(self.model)
self.model_to_device()
if is_overridden("on_post_move_to_device", self.lightning_module):
self.model.module.on_post_move_to_device()
else:
set_shared_parameters(self.model.module, shared_params)

trainer.training_type_plugin.setup_optimizers(trainer)
trainer.precision_plugin.connect(self._model, None, None)

self.barrier("pre-run-stage")

results = trainer.run_stage()

outputs = self._collect_rank_zero_results(trainer, results)

# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
self.barrier("end-process")

# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
if self.local_rank == 0:
time.sleep(2)

# ensure that spawned processes go through teardown before joining
trainer._call_teardown_hook()
return outputs

def model_to_device(self) -> None:
self.model = self.wrapped_model.to(self.root_device)

Expand Down Expand Up @@ -215,8 +196,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt
if is_overridden("add_to_queue", self.lightning_module):
# TODO: Remove the if in v1.7
self.lightning_module.add_to_queue(extra)
else:
self.add_to_queue(trainer, extra)
self.add_to_queue(trainer, extra)

return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra)

Expand Down Expand Up @@ -263,6 +243,10 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st
}

def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
# TODO: this todo is unclear, does it still apply?
# todo: precision pluging is call in accelerator setup and should be moved
if "XLA_USE_BF16" in os.environ:
del os.environ["XLA_USE_BF16"]
context = mp.get_context(self.start_method or "fork")
return_queue = context.SimpleQueue()
xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs())
Expand All @@ -276,7 +260,10 @@ def _wrapped_function(
if self.local_rank == 0:
return_queue.put(move_data_to_device(result, "cpu"))

# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
self.barrier("end-process")

# Ensure that the rank 0 process is the one exiting last
# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
if self.local_rank == 0:
time.sleep(2)
Expand All @@ -287,21 +274,6 @@ def _worker_setup(self, process_idx: int):
self.tpu_global_core_rank = xm.get_ordinal()
rank_zero_only.rank = self.global_rank

def start_training(self, trainer: "pl.Trainer") -> Any:
# todo: precision pluging is call in accelerator setup and should be moved
if "XLA_USE_BF16" in os.environ:
del os.environ["XLA_USE_BF16"]
self._clean_logger(trainer)
return super().start_training(trainer)

def start_evaluating(self, trainer: "pl.Trainer") -> Any:
self._clean_logger(trainer)
return super().start_evaluating(trainer)

def start_predicting(self, trainer: "pl.Trainer") -> Any:
self._clean_logger(trainer)
return super().start_predicting(trainer)

def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.val_step_context():
return self.model(*args, **kwargs)
Expand Down Expand Up @@ -358,9 +330,7 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
return xm.all_gather(tensor)

def teardown(self) -> None:
# TPU teardown
os.environ.pop("PT_XLA_DEBUG", None)
self.barrier("teardown")

@property
def should_rank_save_checkpoint(self) -> bool:
Expand All @@ -377,13 +347,3 @@ def checkpoint_io(self) -> CheckpointIO:
@checkpoint_io.setter
def checkpoint_io(self, plugin: CheckpointIO) -> None:
raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.")

@staticmethod
def _clean_logger(trainer: "pl.Trainer") -> None:
loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger]
for logger in loggers:
if isinstance(logger, TensorBoardLogger) and logger._experiment is not None:
# the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang.
# we want to make sure these are closed before we spawn our own threads.
# assuming nothing else references the experiment object, python should instantly `__del__` it.
logger._experiment = None
Loading