|
24 | 24 | from lightning.fabric.utilities.apply_func import move_data_to_device
|
25 | 25 | from lightning.fabric.utilities.imports import _IS_INTERACTIVE
|
26 | 26 | from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states
|
| 27 | +from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE |
27 | 28 |
|
28 | 29 | if TYPE_CHECKING:
|
29 | 30 | from lightning.fabric.strategies import ParallelStrategy
|
30 |
| - |
| 31 | + |
| 32 | +if _LIGHTNING_XPU_AVAILABLE: |
| 33 | + from lightning_xpu.fabric import XPUAccelerator |
31 | 34 |
|
32 | 35 | class _MultiProcessingLauncher(_Launcher):
|
33 | 36 | r"""Launches processes that run a given function in parallel, and joins them all at the end.
|
@@ -85,6 +88,8 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
|
85 | 88 | """
|
86 | 89 | if self._start_method in ("fork", "forkserver"):
|
87 | 90 | _check_bad_cuda_fork()
|
| 91 | + if XPUAccelerator.is_available(): |
| 92 | + _check_bad_xpu_fork() |
88 | 93 |
|
89 | 94 | # The default cluster environment in Lightning chooses a random free port number
|
90 | 95 | # This needs to be done in the main process here before starting processes to ensure each rank will connect
|
@@ -187,3 +192,21 @@ def _check_bad_cuda_fork() -> None:
|
187 | 192 | if _IS_INTERACTIVE:
|
188 | 193 | message += " You will have to restart the Python kernel."
|
189 | 194 | raise RuntimeError(message)
|
| 195 | + |
| 196 | +def _check_bad_xpu_fork() -> None: |
| 197 | + """Checks whether it is safe to fork and initialize XPU in the new processes, and raises an exception if not. |
| 198 | +
|
| 199 | + The error message replaces PyTorch's 'Cannot re-initialize XPU in forked subprocess' with helpful advice for |
| 200 | + Lightning users. |
| 201 | + """ |
| 202 | + if not XPUAccelerator.is_xpu_initialized(): |
| 203 | + return |
| 204 | + |
| 205 | + message = ( |
| 206 | + "Lightning can't create new processes if XPU is already initialized. Did you manually call" |
| 207 | + " `torch.xpu.*` functions, have moved the model to the device, or allocated memory on the GPU any" |
| 208 | + " other way? Please remove any such calls, or change the selected strategy." |
| 209 | + ) |
| 210 | + if _IS_INTERACTIVE: |
| 211 | + message += " You will have to restart the Python kernel." |
| 212 | + raise RuntimeError(message) |
0 commit comments