Skip to content

Requirements update #17998

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 47 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
48b3d0f
initial_commit
justusschock Jul 6, 2023
f94c4b4
update
justusschock Jul 6, 2023
c0dfad6
empty commit
justusschock Jul 6, 2023
e3522b6
Merge branch 'master' into resumable_loading
justusschock Jul 6, 2023
b239b84
fix import
justusschock Jul 6, 2023
3d4d9ea
update test import
justusschock Jul 6, 2023
10526fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 6, 2023
3b9a488
fix import
justusschock Jul 6, 2023
6ce77ca
Merge branch 'resumable_loading' of https://github.com/lightning-ai/l…
justusschock Jul 6, 2023
3c772d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 6, 2023
39e5ba4
imports
justusschock Jul 6, 2023
a84f0b1
Merge branch 'resumable_loading' of https://github.com/lightning-ai/l…
justusschock Jul 6, 2023
aa6cb70
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 6, 2023
062be88
stupid import
justusschock Jul 6, 2023
405deed
Merge branch 'resumable_loading' of https://github.com/lightning-ai/l…
justusschock Jul 6, 2023
d84391d
update mocks
justusschock Jul 6, 2023
94b4d77
typing
justusschock Jul 6, 2023
7184872
resolve filepath
justusschock Jul 6, 2023
5dbc205
update
justusschock Jul 6, 2023
31299f5
update timeout
justusschock Jul 6, 2023
953721b
timeouts
justusschock Jul 6, 2023
eaed15a
Merge branch 'resumable_loading' of https://github.com/lightning-ai/l…
justusschock Jul 6, 2023
7f0d507
mypy
justusschock Jul 6, 2023
de0ee11
don't always run all tests
justusschock Jul 6, 2023
c618cf4
add conftest
justusschock Jul 6, 2023
37d3843
update
justusschock Jul 6, 2023
27756f9
run it not everywhere, only linux, others are slooooooow
justusschock Jul 6, 2023
3273ae5
mypy
justusschock Jul 6, 2023
b120d43
Update src/lightning/fabric/utilities/spike.py
justusschock Jul 7, 2023
5663790
Update src/lightning/fabric/utilities/spike.py
justusschock Jul 7, 2023
5429c00
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 7, 2023
820b77a
remove spike
justusschock Jul 7, 2023
f3c3059
chlog
justusschock Jul 7, 2023
bef2667
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 7, 2023
301c6bf
Merge branch 'master' into resumable_loading
justusschock Jul 7, 2023
251070c
Update src/lightning/data/datasets/iterable.py
justusschock Jul 7, 2023
5b6eda5
formatting and docs
justusschock Jul 7, 2023
753f4a1
index update
justusschock Jul 7, 2023
c04161f
formatting and docs
justusschock Jul 7, 2023
a4f527e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 7, 2023
6310b18
update
justusschock Jul 7, 2023
9e2b0a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 7, 2023
ff21ac7
messed up
justusschock Jul 7, 2023
5f27c8f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 7, 2023
185eaa0
typing
justusschock Jul 7, 2023
9a1dc6c
Merge branch 'master' into resumable_loading
justusschock Jul 7, 2023
b968c6c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 7, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests-data.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ jobs:
# NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003
run: |
python -m coverage run --source lightning \
-m pytest -v --timeout=30 --durations=50
-m pytest -v --timeout=60 --durations=60

