-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Requirements update #17998
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Requirements update #17998
Changes from 35 commits
Commits
Show all changes
47 commits
Select commit
Hold shift + click to select a range
48b3d0f
initial_commit
justusschock f94c4b4
update
justusschock c0dfad6
empty commit
justusschock e3522b6
Merge branch 'master' into resumable_loading
justusschock b239b84
fix import
justusschock 3d4d9ea
update test import
justusschock 10526fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 3b9a488
fix import
justusschock 6ce77ca
Merge branch 'resumable_loading' of https://github.com/lightning-ai/l…
justusschock 3c772d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 39e5ba4
imports
justusschock a84f0b1
Merge branch 'resumable_loading' of https://github.com/lightning-ai/l…
justusschock aa6cb70
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 062be88
stupid import
justusschock 405deed
Merge branch 'resumable_loading' of https://github.com/lightning-ai/l…
justusschock d84391d
update mocks
justusschock 94b4d77
typing
justusschock 7184872
resolve filepath
justusschock 5dbc205
update
justusschock 31299f5
update timeout
justusschock 953721b
timeouts
justusschock eaed15a
Merge branch 'resumable_loading' of https://github.com/lightning-ai/l…
justusschock 7f0d507
mypy
justusschock de0ee11
don't always run all tests
justusschock c618cf4
add conftest
justusschock 37d3843
update
justusschock 27756f9
run it not everywhere, only linux, others are slooooooow
justusschock 3273ae5
mypy
justusschock b120d43
Update src/lightning/fabric/utilities/spike.py
justusschock 5663790
Update src/lightning/fabric/utilities/spike.py
justusschock 5429c00
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 820b77a
remove spike
justusschock f3c3059
chlog
justusschock bef2667
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 301c6bf
Merge branch 'master' into resumable_loading
justusschock 251070c
Update src/lightning/data/datasets/iterable.py
justusschock 5b6eda5
formatting and docs
justusschock 753f4a1
index update
justusschock c04161f
formatting and docs
justusschock a4f527e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 6310b18
update
justusschock 9e2b0a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ff21ac7
messed up
justusschock 5f27c8f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 185eaa0
typing
justusschock 9a1dc6c
Merge branch 'master' into resumable_loading
justusschock b968c6c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from lightning.data.dataset import LightningDataset | ||
from lightning.data.datasets import LightningDataset, LightningIterableDataset | ||
|
||
__all__ = ["LightningDataset"] | ||
__all__ = ["LightningDataset", "LightningIterableDataset"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from lightning.data.datasets.iterable import LightningIterableDataset | ||
from lightning.data.datasets.mapping import LightningDataset | ||
|
||
__all__ = ["LightningDataset", "LightningIterableDataset"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from typing import Any | ||
|
||
from torch.utils.data import Dataset as TorchDataset | ||
|
||
from lightning.data.backends import _DatasetBackend, LocalDatasetBackend, S3DatasetBackend | ||
from lightning.data.fileio import OpenCloudFileObj | ||
|
||
|
||
class _Dataset(TorchDataset): | ||
"""Base dataset class for streaming data from a cloud storage. | ||
|
||
Args: | ||
backend: storage location of the data_source. current options are "s3" or "local" | ||
""" | ||
|
||
def __init__(self, backend: str = "local"): | ||
self.backend = self._init_backend(backend=backend) | ||
|
||
assert isinstance(self.backend, _DatasetBackend) | ||
|
||
def _init_backend(self, backend: str) -> _DatasetBackend: | ||
"""Picks the correct backend handler.""" | ||
if backend == "s3": | ||
return S3DatasetBackend() | ||
justusschock marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if backend == "local": | ||
return LocalDatasetBackend() | ||
raise ValueError(f"Unsupported backend {backend}") | ||
|
||
def open(self, file: str, mode: str = "r", kwargs_for_open: Any = {}, **kwargs: Any) -> OpenCloudFileObj: | ||
"""Opens a stream for the given file. | ||
|
||
Returns: | ||
A stream object of the file. | ||
""" | ||
return OpenCloudFileObj( | ||
path=file, mode=mode, kwargs_for_open={**self.backend.credentials(), **kwargs_for_open}, **kwargs | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
from torch.utils.data import get_worker_info | ||
|
||
|
||
class DistributedEnv: | ||
justusschock marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""The environment of the distributed training. | ||
|
||
Args: | ||
world_size: The number of total distributed training processes | ||
global_rank: The rank of the current process within this pool of training processes | ||
""" | ||
|
||
def __init__(self, world_size: int, global_rank: int): | ||
self.world_size = world_size | ||
self.global_rank = global_rank | ||
|
||
@classmethod | ||
def detect(cls) -> "DistributedEnv": | ||
"""Tries to automatically detect the distributed environment paramters. | ||
|
||
Note: | ||
This detection may not work in processes spawned from the distributed processes (e.g. DataLoader workers) | ||
as the distributed framework won't be initialized there. | ||
It will default to 1 distributed process in this case. | ||
""" | ||
if torch.distributed.is_available() and torch.distributed.is_initialized(): | ||
world_size = torch.distributed.get_world_size() | ||
global_rank = torch.distributed.get_rank() | ||
else: | ||
world_size = None | ||
global_rank = 0 | ||
|
||
if world_size is None or world_size == -1: | ||
world_size = 1 | ||
|
||
return cls(world_size=world_size, global_rank=global_rank) | ||
|
||
def __repr__(self) -> str: | ||
return f"{self.__class__.__name__}(world_size: {self.world_size}, global_rank: {self.global_rank}\n)" | ||
|
||
def __str__(self) -> str: | ||
return repr(self) | ||
|
||
|
||
class WorkerEnv: | ||
"""Contains the environment for the current dataloader within the current training process. | ||
|
||
Args: | ||
world_size: The number of dataloader workers for the current training process | ||
rank: The rank of the current worker within the number of workers | ||
""" | ||
|
||
def __init__(self, world_size: int, rank: int): | ||
self.world_size = world_size | ||
self.rank = rank | ||
|
||
@classmethod | ||
def detect(cls) -> "WorkerEnv": | ||
"""Automatically detects the number of workers and the current rank. | ||
|
||
Note: | ||
This only works reliably within a dataloader worker as otherwise the necessary information won't be present. | ||
In such a case it will default to 1 worker | ||
""" | ||
worker_info = get_worker_info() | ||
num_workers = worker_info.num_workers if worker_info is not None else 1 | ||
current_worker_rank = worker_info.id if worker_info is not None else 0 | ||
|
||
return cls(world_size=num_workers, rank=current_worker_rank) | ||
|
||
def __repr__(self) -> str: | ||
return f"{self.__class__.__name__}(world_size: {self.world_size}, rank: {self.rank})" | ||
|
||
def __str__(self) -> str: | ||
return repr(self) | ||
|
||
|
||
class Environment: | ||
"""Contains the compute environment. If not passed, will try to detect. | ||
|
||
Args: | ||
dist_env: The distributed environment (distributed worldsize and global rank) | ||
worker_env: The worker environment (number of workers, worker rank) | ||
""" | ||
|
||
def __init__(self, dist_env: Optional[DistributedEnv], worker_env: Optional[WorkerEnv]): | ||
self.worker_env = worker_env | ||
self.dist_env = dist_env | ||
|
||
@classmethod | ||
def from_args( | ||
cls, | ||
dist_world_size: int, | ||
global_rank: int, | ||
num_workers: int, | ||
current_worker_rank: int, | ||
) -> "Environment": | ||
"""Generates the Environment class by already given arguments instead of detecting them. | ||
|
||
Args: | ||
dist_world_size: The worldsize used for distributed training (=total number of distributed processes) | ||
global_rank: The distributed global rank of the current process | ||
num_workers: The number of workers per distributed training process | ||
current_worker_rank: The rank of the current worker within the number of workers of | ||
the current training process | ||
""" | ||
dist_env = DistributedEnv(dist_world_size, global_rank) | ||
worker_env = WorkerEnv(num_workers, current_worker_rank) | ||
return cls(dist_env=dist_env, worker_env=worker_env) | ||
|
||
@property | ||
def num_shards(self) -> int: | ||
"""Returns the total number of shards. | ||
|
||
Note: | ||
This may not be accurate in a non-dataloader-worker process like the main training process | ||
as it doesn't necessarily know about the number of dataloader workers. | ||
""" | ||
assert self.worker_env is not None | ||
justusschock marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert self.dist_env is not None | ||
return self.worker_env.world_size * self.dist_env.world_size | ||
|
||
@property | ||
def shard_rank(self) -> int: | ||
"""Returns the rank of the current process wrt. the total number of shards. | ||
|
||
Note: | ||
This may not be accurate in a non-dataloader-worker process like the main training process as it | ||
doesn't necessarily know about the number of dataloader workers. | ||
""" | ||
assert self.worker_env is not None | ||
assert self.dist_env is not None | ||
return self.dist_env.global_rank * self.worker_env.world_size + self.worker_env.rank | ||
|
||
def __repr__(self) -> str: | ||
dist_env_repr = repr(self.dist_env) | ||
worker_env_repr = repr(self.worker_env) | ||
|
||
return ( | ||
f"{self.__class__.__name__}(\n\tdist_env: {dist_env_repr},\n\tworker_env: " | ||
+ f"{worker_env_repr}\n\tnum_shards: {self.num_shards},\n\tshard_rank: {self.shard_rank})" | ||
) | ||
|
||
def __str__(self) -> str: | ||
return repr(self) |
File renamed without changes.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.