Skip to content

Commit ec164b9

Browse files
committed
squash all
1 parent c059db4 commit ec164b9

File tree

7 files changed

+145
-3
lines changed

7 files changed

+145
-3
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Run this with
2+
#
3+
# python -m lightning_lite.cli examples/01_lite_launch/launcher_cli.py --devices 2 --precision bf16
4+
5+
import torch.distributed
6+
7+
from lightning_lite import LightningLite
8+
9+
10+
if __name__ == "__main__":
11+
lite = LightningLite()
12+
print("launched", lite.global_rank)
13+
assert torch.distributed.is_initialized()
14+
lite.barrier()
15+
print("end")
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torch.distributed
2+
3+
from lightning_lite import LightningLite
4+
5+
6+
class Lite(LightningLite):
7+
def run(self):
8+
print("launched", self.global_rank)
9+
assert torch.distributed.is_initialized()
10+
self.barrier()
11+
12+
13+
if __name__ == "__main__":
14+
lite = Lite(accelerator="cpu", devices=2, strategy="ddp")
15+
lite.run()
16+
print("after run", lite.global_rank)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import torch.distributed
2+
3+
from lightning_lite import LightningLite
4+
5+
6+
if __name__ == "__main__":
7+
lite = LightningLite(accelerator="cpu", devices=2, strategy="ddp")
8+
lite.launch()
9+
print("launched", lite.global_rank)
10+
assert torch.distributed.is_initialized()
11+
lite.barrier()
12+
print("end")
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch.distributed
2+
3+
from lightning_lite import LightningLite
4+
5+
6+
def run(lite):
7+
print("launched", lite.global_rank)
8+
assert torch.distributed.is_initialized()
9+
lite.barrier()
10+
print("end")
11+
12+
13+
if __name__ == "__main__":
14+
lite = LightningLite(accelerator="cpu", devices=2, strategy="ddp")
15+
lite.launch(run)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch.distributed
2+
3+
from lightning_lite import LightningLite
4+
5+
6+
def run(lite):
7+
print("launched", lite.global_rank)
8+
assert torch.distributed.is_initialized()
9+
10+
11+
if __name__ == "__main__":
12+
lite = LightningLite(accelerator="cpu", devices=2, strategy="ddp_notebook")
13+
lite.launch(run)
14+
print("main process joins", lite.global_rank)

