Skip to content

Commit 821ea00

Browse files
justusschockawaelchlipre-commit-ci[bot]
authored
Requirements update (#17998)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9c775c9 commit 821ea00

File tree

18 files changed

+1391
-45
lines changed

18 files changed

+1391
-45
lines changed

.github/workflows/ci-tests-data.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ jobs:
9797
# NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003
9898
run: |
9999
python -m coverage run --source lightning \
100-
-m pytest -v --timeout=30 --durations=50
100+
-m pytest -v --timeout=60 --durations=60
101101
102102
- name: Statistics
103103
if: success()

src/lightning/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from lightning.app.perf import pdb # noqa: E402
2323
from lightning.app.utilities.packaging.build_config import BuildConfig # noqa: E402
2424
from lightning.app.utilities.packaging.cloud_compute import CloudCompute # noqa: E402
25-
from lightning.data import LightningDataset # noqa: E402
25+
from lightning.data import LightningDataset, LightningIterableDataset # noqa: E402
2626
from lightning.fabric.fabric import Fabric # noqa: E402
2727
from lightning.fabric.utilities.seed import seed_everything # noqa: E402
2828
from lightning.pytorch.callbacks import Callback # noqa: E402
@@ -44,6 +44,7 @@
4444
"CloudCompute",
4545
"Trainer",
4646
"LightningDataset",
47+
"LightningIterableDataset",
4748
"LightningDataModule",
4849
"LightningModule",
4950
"Callback",

src/lightning/data/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,5 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99
### Added
1010

1111
- Added `LightningDataset` for optimized data loading including fast loading for S3 buckets. ([#17743](https://github.com/Lightning-AI/lightning/pull/17743))
12+
13+
- Added `LightningIterableDataset` for resumable dataloading with iterable datasets ([#17998](https://github.com/Lightning-AI/lightning/pull/17998))

src/lightning/data/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from lightning.data.dataset import LightningDataset
1+
from lightning.data.datasets import LightningDataset, LightningIterableDataset
22

3-
__all__ = ["LightningDataset"]
3+
__all__ = ["LightningDataset", "LightningIterableDataset"]
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from lightning.data.datasets.iterable import LightningIterableDataset
2+
from lightning.data.datasets.mapping import LightningDataset
3+
4+
__all__ = ["LightningDataset", "LightningIterableDataset"]

src/lightning/data/datasets/base.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from typing import Any, Literal
2+
3+
from torch.utils.data import Dataset as TorchDataset
4+
5+
from lightning.data.backends import _DatasetBackend, LocalDatasetBackend, S3DatasetBackend
6+
from lightning.data.fileio import OpenCloudFileObj
7+
8+
9+
class _Dataset(TorchDataset):
10+
"""Base dataset class for streaming data from a cloud storage.
11+
12+
Args:
13+
backend: storage location of the data_source. current options are "s3" or "local"
14+
"""
15+
16+
def __init__(self, backend: Literal["local", "s3"] = "local"):
17+
self.backend = self._init_backend(backend=backend)
18+
19+
assert isinstance(self.backend, _DatasetBackend)
20+
21+
def _init_backend(self, backend: str) -> _DatasetBackend:
22+
"""Picks the correct backend handler."""
23+
if backend == "s3":
24+
return S3DatasetBackend()
25+
if backend == "local":
26+
return LocalDatasetBackend()
27+
raise ValueError(f"Unsupported backend {backend}")
28+
29+
def open(self, file: str, mode: str = "r", kwargs_for_open: Any = {}, **kwargs: Any) -> OpenCloudFileObj:
30+
"""Opens a stream for the given file.
31+
32+
Returns:
33+
A stream object of the file.
34+
"""
35+
return OpenCloudFileObj(
36+
path=file, mode=mode, kwargs_for_open={**self.backend.credentials(), **kwargs_for_open}, **kwargs
37+
)

src/lightning/data/datasets/env.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from typing import Optional
2+
3+
import torch
4+
from torch.utils.data import get_worker_info
5+
6+
7+
class _DistributedEnv:
8+
"""The environment of the distributed training.
9+
10+
Args:
11+
world_size: The number of total distributed training processes
12+
global_rank: The rank of the current process within this pool of training processes
13+
"""
14+
15+
def __init__(self, world_size: int, global_rank: int):
16+
self.world_size = world_size
17+
self.global_rank = global_rank
18+
19+
@classmethod
20+
def detect(cls) -> "_DistributedEnv":
21+
"""Tries to automatically detect the distributed environment paramters.
22+
23+
Note:
24+
This detection may not work in processes spawned from the distributed processes (e.g. DataLoader workers)
25+
as the distributed framework won't be initialized there.
26+
It will default to 1 distributed process in this case.
27+
"""
28+
if torch.distributed.is_available() and torch.distributed.is_initialized():
29+
world_size = torch.distributed.get_world_size()
30+
global_rank = torch.distributed.get_rank()
31+
else:
32+
world_size = None
33+
global_rank = 0
34+
35+
if world_size is None or world_size == -1:
36+
world_size = 1
37+
38+
return cls(world_size=world_size, global_rank=global_rank)
39+
40+
def __repr__(self) -> str:
41+
return f"{self.__class__.__name__}(world_size: {self.world_size}, global_rank: {self.global_rank}\n)"
42+
43+
def __str__(self) -> str:
44+
return repr(self)
45+
46+
47+
class _WorkerEnv:
48+
"""Contains the environment for the current dataloader within the current training process.
49+
50+
Args:
51+
world_size: The number of dataloader workers for the current training process
52+
rank: The rank of the current worker within the number of workers
53+
"""
54+
55+
def __init__(self, world_size: int, rank: int):
56+
self.world_size = world_size
57+
self.rank = rank
58+
59+
@classmethod
60+
def detect(cls) -> "_WorkerEnv":
61+
"""Automatically detects the number of workers and the current rank.
62+
63+
Note:
64+
This only works reliably within a dataloader worker as otherwise the necessary information won't be present.
65+
In such a case it will default to 1 worker
66+
"""
67+
worker_info = get_worker_info()
68+
num_workers = worker_info.num_workers if worker_info is not None else 1
69+
current_worker_rank = worker_info.id if worker_info is not None else 0
70+
71+
return cls(world_size=num_workers, rank=current_worker_rank)
72+
73+
def __repr__(self) -> str:
74+
return f"{self.__class__.__name__}(world_size: {self.world_size}, rank: {self.rank})"
75+
76+
def __str__(self) -> str:
77+
return repr(self)
78+
79+
80+
class Environment:
81+
"""Contains the compute environment. If not passed, will try to detect.
82+
83+
Args:
84+
dist_env: The distributed environment (distributed worldsize and global rank)
85+
worker_env: The worker environment (number of workers, worker rank)
86+
"""
87+
88+
def __init__(self, dist_env: Optional[_DistributedEnv], worker_env: Optional[_WorkerEnv]):
89+
self.worker_env = worker_env
90+
self.dist_env = dist_env
91+
92+
@classmethod
93+
def from_args(
94+
cls,
95+
dist_world_size: int,
96+
global_rank: int,
97+
num_workers: int,
98+
current_worker_rank: int,
99+
) -> "Environment":
100+
"""Generates the Environment class by already given arguments instead of detecting them.
101+
102+
Args:
103+
dist_world_size: The worldsize used for distributed training (=total number of distributed processes)
104+
global_rank: The distributed global rank of the current process
105+
num_workers: The number of workers per distributed training process
106+
current_worker_rank: The rank of the current worker within the number of workers of
107+
the current training process
108+
"""
109+
dist_env = _DistributedEnv(dist_world_size, global_rank)
110+
worker_env = _WorkerEnv(num_workers, current_worker_rank)
111+
return cls(dist_env=dist_env, worker_env=worker_env)
112+
113+
@property
114+
def num_shards(self) -> int:
115+
"""Returns the total number of shards.
116+
117+
Note:
118+
This may not be accurate in a non-dataloader-worker process like the main training process
119+
as it doesn't necessarily know about the number of dataloader workers.
120+
"""
121+
assert self.worker_env is not None
122+
assert self.dist_env is not None
123+
return self.worker_env.world_size * self.dist_env.world_size
124+
125+
@property
126+
def shard_rank(self) -> int:
127+
"""Returns the rank of the current process wrt. the total number of shards.
128+
129+
Note:
130+
This may not be accurate in a non-dataloader-worker process like the main training process as it
131+
doesn't necessarily know about the number of dataloader workers.
132+
"""
133+
assert self.worker_env is not None
134+
assert self.dist_env is not None
135+
return self.dist_env.global_rank * self.worker_env.world_size + self.worker_env.rank
136+
137+
def __repr__(self) -> str:
138+
dist_env_repr = repr(self.dist_env)
139+
worker_env_repr = repr(self.worker_env)
140+
141+
return (
142+
f"{self.__class__.__name__}(\n\tdist_env: {dist_env_repr},\n\tworker_env: "
143+
+ f"{worker_env_repr}\n\tnum_shards: {self.num_shards},\n\tshard_rank: {self.shard_rank})"
144+
)
145+
146+
def __str__(self) -> str:
147+
return repr(self)

0 commit comments

Comments
 (0)