Skip to content

Commit 400bc59

Browse files
authored
Merge pull request #5 from abhilash1910/patch-3
add seeding for pytorch utilities
2 parents 9487579 + fa9da18 commit 400bc59

File tree

3 files changed

+33
-4
lines changed

3 files changed

+33
-4
lines changed

src/lightning/fabric/strategies/launchers/multiprocessing.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,13 @@
2424
from lightning.fabric.utilities.apply_func import move_data_to_device
2525
from lightning.fabric.utilities.imports import _IS_INTERACTIVE
2626
from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states
27+
from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE
2728

2829
if TYPE_CHECKING:
2930
from lightning.fabric.strategies import ParallelStrategy
30-
31+
32+
if _LIGHTNING_XPU_AVAILABLE:
33+
from lightning_xpu.fabric import XPUAccelerator
3134

3235
class _MultiProcessingLauncher(_Launcher):
3336
r"""Launches processes that run a given function in parallel, and joins them all at the end.
@@ -85,6 +88,8 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
8588
"""
8689
if self._start_method in ("fork", "forkserver"):
8790
_check_bad_cuda_fork()
91+
if XPUAccelerator.is_available():
92+
_check_bad_xpu_fork()
8893

8994
# The default cluster environment in Lightning chooses a random free port number
9095
# This needs to be done in the main process here before starting processes to ensure each rank will connect
@@ -187,3 +192,21 @@ def _check_bad_cuda_fork() -> None:
187192
if _IS_INTERACTIVE:
188193
message += " You will have to restart the Python kernel."
189194
raise RuntimeError(message)
195+
196+
def _check_bad_xpu_fork() -> None:
197+
"""Checks whether it is safe to fork and initialize XPU in the new processes, and raises an exception if not.
198+
199+
The error message replaces PyTorch's 'Cannot re-initialize XPU in forked subprocess' with helpful advice for
200+
Lightning users.
201+
"""
202+
if not XPUAccelerator.is_xpu_initialized():
203+
return
204+
205+
message = (
206+
"Lightning can't create new processes if XPU is already initialized. Did you manually call"
207+
" `torch.xpu.*` functions, have moved the model to the device, or allocated memory on the GPU any"
208+
" other way? Please remove any such calls, or change the selected strategy."
209+
)
210+
if _IS_INTERACTIVE:
211+
message += " You will have to restart the Python kernel."
212+
raise RuntimeError(message)

src/lightning/pytorch/strategies/launchers/multiprocessing.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,19 @@
2727
from torch import Tensor
2828

2929
import lightning.pytorch as pl
30-
from lightning.fabric.strategies.launchers.multiprocessing import _check_bad_cuda_fork
30+
from lightning.fabric.strategies.launchers.multiprocessing import _check_bad_cuda_fork, _check_bad_xpu_fork
3131
from lightning.fabric.utilities import move_data_to_device
3232
from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states
3333
from lightning.fabric.utilities.types import _PATH
3434
from lightning.pytorch.strategies.launchers.launcher import _Launcher
3535
from lightning.pytorch.trainer.connectors.signal_connector import _SIGNUM
3636
from lightning.pytorch.trainer.states import TrainerFn, TrainerState
3737
from lightning.pytorch.utilities.rank_zero import rank_zero_debug
38+
from lightning.pytorch.utilities.imports import _LIGHTNING_XPU_AVAILABLE
3839

40+
if _LIGHTNING_XPU_AVAILABLE:
41+
from lightning_xpu.pytorch import XPUAccelerator
42+
3943
log = logging.getLogger(__name__)
4044

4145

@@ -97,6 +101,8 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
97101
self._check_torchdistx_support()
98102
if self._start_method in ("fork", "forkserver"):
99103
_check_bad_cuda_fork()
104+
if XPUAccelerator.is_available():
105+
_check_bad_xpu_fork()
100106

101107
# The default cluster environment in Lightning chooses a random free port number
102108
# This needs to be done in the main process here before starting processes to ensure each rank will connect

src/lightning/pytorch/utilities/seed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020

2121
@contextmanager
22-
def isolate_rng(include_cuda: bool = True) -> Generator[None, None, None]:
22+
def isolate_rng(include_cuda: bool = True, include_xpu: bool = True) -> Generator[None, None, None]:
2323
"""A context manager that resets the global random state on exit to what it was before entering.
2424
2525
It supports isolating the states for PyTorch, Numpy, and Python built-in random number generators.
@@ -39,6 +39,6 @@ def isolate_rng(include_cuda: bool = True) -> Generator[None, None, None]:
3939
>>> torch.rand(1)
4040
tensor([0.7576])
4141
"""
42-
states = _collect_rng_states(include_cuda)
42+
states = _collect_rng_states(include_cuda, include_xpu)
4343
yield
4444
_set_rng_states(states)

0 commit comments

Comments
 (0)