src/lightning_lite/cli.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
from argparse import ArgumentParser
3+
import torch.distributed.run as torchrun
4+
5+
6+
def main():
7+
parser = ArgumentParser()
8+
parser.add_argument("script", type=str)
9+
parser.add_argument("--accelerator", type=str, default="cpu", choices=("cpu", "cuda", "mps", "tpu", "auto"))
10+
# TODO: note for some accelerators/strategies, torchrun won't make sense (e.g. dp)
11+
# TODO: should we include spawn?
12+
parser.add_argument("--strategy", type=str, default=None, choices=(None, "ddp", "dp", "deepspeed"))
13+
parser.add_argument("--devices", type=str, default="1")
14+
parser.add_argument("--num-nodes", type=int, default=1)
15+
parser.add_argument("--node-rank", type=int, default=0)
16+
parser.add_argument("--main-address", type=str, default="127.0.0.1")
17+
parser.add_argument("--main-port", type=int, default=29400)
18+
parser.add_argument("--precision", type=str, default="32", choices=("32", "16", "bf16"))
19+
args = parser.parse_args()
20+
21+
os.environ["LT_ACCELERATOR"] = str(args.accelerator)
22+
if args.strategy:
23+
os.environ["LT_STRATEGY"] = str(args.strategy)
24+
os.environ["LT_DEVICES"] = str(args.devices)
25+
os.environ["LT_NUM_NODES"] = str(args.num_nodes)
26+
os.environ["LT_PRECISION"] = str(args.precision)
27+
28+
num_devices = int(args.devices) # TODO: count them
29+
30+
torchrun_args = []
31+
torchrun_args.extend(["--nproc_per_node", str(num_devices)])
32+
torchrun_args.extend(["--nnodes", str(args.num_nodes)])
33+
torchrun_args.extend(["--node_rank", str(args.node_rank)])
34+
torchrun_args.extend(["--master_addr", args.main_address])
35+
torchrun_args.extend(["--master_port", str(args.main_port)])
36+
torchrun_args.append(args.script)
37+
38+
os.environ.setdefault("OMP_NUM_THREADS", str(max(1, os.cpu_count() // num_devices)))
39+
40+
torchrun.main(torchrun_args)
41+
42+
43+
if __name__ == "__main__":
44+
main()

src/lightning_lite/lite.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ def __init__(
7777
precision: _PRECISION_INPUT = 32,
7878
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
7979
) -> None:
80+
accelerator = os.getenv("LT_ACCELERATOR", accelerator)
81+
strategy = os.getenv("LT_STRATEGY", strategy)
82+
devices = os.getenv("LT_DEVICES", devices)
83+
num_nodes = os.getenv("LT_NUM_NODES", num_nodes)
84+
precision = os.getenv("LT_PRECISION", precision)
85+
precision = int(precision) if precision in ("16", "32") else precision
86+
8087
self._connector = _Connector(
8188
accelerator=accelerator,
8289
strategy=strategy,
@@ -93,6 +100,9 @@ def __init__(
93100
# wrap the run method so we can inject setup logic or spawn processes for the user
94101
setattr(self, "run", partial(self._run_impl, self.run))
95102

103+
if "LT_ACCELERATOR" in os.environ:
104+
self._strategy.setup_environment()
105+
96106
@property
97107
def device(self) -> torch.device:
98108
"""The current device this process runs on.
@@ -126,7 +136,7 @@ def is_global_zero(self) -> bool:
126136
"""Wether this rank is rank zero."""
127137
return self._strategy.is_global_zero
128138

129-
@abstractmethod
139+
# TODO(lite): Error/warn when run overridden but launcher is used
130140
def run(self, *args: Any, **kwargs: Any) -> Any:
131141
"""All the code inside this run method gets accelerated by Lite.
132142
@@ -367,6 +377,15 @@ def load(self, filepath: Union[str, Path]) -> Any:
367377
"""
368378
return self._strategy.load_checkpoint(filepath)
369379

380+
def launch(self, function: Optional[Callable] = None, *args: Any, **kwargs: Any) -> Any:
381+
function = _do_nothing if function is None else function
382+
function = partial(self._function_with_strategy_setup, function)
383+
args = [self, *args]
384+
if self._strategy.launcher is not None:
385+
return self._strategy.launcher.launch(function, *args, **kwargs)
386+
else:
387+
return function(*args, **kwargs)
388+
370389
@staticmethod
371390
def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) -> int:
372391
"""Helper function to seed everything without explicitly importing Lightning.
@@ -380,9 +399,8 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None)
380399
return seed_everything(seed=seed, workers=workers)
381400

382401
def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
383-
# wrap the real run method with setup logic
402+
# TODO: skip launcher if already launched externally!
384403
run_method = partial(self._run_with_setup, run_method)
385-
386404
if self._strategy.launcher is not None:
387405
return self._strategy.launcher.launch(run_method, *args, **kwargs)
388406
else:
@@ -396,6 +414,11 @@ def _run_with_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> An
396414
), _replace_dunder_methods(BatchSampler):
397415
return run_method(*args, **kwargs)
398416

417+
def _function_with_strategy_setup(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
418+
self._strategy.setup_environment()
419+
with _replace_dunder_methods(DataLoader, "dataset"), _replace_dunder_methods(BatchSampler):
420+
return function(*args, **kwargs)
421+
399422
def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module:
400423
initial_device = next(model.parameters()).device
401424
if any(param.device != initial_device for param in model.parameters()):
@@ -450,3 +473,6 @@ def _validate_setup_dataloaders(dataloaders: Sequence[DataLoader]) -> None:
450473

451474
if any(not isinstance(dl, DataLoader) for dl in dataloaders):
452475
raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")
476+
477+
478+
def _do_nothing(*_): pass

0 commit comments

Comments
 (0)