3
3
from typing_extensions import Protocol , runtime_checkable
4
4
5
5
from lightning_app .components .multi_node .base import MultiNode
6
+ from lightning_app .core .queues import MultiProcessQueue
6
7
from lightning_app .core .work import LightningWork
7
- from lightning_app .utilities .app_helpers import is_static_method
8
8
from lightning_app .utilities .packaging .cloud_compute import CloudCompute
9
- from lightning_app .utilities .proxies import WorkRunExecutor
9
+ from lightning_app .utilities .proxies import _proxy_setattr , unwrap , WorkRunExecutor , WorkStateObserver
10
10
11
11
12
12
@runtime_checkable
@@ -22,6 +22,9 @@ def run(
22
22
23
23
24
24
class _PyTorchSpawnRunExecutor (WorkRunExecutor ):
25
+
26
+ enable_start_observer : bool = False
27
+
25
28
def __call__ (
26
29
self ,
27
30
main_address : str ,
@@ -31,10 +34,31 @@ def __call__(
31
34
):
32
35
import torch
33
36
34
- nprocs = torch .cuda .device_count () if torch .cuda .is_available () else 1
35
- torch .multiprocessing .spawn (
36
- self .run , args = (self .work_run , main_address , main_port , num_nodes , node_rank , nprocs ), nprocs = nprocs
37
- )
37
+ with self .enable_spawn ():
38
+ nprocs = torch .cuda .device_count () if torch .cuda .is_available () else 1
39
+ queue = self .delta_queue if isinstance (self .delta_queue , MultiProcessQueue ) else self .delta_queue .to_dict ()
40
+ torch .multiprocessing .spawn (
41
+ self .dispatch_run ,
42
+ args = (self .__class__ , self .work , queue , main_address , main_port , num_nodes , node_rank , nprocs ),
43
+ nprocs = nprocs ,
44
+ )
45
+
46
+ @staticmethod
47
+ def dispatch_run (local_rank , cls , work , delta_queue , * args , ** kwargs ):
48
+ if local_rank == 0 :
49
+ if isinstance (delta_queue , dict ):
50
+ delta_queue = cls .process_queue (delta_queue )
51
+ work ._request_queue = cls .process_queue (work ._request_queue )
52
+ work ._response_queue = cls .process_queue (work ._response_queue )
53
+
54
+ state_observer = WorkStateObserver (work , delta_queue = delta_queue )
55
+ state_observer .start ()
56
+ _proxy_setattr (work , delta_queue , state_observer )
57
+
58
+ cls .run (local_rank , unwrap (work .run ), * args , ** kwargs )
59
+
60
+ if local_rank == 0 :
61
+ state_observer .join (0 )
38
62
39
63
@staticmethod
40
64
def run (
@@ -46,6 +70,7 @@ def run(
46
70
node_rank : int ,
47
71
nprocs : int ,
48
72
):
73
+
49
74
import torch
50
75
51
76
# 1. Setting distributed environment
@@ -76,11 +101,6 @@ def __init__(
76
101
** work_kwargs : Any ,
77
102
) -> None :
78
103
assert issubclass (work_cls , _PyTorchSpawnWorkProtocol )
79
- if not is_static_method (work_cls , "run" ):
80
- raise TypeError (
81
- f"The provided { work_cls } run method needs to be static for now."
82
- "HINT: Remove `self` and add staticmethod decorator."
83
- )
84
104
85
105
# Note: Private way to modify the work run executor
86
106
# Probably exposed to the users in the future if needed.
0 commit comments