Skip to content

Commit 61c1f69

Browse files
authored
[App] Enable state broadcast with MultiNode (#15607)
1 parent 4ea44dd commit 61c1f69

File tree

10 files changed

+166
-94
lines changed

10 files changed

+166
-94
lines changed

examples/app_multi_node/train_lite.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,26 @@
66

77

88
class LitePyTorchDistributed(L.LightningWork):
9-
@staticmethod
10-
def run():
11-
# 1. Create LightningLite.
12-
lite = LightningLite(strategy="ddp", precision=16)
9+
def run(self):
10+
# 1. Prepare the model
11+
model = torch.nn.Sequential(
12+
torch.nn.Linear(1, 1),
13+
torch.nn.ReLU(),
14+
torch.nn.Linear(1, 1),
15+
)
1316

14-
# 2. Prepare distributed model and optimizer.
15-
model = torch.nn.Linear(32, 2)
16-
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
17-
model, optimizer = lite.setup(model, optimizer)
17+
# 2. Create LightningLite.
18+
lite = LightningLite(strategy="ddp", precision=16)
19+
model, optimizer = lite.setup(model, torch.optim.SGD(model.parameters(), lr=0.01))
1820
criterion = torch.nn.MSELoss()
1921

20-
# 3. Train the model for 50 steps.
21-
for step in range(50):
22+
# 3. Train the model for 1000 steps.
23+
for step in range(1000):
2224
model.zero_grad()
23-
x = torch.randn(64, 32).to(lite.device)
25+
x = torch.tensor([0.8]).to(lite.device)
26+
target = torch.tensor([1.0]).to(lite.device)
2427
output = model(x)
25-
loss = criterion(output, torch.ones_like(output))
28+
loss = criterion(output, target)
2629
print(f"global_rank: {lite.global_rank} step: {step} loss: {loss}")
2730
lite.backward(loss)
2831
optimizer.step()

examples/app_multi_node/train_lt.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44

55

66
class LightningTrainerDistributed(L.LightningWork):
7-
@staticmethod
8-
def run():
7+
def run(self):
98
model = BoringModel()
109
trainer = L.Trainer(
11-
max_epochs=10,
10+
max_steps=1000,
1211
strategy="ddp",
1312
)
1413
trainer.fit(model)

examples/app_multi_node/train_pytorch.py

+13-14
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,28 @@ def distributed_train(local_rank: int, main_address: str, main_port: int, num_no
1818
init_method=f"tcp://{main_address}:{main_port}",
1919
)
2020

21-
# 2. Prepare distributed model
22-
model = torch.nn.Linear(32, 2)
21+
# 2. Prepare the model
22+
model = torch.nn.Sequential(
23+
torch.nn.Linear(1, 1),
24+
torch.nn.ReLU(),
25+
torch.nn.Linear(1, 1),
26+
)
2327

2428
# 3. Setup distributed training
25-
if torch.cuda.is_available():
26-
device = torch.device(f"cuda:{local_rank}")
27-
torch.cuda.set_device(device)
28-
else:
29-
device = torch.device("cpu")
30-
31-
model = model.to(device)
32-
model = DistributedDataParallel(model, device_ids=[device.index] if torch.cuda.is_available() else None)
29+
device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu")
30+
model = DistributedDataParallel(model.to(device), device_ids=[local_rank] if torch.cuda.is_available() else None)
3331

3432
# 4. Prepare loss and optimizer
3533
criterion = torch.nn.MSELoss()
3634
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
3735

38-
# 5. Train the model for 50 steps.
39-
for step in range(50):
36+
# 5. Train the model for 1000 steps.
37+
for step in range(1000):
4038
model.zero_grad()
41-
x = torch.randn(64, 32).to(device)
39+
x = torch.tensor([0.8]).to(device)
40+
target = torch.tensor([1.0]).to(device)
4241
output = model(x)
43-
loss = criterion(output, torch.ones_like(output))
42+
loss = criterion(output, target)
4443
print(f"global_rank: {global_rank} step: {step} loss: {loss}")
4544
loss.backward()
4645
optimizer.step()

examples/app_multi_node/train_pytorch_spawn.py

+16-17
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,37 @@
66

77

88
class PyTorchDistributed(L.LightningWork):
9-
10-
# Note: Only staticmethod are support for now with `PyTorchSpawnMultiNode`
11-
@staticmethod
129
def run(
10+
self,
1311
world_size: int,
1412
node_rank: int,
1513
global_rank: str,
1614
local_rank: int,
1715
):
18-
# 1. Prepare distributed model
19-
model = torch.nn.Linear(32, 2)
16+
# 1. Prepare the model
17+
model = torch.nn.Sequential(
18+
torch.nn.Linear(1, 1),
19+
torch.nn.ReLU(),
20+
torch.nn.Linear(1, 1),
21+
)
2022

2123
# 2. Setup distributed training
22-
if torch.cuda.is_available():
23-
device = torch.device(f"cuda:{local_rank}")
24-
torch.cuda.set_device(device)
25-
else:
26-
device = torch.device("cpu")
27-
28-
model = model.to(device)
29-
model = DistributedDataParallel(model, device_ids=[device.index] if torch.cuda.is_available() else None)
24+
device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu")
25+
model = DistributedDataParallel(
26+
model.to(device), device_ids=[local_rank] if torch.cuda.is_available() else None
27+
)
3028

3129
# 3. Prepare loss and optimizer
3230
criterion = torch.nn.MSELoss()
3331
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
3432

35-
# 4. Train the model for 50 steps.
36-
for step in range(50):
33+
# 4. Train the model for 1000 steps.
34+
for step in range(1000):
3735
model.zero_grad()
38-
x = torch.randn(64, 32).to(device)
36+
x = torch.tensor([0.8]).to(device)
37+
target = torch.tensor([1.0]).to(device)
3938
output = model(x)
40-
loss = criterion(output, torch.ones_like(output))
39+
loss = criterion(output, target)
4140
print(f"global_rank: {global_rank} step: {step} loss: {loss}")
4241
loss.backward()
4342
optimizer.step()

src/lightning_app/CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2626

2727
- Added `bi-directional` delta updates between the flow and the works ([#15582](https://github.com/Lightning-AI/lightning/pull/15582))
2828

29+
- Enabled MultiNode Components to support state broadcasting ([#15607](https://github.com/Lightning-AI/lightning/pull/15607))
30+
2931

3032
### Changed
3133

src/lightning_app/components/multi_node/lite.py

-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from lightning_app.components.multi_node.base import MultiNode
88
from lightning_app.components.multi_node.pytorch_spawn import _PyTorchSpawnRunExecutor
99
from lightning_app.core.work import LightningWork
10-
from lightning_app.utilities.app_helpers import is_static_method
1110
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
1211
from lightning_app.utilities.tracer import Tracer
1312

@@ -82,11 +81,6 @@ def __init__(
8281
**work_kwargs: Any,
8382
) -> None:
8483
assert issubclass(work_cls, _LiteWorkProtocol)
85-
if not is_static_method(work_cls, "run"):
86-
raise TypeError(
87-
f"The provided {work_cls} run method needs to be static for now."
88-
"HINT: Remove `self` and add staticmethod decorator."
89-
)
9084

9185
# Note: Private way to modify the work run executor
9286
# Probably exposed to the users in the future if needed.

src/lightning_app/components/multi_node/pytorch_spawn.py

+31-11
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from typing_extensions import Protocol, runtime_checkable
44

55
from lightning_app.components.multi_node.base import MultiNode
6+
from lightning_app.core.queues import MultiProcessQueue
67
from lightning_app.core.work import LightningWork
7-
from lightning_app.utilities.app_helpers import is_static_method
88
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
9-
from lightning_app.utilities.proxies import WorkRunExecutor
9+
from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver
1010

1111

1212
@runtime_checkable
@@ -22,6 +22,9 @@ def run(
2222

2323

2424
class _PyTorchSpawnRunExecutor(WorkRunExecutor):
25+
26+
enable_start_observer: bool = False
27+
2528
def __call__(
2629
self,
2730
main_address: str,
@@ -31,10 +34,31 @@ def __call__(
3134
):
3235
import torch
3336

34-
nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1
35-
torch.multiprocessing.spawn(
36-
self.run, args=(self.work_run, main_address, main_port, num_nodes, node_rank, nprocs), nprocs=nprocs
37-
)
37+
with self.enable_spawn():
38+
nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1
39+
queue = self.delta_queue if isinstance(self.delta_queue, MultiProcessQueue) else self.delta_queue.to_dict()
40+
torch.multiprocessing.spawn(
41+
self.dispatch_run,
42+
args=(self.__class__, self.work, queue, main_address, main_port, num_nodes, node_rank, nprocs),
43+
nprocs=nprocs,
44+
)
45+
46+
@staticmethod
47+
def dispatch_run(local_rank, cls, work, delta_queue, *args, **kwargs):
48+
if local_rank == 0:
49+
if isinstance(delta_queue, dict):
50+
delta_queue = cls.process_queue(delta_queue)
51+
work._request_queue = cls.process_queue(work._request_queue)
52+
work._response_queue = cls.process_queue(work._response_queue)
53+
54+
state_observer = WorkStateObserver(work, delta_queue=delta_queue)
55+
state_observer.start()
56+
_proxy_setattr(work, delta_queue, state_observer)
57+
58+
cls.run(local_rank, unwrap(work.run), *args, **kwargs)
59+
60+
if local_rank == 0:
61+
state_observer.join(0)
3862

3963
@staticmethod
4064
def run(
@@ -46,6 +70,7 @@ def run(
4670
node_rank: int,
4771
nprocs: int,
4872
):
73+
4974
import torch
5075

5176
# 1. Setting distributed environment
@@ -76,11 +101,6 @@ def __init__(
76101
**work_kwargs: Any,
77102
) -> None:
78103
assert issubclass(work_cls, _PyTorchSpawnWorkProtocol)
79-
if not is_static_method(work_cls, "run"):
80-
raise TypeError(
81-
f"The provided {work_cls} run method needs to be static for now."
82-
"HINT: Remove `self` and add staticmethod decorator."
83-
)
84104

85105
# Note: Private way to modify the work run executor
86106
# Probably exposed to the users in the future if needed.

src/lightning_app/components/multi_node/trainer.py

-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from lightning_app.components.multi_node.base import MultiNode
88
from lightning_app.components.multi_node.pytorch_spawn import _PyTorchSpawnRunExecutor
99
from lightning_app.core.work import LightningWork
10-
from lightning_app.utilities.app_helpers import is_static_method
1110
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
1211
from lightning_app.utilities.tracer import Tracer
1312

@@ -81,11 +80,6 @@ def __init__(
8180
**work_kwargs: Any,
8281
) -> None:
8382
assert issubclass(work_cls, _LightningTrainerWorkProtocol)
84-
if not is_static_method(work_cls, "run"):
85-
raise TypeError(
86-
f"The provided {work_cls} run method needs to be static for now."
87-
"HINT: Remove `self` and add staticmethod decorator."
88-
)
8983

9084
# Note: Private way to modify the work run executor
9185
# Probably exposed to the users in the future if needed.

src/lightning_app/core/queues.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -235,12 +235,12 @@ def __init__(
235235
"""
236236
if name is None:
237237
raise ValueError("You must specify a name for the queue")
238-
host = host or REDIS_HOST
239-
port = port or REDIS_PORT
240-
password = password or REDIS_PASSWORD
238+
self.host = host or REDIS_HOST
239+
self.port = port or REDIS_PORT
240+
self.password = password or REDIS_PASSWORD
241241
self.name = name
242242
self.default_timeout = default_timeout
243-
self.redis = redis.Redis(host=host, port=port, password=password)
243+
self.redis = redis.Redis(host=self.host, port=self.port, password=self.password)
244244

245245
def put(self, item: Any) -> None:
246246
from lightning_app import LightningWork
@@ -329,6 +329,20 @@ def is_running(self) -> bool:
329329
except redis.exceptions.ConnectionError:
330330
return False
331331

332+
def to_dict(self):
333+
return {
334+
"type": "redis",
335+
"name": self.name,
336+
"default_timeout": self.default_timeout,
337+
"host": self.host,
338+
"port": self.port,
339+
"password": self.password,
340+
}
341+
342+
@classmethod
343+
def from_dict(cls, state):
344+
return cls(**state)
345+
332346

333347
class HTTPQueue(BaseQueue):
334348
def __init__(self, name: str, default_timeout: float):
@@ -414,6 +428,17 @@ def _split_app_id_and_queue_name(queue_name):
414428
app_id, queue_name = queue_name.split("_", 1)
415429
return app_id, queue_name
416430

431+
def to_dict(self):
432+
return {
433+
"type": "http",
434+
"name": self.name,
435+
"default_timeout": self.default_timeout,
436+
}
437+
438+
@classmethod
439+
def from_dict(cls, state):
440+
return cls(**state)
441+
417442

418443
def debug_log_callback(message: str, *args: Any, **kwargs: Any) -> None:
419444
if QUEUE_DEBUG_ENABLED or (Path(LIGHTNING_DIR) / "QUEUE_DEBUG_ENABLED").exists():

0 commit comments

Comments
 (0)