File tree 2 files changed +7
-6
lines changed
fabric/strategies/launchers
pytorch/strategies/launchers 2 files changed +7
-6
lines changed Original file line number Diff line number Diff line change 22
22
23
23
from lightning .fabric .strategies .launchers .launcher import _Launcher
24
24
from lightning .fabric .utilities .apply_func import move_data_to_device
25
- from lightning .fabric .utilities .imports import _IS_INTERACTIVE
25
+ from lightning .fabric .utilities .imports import _IS_INTERACTIVE , _LIGHTNING_XPU_AVAILABLE
26
26
from lightning .fabric .utilities .seed import _collect_rng_states , _set_rng_states
27
- from lightning .fabric .utilities .imports import _LIGHTNING_XPU_AVAILABLE
28
27
29
28
if TYPE_CHECKING :
30
29
from lightning .fabric .strategies import ParallelStrategy
31
-
30
+
32
31
if _LIGHTNING_XPU_AVAILABLE :
33
32
from lightning_xpu .fabric import XPUAccelerator
34
33
34
+
35
35
class _MultiProcessingLauncher (_Launcher ):
36
36
r"""Launches processes that run a given function in parallel, and joins them all at the end.
37
37
@@ -193,6 +193,7 @@ def _check_bad_cuda_fork() -> None:
193
193
message += " You will have to restart the Python kernel."
194
194
raise RuntimeError (message )
195
195
196
+
196
197
def _check_bad_xpu_fork () -> None :
197
198
"""Checks whether it is safe to fork and initialize XPU in the new processes, and raises an exception if not.
198
199
@@ -209,4 +210,4 @@ def _check_bad_xpu_fork() -> None:
209
210
)
210
211
if _IS_INTERACTIVE :
211
212
message += " You will have to restart the Python kernel."
212
- raise RuntimeError (message )
213
+ raise RuntimeError (message )
Original file line number Diff line number Diff line change 34
34
from lightning .pytorch .strategies .launchers .launcher import _Launcher
35
35
from lightning .pytorch .trainer .connectors .signal_connector import _SIGNUM
36
36
from lightning .pytorch .trainer .states import TrainerFn , TrainerState
37
- from lightning .pytorch .utilities .rank_zero import rank_zero_debug
38
37
from lightning .pytorch .utilities .imports import _LIGHTNING_XPU_AVAILABLE
38
+ from lightning .pytorch .utilities .rank_zero import rank_zero_debug
39
39
40
40
if _LIGHTNING_XPU_AVAILABLE :
41
41
from lightning_xpu .pytorch import XPUAccelerator
42
-
42
+
43
43
log = logging .getLogger (__name__ )
44
44
45
45
You can’t perform that action at this time.
0 commit comments