- name: Statistics
if: success()
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from lightning.app.perf import pdb # noqa: E402
from lightning.app.utilities.packaging.build_config import BuildConfig # noqa: E402
from lightning.app.utilities.packaging.cloud_compute import CloudCompute # noqa: E402
from lightning.data import LightningDataset # noqa: E402
from lightning.data import LightningDataset, LightningIterableDataset # noqa: E402
from lightning.fabric.fabric import Fabric # noqa: E402
from lightning.fabric.utilities.seed import seed_everything # noqa: E402
from lightning.pytorch.callbacks import Callback # noqa: E402
Expand All @@ -44,6 +44,7 @@
"CloudCompute",
"Trainer",
"LightningDataset",
"LightningIterableDataset",
"LightningDataModule",
"LightningModule",
"Callback",
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/data/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added `LightningDataset` for optimized data loading including fast loading for S3 buckets. ([#17743](https://github.com/Lightning-AI/lightning/pull/17743))

- Added `LightningIterableDataset` for resumable dataloading with iterable datasets ([#17998](https://github.com/Lightning-AI/lightning/pull/17998))
4 changes: 2 additions & 2 deletions src/lightning/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from lightning.data.dataset import LightningDataset
from lightning.data.datasets import LightningDataset, LightningIterableDataset

__all__ = ["LightningDataset"]
__all__ = ["LightningDataset", "LightningIterableDataset"]
4 changes: 4 additions & 0 deletions src/lightning/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from lightning.data.datasets.iterable import LightningIterableDataset
from lightning.data.datasets.mapping import LightningDataset

__all__ = ["LightningDataset", "LightningIterableDataset"]
37 changes: 37 additions & 0 deletions src/lightning/data/datasets/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Any

from torch.utils.data import Dataset as TorchDataset

from lightning.data.backends import _DatasetBackend, LocalDatasetBackend, S3DatasetBackend
from lightning.data.fileio import OpenCloudFileObj


class _Dataset(TorchDataset):
"""Base dataset class for streaming data from a cloud storage.

Args:
backend: storage location of the data_source. current options are "s3" or "local"
"""

def __init__(self, backend: str = "local"):
self.backend = self._init_backend(backend=backend)

assert isinstance(self.backend, _DatasetBackend)

def _init_backend(self, backend: str) -> _DatasetBackend:
"""Picks the correct backend handler."""
if backend == "s3":
return S3DatasetBackend()
if backend == "local":
return LocalDatasetBackend()
raise ValueError(f"Unsupported backend {backend}")

def open(self, file: str, mode: str = "r", kwargs_for_open: Any = {}, **kwargs: Any) -> OpenCloudFileObj:
"""Opens a stream for the given file.

Returns:
A stream object of the file.
"""
return OpenCloudFileObj(
path=file, mode=mode, kwargs_for_open={**self.backend.credentials(), **kwargs_for_open}, **kwargs
)
147 changes: 147 additions & 0 deletions src/lightning/data/datasets/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from typing import Optional

import torch
from torch.utils.data import get_worker_info


class DistributedEnv:
"""The environment of the distributed training.

Args:
world_size: The number of total distributed training processes
global_rank: The rank of the current process within this pool of training processes
"""

def __init__(self, world_size: int, global_rank: int):
self.world_size = world_size
self.global_rank = global_rank

@classmethod
def detect(cls) -> "DistributedEnv":
"""Tries to automatically detect the distributed environment paramters.

Note:
This detection may not work in processes spawned from the distributed processes (e.g. DataLoader workers)
as the distributed framework won't be initialized there.
It will default to 1 distributed process in this case.
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
global_rank = torch.distributed.get_rank()
else:
world_size = None
global_rank = 0

if world_size is None or world_size == -1:
world_size = 1

return cls(world_size=world_size, global_rank=global_rank)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(world_size: {self.world_size}, global_rank: {self.global_rank}\n)"

def __str__(self) -> str:
return repr(self)


class WorkerEnv:
"""Contains the environment for the current dataloader within the current training process.

Args:
world_size: The number of dataloader workers for the current training process
rank: The rank of the current worker within the number of workers
"""

def __init__(self, world_size: int, rank: int):
self.world_size = world_size
self.rank = rank

@classmethod
def detect(cls) -> "WorkerEnv":
"""Automatically detects the number of workers and the current rank.

Note:
This only works reliably within a dataloader worker as otherwise the necessary information won't be present.
In such a case it will default to 1 worker
"""
worker_info = get_worker_info()
num_workers = worker_info.num_workers if worker_info is not None else 1
current_worker_rank = worker_info.id if worker_info is not None else 0

return cls(world_size=num_workers, rank=current_worker_rank)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(world_size: {self.world_size}, rank: {self.rank})"

def __str__(self) -> str:
return repr(self)


class Environment:
"""Contains the compute environment. If not passed, will try to detect.

Args:
dist_env: The distributed environment (distributed worldsize and global rank)
worker_env: The worker environment (number of workers, worker rank)
"""

def __init__(self, dist_env: Optional[DistributedEnv], worker_env: Optional[WorkerEnv]):
self.worker_env = worker_env
self.dist_env = dist_env

@classmethod
def from_args(
cls,
dist_world_size: int,
global_rank: int,
num_workers: int,
current_worker_rank: int,
) -> "Environment":
"""Generates the Environment class by already given arguments instead of detecting them.

Args:
dist_world_size: The worldsize used for distributed training (=total number of distributed processes)
global_rank: The distributed global rank of the current process
num_workers: The number of workers per distributed training process
current_worker_rank: The rank of the current worker within the number of workers of
the current training process
"""
dist_env = DistributedEnv(dist_world_size, global_rank)
worker_env = WorkerEnv(num_workers, current_worker_rank)
return cls(dist_env=dist_env, worker_env=worker_env)

@property
def num_shards(self) -> int:
"""Returns the total number of shards.

Note:
This may not be accurate in a non-dataloader-worker process like the main training process
as it doesn't necessarily know about the number of dataloader workers.
"""
assert self.worker_env is not None
assert self.dist_env is not None
return self.worker_env.world_size * self.dist_env.world_size

@property
def shard_rank(self) -> int:
"""Returns the rank of the current process wrt. the total number of shards.

Note:
This may not be accurate in a non-dataloader-worker process like the main training process as it
doesn't necessarily know about the number of dataloader workers.
"""
assert self.worker_env is not None
assert self.dist_env is not None
return self.dist_env.global_rank * self.worker_env.world_size + self.worker_env.rank

def __repr__(self) -> str:
dist_env_repr = repr(self.dist_env)
worker_env_repr = repr(self.worker_env)

return (
f"{self.__class__.__name__}(\n\tdist_env: {dist_env_repr},\n\tworker_env: "
+ f"{worker_env_repr}\n\tnum_shards: {self.num_shards},\n\tshard_rank: {self.shard_rank})"
)

def __str__(self) -> str:
return repr(self)
Loading