diff --git a/.github/workflows/ci-tests-data.yml b/.github/workflows/ci-tests-data.yml index 6c2a083a2a877..7e2d4287fd1a6 100644 --- a/.github/workflows/ci-tests-data.yml +++ b/.github/workflows/ci-tests-data.yml @@ -97,7 +97,7 @@ jobs: # NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003 run: | python -m coverage run --source lightning \ - -m pytest -v --timeout=30 --durations=50 + -m pytest -v --timeout=60 --durations=60 - name: Statistics if: success() diff --git a/src/lightning/__init__.py b/src/lightning/__init__.py index 0a209e9064131..11daa4d78243f 100644 --- a/src/lightning/__init__.py +++ b/src/lightning/__init__.py @@ -22,7 +22,7 @@ from lightning.app.perf import pdb # noqa: E402 from lightning.app.utilities.packaging.build_config import BuildConfig # noqa: E402 from lightning.app.utilities.packaging.cloud_compute import CloudCompute # noqa: E402 -from lightning.data import LightningDataset # noqa: E402 +from lightning.data import LightningDataset, LightningIterableDataset # noqa: E402 from lightning.fabric.fabric import Fabric # noqa: E402 from lightning.fabric.utilities.seed import seed_everything # noqa: E402 from lightning.pytorch.callbacks import Callback # noqa: E402 @@ -44,6 +44,7 @@ "CloudCompute", "Trainer", "LightningDataset", + "LightningIterableDataset", "LightningDataModule", "LightningModule", "Callback", diff --git a/src/lightning/data/CHANGELOG.md b/src/lightning/data/CHANGELOG.md index 552f04c4a8c89..7bac4402f73ca 100644 --- a/src/lightning/data/CHANGELOG.md +++ b/src/lightning/data/CHANGELOG.md @@ -9,3 +9,5 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Added `LightningDataset` for optimized data loading including fast loading for S3 buckets. ([#17743](https://github.com/Lightning-AI/lightning/pull/17743)) + +- Added `LightningIterableDataset` for resumable dataloading with iterable datasets ([#17998](https://github.com/Lightning-AI/lightning/pull/17998)) diff --git a/src/lightning/data/__init__.py b/src/lightning/data/__init__.py index e976bab846553..8cf8860401d4c 100644 --- a/src/lightning/data/__init__.py +++ b/src/lightning/data/__init__.py @@ -1,3 +1,3 @@ -from lightning.data.dataset import LightningDataset +from lightning.data.datasets import LightningDataset, LightningIterableDataset -__all__ = ["LightningDataset"] +__all__ = ["LightningDataset", "LightningIterableDataset"] diff --git a/src/lightning/data/datasets/__init__.py b/src/lightning/data/datasets/__init__.py new file mode 100644 index 0000000000000..0ab577fbcb872 --- /dev/null +++ b/src/lightning/data/datasets/__init__.py @@ -0,0 +1,4 @@ +from lightning.data.datasets.iterable import LightningIterableDataset +from lightning.data.datasets.mapping import LightningDataset + +__all__ = ["LightningDataset", "LightningIterableDataset"] diff --git a/src/lightning/data/datasets/base.py b/src/lightning/data/datasets/base.py new file mode 100644 index 0000000000000..d0ebc163402ad --- /dev/null +++ b/src/lightning/data/datasets/base.py @@ -0,0 +1,37 @@ +from typing import Any, Literal + +from torch.utils.data import Dataset as TorchDataset + +from lightning.data.backends import _DatasetBackend, LocalDatasetBackend, S3DatasetBackend +from lightning.data.fileio import OpenCloudFileObj + + +class _Dataset(TorchDataset): + """Base dataset class for streaming data from a cloud storage. + + Args: + backend: storage location of the data_source. current options are "s3" or "local" + """ + + def __init__(self, backend: Literal["local", "s3"] = "local"): + self.backend = self._init_backend(backend=backend) + + assert isinstance(self.backend, _DatasetBackend) + + def _init_backend(self, backend: str) -> _DatasetBackend: + """Picks the correct backend handler.""" + if backend == "s3": + return S3DatasetBackend() + if backend == "local": + return LocalDatasetBackend() + raise ValueError(f"Unsupported backend {backend}") + + def open(self, file: str, mode: str = "r", kwargs_for_open: Any = {}, **kwargs: Any) -> OpenCloudFileObj: + """Opens a stream for the given file. + + Returns: + A stream object of the file. + """ + return OpenCloudFileObj( + path=file, mode=mode, kwargs_for_open={**self.backend.credentials(), **kwargs_for_open}, **kwargs + ) diff --git a/src/lightning/data/datasets/env.py b/src/lightning/data/datasets/env.py new file mode 100644 index 0000000000000..4e923f34735f8 --- /dev/null +++ b/src/lightning/data/datasets/env.py @@ -0,0 +1,147 @@ +from typing import Optional + +import torch +from torch.utils.data import get_worker_info + + +class _DistributedEnv: + """The environment of the distributed training. + + Args: + world_size: The number of total distributed training processes + global_rank: The rank of the current process within this pool of training processes + """ + + def __init__(self, world_size: int, global_rank: int): + self.world_size = world_size + self.global_rank = global_rank + + @classmethod + def detect(cls) -> "_DistributedEnv": + """Tries to automatically detect the distributed environment paramters. + + Note: + This detection may not work in processes spawned from the distributed processes (e.g. DataLoader workers) + as the distributed framework won't be initialized there. + It will default to 1 distributed process in this case. + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + global_rank = torch.distributed.get_rank() + else: + world_size = None + global_rank = 0 + + if world_size is None or world_size == -1: + world_size = 1 + + return cls(world_size=world_size, global_rank=global_rank) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(world_size: {self.world_size}, global_rank: {self.global_rank}\n)" + + def __str__(self) -> str: + return repr(self) + + +class _WorkerEnv: + """Contains the environment for the current dataloader within the current training process. + + Args: + world_size: The number of dataloader workers for the current training process + rank: The rank of the current worker within the number of workers + """ + + def __init__(self, world_size: int, rank: int): + self.world_size = world_size + self.rank = rank + + @classmethod + def detect(cls) -> "_WorkerEnv": + """Automatically detects the number of workers and the current rank. + + Note: + This only works reliably within a dataloader worker as otherwise the necessary information won't be present. + In such a case it will default to 1 worker + """ + worker_info = get_worker_info() + num_workers = worker_info.num_workers if worker_info is not None else 1 + current_worker_rank = worker_info.id if worker_info is not None else 0 + + return cls(world_size=num_workers, rank=current_worker_rank) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(world_size: {self.world_size}, rank: {self.rank})" + + def __str__(self) -> str: + return repr(self) + + +class Environment: + """Contains the compute environment. If not passed, will try to detect. + + Args: + dist_env: The distributed environment (distributed worldsize and global rank) + worker_env: The worker environment (number of workers, worker rank) + """ + + def __init__(self, dist_env: Optional[_DistributedEnv], worker_env: Optional[_WorkerEnv]): + self.worker_env = worker_env + self.dist_env = dist_env + + @classmethod + def from_args( + cls, + dist_world_size: int, + global_rank: int, + num_workers: int, + current_worker_rank: int, + ) -> "Environment": + """Generates the Environment class by already given arguments instead of detecting them. + + Args: + dist_world_size: The worldsize used for distributed training (=total number of distributed processes) + global_rank: The distributed global rank of the current process + num_workers: The number of workers per distributed training process + current_worker_rank: The rank of the current worker within the number of workers of + the current training process + """ + dist_env = _DistributedEnv(dist_world_size, global_rank) + worker_env = _WorkerEnv(num_workers, current_worker_rank) + return cls(dist_env=dist_env, worker_env=worker_env) + + @property + def num_shards(self) -> int: + """Returns the total number of shards. + + Note: + This may not be accurate in a non-dataloader-worker process like the main training process + as it doesn't necessarily know about the number of dataloader workers. + """ + assert self.worker_env is not None + assert self.dist_env is not None + return self.worker_env.world_size * self.dist_env.world_size + + @property + def shard_rank(self) -> int: + """Returns the rank of the current process wrt. the total number of shards. + + Note: + This may not be accurate in a non-dataloader-worker process like the main training process as it + doesn't necessarily know about the number of dataloader workers. + """ + assert self.worker_env is not None + assert self.dist_env is not None + return self.dist_env.global_rank * self.worker_env.world_size + self.worker_env.rank + + def __repr__(self) -> str: + dist_env_repr = repr(self.dist_env) + worker_env_repr = repr(self.worker_env) + + return ( + f"{self.__class__.__name__}(\n\tdist_env: {dist_env_repr},\n\tworker_env: " + + f"{worker_env_repr}\n\tnum_shards: {self.num_shards},\n\tshard_rank: {self.shard_rank})" + ) + + def __str__(self) -> str: + return repr(self) diff --git a/src/lightning/data/dataset_index.py b/src/lightning/data/datasets/index.py similarity index 100% rename from src/lightning/data/dataset_index.py rename to src/lightning/data/datasets/index.py diff --git a/src/lightning/data/datasets/iterable.py b/src/lightning/data/datasets/iterable.py new file mode 100644 index 0000000000000..f97b16f89e20f --- /dev/null +++ b/src/lightning/data/datasets/iterable.py @@ -0,0 +1,444 @@ +import math +import warnings +from abc import ABC, abstractmethod +from copy import deepcopy +from typing import Any, Dict, Generator, List, Literal, Optional, Protocol, runtime_checkable, Sequence, Tuple + +import torch +from torch.utils.data import DataLoader as _DataLoader +from torch.utils.data import IterableDataset + +from lightning.data.datasets.base import _Dataset +from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv, Environment + + +class _StatefulIterableDataset(ABC, IterableDataset): + @abstractmethod + def state_dict(self, returned_samples: int, num_workers: int) -> Dict[str, Any]: + pass + + @abstractmethod + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + pass + + +class _Chunk: + """A single chunk of data. + + Args: + chunk_data: The original data contained by this chunk + chunk_size: The number of samples contained in this chunk + start_index: the index from where to start sampling the chunk (already retrieved samples) + """ + + def __init__(self, chunk_data: Any, chunk_size: int, start_index: int = 0): + self._chunk_data = chunk_data + self._index_permutations: Optional[Tuple[int, ...]] = None + self._start_index = start_index + self._chunk_size = chunk_size + + def shuffle(self, generator: Optional[torch.Generator] = None) -> "_Chunk": + """Shuffles index permutations for the current chunk.""" + new_indices = torch.randperm(self.chunk_size, generator=generator, device="cpu").tolist() + self._index_permutations = tuple(new_indices) + return self + + def __iter__(self) -> Generator[int, None, None]: + """Returns an iterator over the index permutations.""" + # iterates over indices + index_permutations = self.index_permutations + for i in range(self._start_index, self.chunk_size): + yield index_permutations[i] + + @property + def chunk_size(self) -> int: + return self._chunk_size + + @property + def index_permutations(self) -> Tuple[int, ...]: + if self._index_permutations is None: + return tuple(range(self._chunk_size)) + return self._index_permutations + + +class LightningIterableDataset(_StatefulIterableDataset, _Dataset): + """An iterable dataset that can be resumed mid-epoch, implements chunking and sharding of chunks. The behavior + of this dataset can be customized with the following hooks: + + - ``prepare_chunk`` gives the possibility to prepare the chunk one iteration before its actually loaded + (e.g. download from s3). + - ``load_chunk`` implements how an entire chunk is loaded into memory + (e.g. loading the previously downloaded file into memory) + - ``load_sample_from_chunk`` implements how to retrieve a single sample from the current chunk + (e.g. indexing the chunk if it's a list or just returning it if the chunk has a size of 1) + + Args: + chunks: The chunked_data to load. + chunk_size: The number of samples in each chunk + num_parallel_chunks: How many chunks to load in parallel. + env: The compute-environment. Important for sharding. Contains the distributed world-size, + the distributed global rank, the number of workers on the current rank and the current worker rank. + If None, it will try to detect these things automatically. + shuffle: Whether to shuffle your data. Will shuffle both, order of chunks before sharding and order of + samples within each chunk. + seed: The seed for the random number generator. If :param:`shuffle` = False, the seed has no effect. + wrap: Whether to restart your dataset if it's exhausted. If set to True, it results in a + virtually infinite dataset looping through the same data over and over again. + lazy_shuffle: Whether to shuffle your data lazily instead of upfront. + This consumes a lot less memory, but may yield undeterministic results. + backend: A string pointing to the respective cloud-backend to use. Currently "s3" and "local" are supported. + + Note: + :param:`lazy_shuffle` is experimental, consumes less memory than shuffling everything in advance (default) + but may result in undeterministic behavior. + + Note: + On resume from a state-dict, we always skip currently started chunks as these would make the data-order + impossible to determine with sharding. Upon resuming from a point, where a new chunk would be started + anyways, nothing is skipped. + + Note: + Order of data is only guaranteed when resuming with the same distributed settings and the same number of + workers. Everything else leads to different sharding and therefore results in different data order. + """ + + def __init__( + self, + chunks: Sequence[Any], + chunk_size: int = 1, + num_parallel_chunks: int = 1, + env: Optional[Environment] = None, + shuffle: bool = False, + seed: Optional[int] = None, + wrap: bool = False, + lazy_shuffle: bool = False, + backend: Literal["local", "s3"] = "local", + ): + _StatefulIterableDataset.__init__(self) + _Dataset.__init__(self, backend=backend) + + chunks = [_Chunk(c, chunk_size=chunk_size) for c in chunks] + if env is None: + # must detect distributed env here since distributed is not initialized in worker processes with ddp spawn + # can't detect worker env here since workers not yet initialized + env = Environment(_DistributedEnv.detect(), None) + self._env = env + self._shuffle = shuffle + self._lazy_shuffle = lazy_shuffle + + # prepare shuffling + if shuffle: + generator = torch.Generator() + if seed is not None: + generator = generator.manual_seed(seed) + else: + generator = None + + self._seed = seed + self._generator = generator + self._initial_generator_state = self._generator.get_state() if self._generator is not None else None + + self._num_parallel_chunks = num_parallel_chunks + self._chunks = chunks + self._original_chunks = chunks + + self._chunk_size = chunk_size + self._local_chunks: List[_Chunk] = [] + self._wrap = wrap + + self._start_index_chunk = 0 + self._start_index_sample = 0 + self._curr_chunk_index = 0 + self._curr_sample_index = 0 + + self._curr_loaded_chunks: List[_Chunk] = [] + self._curr_loaded_num_samples = 0 + + @abstractmethod + def load_chunk(self, chunk: Any) -> Any: + """Implement this to load a single chunk into memory. This could e.g. mean loading the file that has + previously been downloaded from s3. + + Args: + chunk: The chunk that should be currently loaded + """ + + @abstractmethod + def load_sample_from_chunk(self, chunk: Any, index: int) -> Any: + """Implement this to retrieve a single sample from a given (already loaded) chunk. This could be indexing a + list or returning the entire chunk if it's size is 1. + + Args: + chunk: The chunk the sample should be retrieved from + index: The index of the current sample to retrieve within the chunk. + """ + + def prepare_chunk(self, chunk: Any) -> None: + """Prepares a single chunk before it is actually loaded. This could e.g. download the actual file from s3. + + Args: + chunk: the chunk data to prepare. + """ + + def __iter__(self) -> "LightningIterableDataset": + """Creates an iterator. + + Before that, detects the env if necessary, shuffles chunks, shards the data and shuffles sample orders within + chunks. + """ + self._curr_chunk_index = self._start_index_chunk + self._curr_sample_index = self._start_index_sample + if self._env.worker_env is None: + self._env.worker_env = _WorkerEnv.detect() + + self._chunks = self._shuffle_if_necessary(self._chunks, 0, shuffle_chunk_order=True, shuffle_sample_order=False) + self._apply_sharding() + self._local_chunks = self._shuffle_if_necessary( + self._local_chunks, + self._curr_chunk_index, + shuffle_chunk_order=False, + shuffle_sample_order=True, + ) + self._ensure_chunks_loaded() + return self + + def __next__(self) -> Any: + """Returns the next sample. + + If necessary, this also loads the new chunks. + """ + self._check_if_sharded() + self._ensure_chunks_loaded() + + if self._curr_sample_index >= self._curr_loaded_num_samples: + self._curr_chunk_index += self._num_parallel_chunks + self._check_dataset_end() + + self._load_next_chunks() + self._curr_sample_index = 0 + + remainder = self._curr_sample_index + curr_loaded_chunk_idx = 0 + for i, c in enumerate(self._curr_loaded_chunks): + if c.chunk_size > remainder: + curr_loaded_chunk_idx = i + break + + remainder -= c.chunk_size + + sample = self.load_sample_from_chunk( + self._curr_loaded_chunks[curr_loaded_chunk_idx]._chunk_data, + self._curr_loaded_chunks[curr_loaded_chunk_idx].index_permutations[remainder], + ) + self._curr_sample_index += 1 + + return sample + + def state_dict(self, returned_samples: int, num_workers: int) -> Dict[str, Any]: + """Returns a global state-dict across all shards and workers. For construction of a global state-dict the + `returned_samples` and `num_workers` arguments are required, since the main process, which is taking this + state-dict, typically does not have access to worker_info. + + Args: + returned_samples: the number of totally returned samples by the dataloader(s) (across all distributed + training processes). + num_workers: number of dataloader workers per distributed training process. + """ + + # compute indices locally again since other workers may have different offsets + if num_workers == 0: + # num_workers=0 indicate loading in the main process --> main process becomes 1 effective worker + num_workers = 1 + + # manually compute num_shards since env doesn't know about num_workers in main process outside dataloader iter + assert self._env.dist_env is not None + num_shards = self._env.dist_env.world_size * num_workers + + # fast-forward so that each chunk on each shard is finished -> this may skip a few samples! + curr_index = math.ceil(returned_samples / num_shards / self._chunk_size) * num_shards * self._chunk_size + + # since we go to next chunk, always start at beginning of chunk + curr_sample_in_chunk = 0 + + # global chunk index + curr_chunk_index = math.ceil(curr_index / self._chunk_size) + return { + "current_chunk": curr_chunk_index, + "current_sample_in_chunk": curr_sample_in_chunk, + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Loads a previous state dict to resume for it's state. + + Args: + state_dict: the previous state-dict containing internal indices and random number generator states. + + Note: + Some of the changes only take effect when creating a new iterator + """ + state_dict = deepcopy(state_dict) + self._start_index_chunk = state_dict.pop("current_chunk") + self._start_index_sample = state_dict.pop("current_sample_in_chunk") + + self._curr_chunk_index = self._start_index_chunk + self._curr_sample_index = self._start_index_sample + + self._curr_loaded_chunks = [] + self._curr_loaded_num_samples = 0 + + def _ensure_chunks_loaded(self) -> None: + """Ensures that the correct number of chunks is loaded.""" + if len(self._curr_loaded_chunks) != self._num_parallel_chunks: + self._check_dataset_end() + self._load_next_chunks() + + def _load_next_chunks(self) -> None: + """Loads the current chunks and prepares the chunks thereafter.""" + self._curr_loaded_chunks = [] + self._curr_loaded_num_samples = 0 + # load next N chunks + for i in range(self._num_parallel_chunks): + curr_chunk = self._local_chunks[self._curr_chunk_index + i] + loaded_chunk = _Chunk( + self.load_chunk(curr_chunk._chunk_data), + chunk_size=curr_chunk.chunk_size, + start_index=curr_chunk._start_index, + ) + if self._lazy_shuffle: + loaded_chunk.shuffle(generator=self._generator) + else: + loaded_chunk._index_permutations = curr_chunk.index_permutations + self._curr_loaded_chunks.append(loaded_chunk) + self._curr_loaded_num_samples += loaded_chunk.chunk_size + + # prepare the next N chunks after currently loaded ones + for i in range(self._num_parallel_chunks, self._num_parallel_chunks * 2): + if self._curr_chunk_index + i >= len(self._local_chunks): + break + curr_chunk = self._local_chunks[self._curr_chunk_index + i] + self.prepare_chunk(curr_chunk._chunk_data) + + def _apply_sharding(self) -> None: + """Shards the chunks if necessary. + + No-op if already sharded + """ + if not self._local_chunks: + num_shards = self._env.num_shards + + # every shard must have the same number of chunks -> truncate if not evenly divisible + max_chunks = len(self._chunks) // num_shards * num_shards + self._local_chunks = self._chunks[self._env.shard_rank : max_chunks : num_shards] + + # if state-dict was set, the curr chunk index was the global number across all shards + # --> divide it to get local number + if self._start_index_chunk and self._start_index_chunk == self._curr_chunk_index: + self._curr_chunk_index = math.ceil(self._curr_chunk_index // self._env.num_shards) + + def _check_if_sharded(self) -> None: + """Raises a warning if the dataset is not sharded.""" + if not self._local_chunks: + warnings.warn( + "Chunks have not been sharded yet. Call iter() on your dataset to ensure sharding is done correctly. " + "It won't recognize dataloader workers when manually calling it outside an actual dataloader." + ) + + def _check_dataset_end(self) -> None: + """Checks if the dataset has reached it's end or should be restarted.""" + if self._curr_chunk_index >= len(self._local_chunks): + if self._wrap: + self._curr_chunk_index = 0 + else: + raise StopIteration + + def _shuffle_if_necessary( + self, + chunks: List[_Chunk], + first_chunk_index: int, + shuffle_chunk_order: bool = True, + shuffle_sample_order: bool = True, + ) -> List[_Chunk]: + """This shuffles the chunk-order and the order of samples within each chunk. + + Args: + chunks: The chunks to optionally shuffle + first_chunk_index: The point to which the generator should be replayed + shuffle_chunk_order: Whether to shuffle the order of chunks + shuffle_sample_order: Whether to shuffle the order of samples within a chunk + """ + # re-seed generator + if self._generator is not None and self._initial_generator_state is not None: + self._generator = self._generator.set_state(self._initial_generator_state) + + # shuffle chunks if necessary + chunks = _Chunk(chunks, len(chunks)) + if self._shuffle and shuffle_chunk_order: + chunks.shuffle(generator=self._generator) + # this is annoying but otherwise we cannot make sure the states are the same + elif self._shuffle: + _dummy_chunks = _Chunk(None, len(chunks._chunk_data)) + _dummy_chunks.shuffle(generator=self._generator) + chunks = [chunks._chunk_data[i] for i in chunks] + + if not shuffle_sample_order: + return chunks + + # after shuffling all chunks -> fast forward to first_chunk_index + if self._shuffle and self._lazy_shuffle: + for _ in range(first_chunk_index): + chunk = _Chunk(None, chunk_size=self._chunk_size) + chunk.shuffle(generator=self._generator) + # shuffle samples within each chunk + elif self._shuffle: + chunks = [c.shuffle(generator=self._generator) for c in chunks] + + return chunks + + +class DataLoader(_DataLoader): + __doc__ = _DataLoader.__doc__ + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.returned_samples = 0 + + def __iter__(self) -> Generator[Any, None, None]: # type: ignore + base_iter = super().__iter__() + + for batch in base_iter: + self.returned_samples += self._get_batch_size(batch) + yield batch + + def _get_batch_size(self, batch: Any) -> int: + if isinstance(batch, torch.Tensor): + return batch.size(0) + if isinstance(batch, Sequence): + return len(batch[0]) + + assert isinstance(self.batch_size, int) + return self.batch_size + + def state_dict(self) -> Dict[str, Any]: + """Returns the state-dict of the dataset.""" + if isinstance(self.dataset, _Stateful): + state_dict = self.dataset.state_dict(returned_samples=self.returned_samples, num_workers=self.num_workers) + return {"returned_samples": self.returned_samples, "dataset": state_dict} + + raise TypeError("The dataset has no method `state_dict` that accepts `returned_samples` and `num_workers`") + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Loads a given state-dict onto the dataset.""" + self.returned_samples = state_dict.pop("returned_samples") + if isinstance(self.dataset, _Stateful): + return self.dataset.load_state_dict(state_dict["dataset"]) + + raise TypeError("The dataset has no method `load_state_dict` accepting a `state_dict`") + + +@runtime_checkable +class _Stateful(Protocol): + def state_dict(self, returned_samples: int, num_workers: int) -> Dict[str, Any]: + pass + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + pass diff --git a/src/lightning/data/dataset.py b/src/lightning/data/datasets/mapping.py similarity index 62% rename from src/lightning/data/dataset.py rename to src/lightning/data/datasets/mapping.py index 831306dfa08ce..046824cb2d38a 100644 --- a/src/lightning/data/dataset.py +++ b/src/lightning/data/datasets/mapping.py @@ -1,29 +1,26 @@ import os import tempfile from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any, Literal, Optional -from torch.utils.data import Dataset as TorchDataset - -from lightning.data.backends import _DatasetBackend, LocalDatasetBackend, S3DatasetBackend -from lightning.data.dataset_index import get_index +from lightning.data.datasets.base import _Dataset +from lightning.data.datasets.index import get_index from lightning.data.fileio import OpenCloudFileObj -class LightningDataset(TorchDataset, ABC): +class LightningDataset(_Dataset, ABC): """Dataset wrapper for optimized dataloading. - Arguments: - + Args: data_source: path of data directory. ex. s3://mybucket/path - backend: storage location of the data_source. current options are "s3" or "local" - path_to_index_file: path to index file that lists all file contents of the data_source. """ - def __init__(self, data_source: str, backend: str = "local", path_to_index_file: Optional[str] = None): - super().__init__() + def __init__( + self, data_source: str, backend: Literal["local", "s3"] = "local", path_to_index_file: Optional[str] = None + ): + super().__init__(backend=backend) self.data_source = data_source if not path_to_index_file: @@ -34,18 +31,6 @@ def __init__(self, data_source: str, backend: str = "local", path_to_index_file: self.files = self.get_index() - self.backend = self._init_backend(backend=backend) - - assert isinstance(self.backend, _DatasetBackend) - - def _init_backend(self, backend: str) -> _DatasetBackend: - """Picks the correct backend handler.""" - if backend == "s3": - return S3DatasetBackend() - if backend == "local": - return LocalDatasetBackend() - raise ValueError(f"Unsupported backend {backend}") - def get_index(self) -> Any: """Gets existing index or triggers an index generation if it doesn't exist for the provided data_source. @@ -59,16 +44,6 @@ def get_index(self) -> Any: index = f.readlines() return (line.strip("\n") for line in index) - def open(self, file: str, mode: str = "r", kwargs_for_open: Any = {}, **kwargs: Any) -> OpenCloudFileObj: - """Opens a stream for the given file. - - Returns: - A stream object of the file. - """ - return OpenCloudFileObj( - path=file, mode=mode, kwargs_for_open={**self.backend.credentials(), **kwargs_for_open}, **kwargs - ) - def __getitem__(self, idx: int) -> Any: """Get's item from the dataset at provided index. diff --git a/tests/tests_data/conftest.py b/tests/tests_data/conftest.py new file mode 100644 index 0000000000000..d7bf63b80ae36 --- /dev/null +++ b/tests/tests_data/conftest.py @@ -0,0 +1,10 @@ +import pytest +import torch.distributed + + +@pytest.fixture(autouse=True) +def teardown_process_group(): + """Ensures that the distributed process group gets closed before the next test runs.""" + yield + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() diff --git a/tests/tests_data/datasets/__init__.py b/tests/tests_data/datasets/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tests_data/test_data/test_index.txt b/tests/tests_data/datasets/test_data/test_index.txt similarity index 100% rename from tests/tests_data/test_data/test_index.txt rename to tests/tests_data/datasets/test_data/test_index.txt diff --git a/tests/tests_data/test_data/test_index_s3.txt b/tests/tests_data/datasets/test_data/test_index_s3.txt similarity index 100% rename from tests/tests_data/test_data/test_index_s3.txt rename to tests/tests_data/datasets/test_data/test_index_s3.txt diff --git a/tests/tests_data/datasets/test_env.py b/tests/tests_data/datasets/test_env.py new file mode 100644 index 0000000000000..6be415cb7e021 --- /dev/null +++ b/tests/tests_data/datasets/test_env.py @@ -0,0 +1,116 @@ +from functools import partial + +import pytest +import torch +from torch.utils.data import get_worker_info + +from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv, Environment +from lightning.fabric import Fabric + + +@pytest.mark.parametrize( + ( + "num_workers", + "current_worker_rank", + "dist_world_size", + "global_rank", + "expected_num_shards", + "expected_shard_rank", + ), + [ + pytest.param(1, 0, 1, 0, 1, 0), + pytest.param(1, 0, 2, 0, 2, 0), + pytest.param(1, 0, 2, 1, 2, 1), + pytest.param(2, 0, 1, 0, 2, 0), + pytest.param(2, 1, 1, 0, 2, 1), + pytest.param(2, 0, 2, 0, 4, 0), + pytest.param(2, 1, 2, 0, 4, 1), + pytest.param(2, 0, 2, 1, 4, 2), + pytest.param(2, 1, 2, 1, 4, 3), + ], +) +def test_environment( + num_workers, + current_worker_rank, + dist_world_size, + global_rank, + expected_num_shards, + expected_shard_rank, +): + env = Environment.from_args(dist_world_size, global_rank, num_workers, current_worker_rank) + assert env.num_shards == expected_num_shards + assert env.shard_rank == expected_shard_rank + + assert "Environment(" in repr(env) + assert "Environment(" in str(env) + + assert "\n\tdist_env: _DistributedEnv(" in repr(env) + assert "\n\tdist_env: _DistributedEnv(" in str(env) + assert "_DistributedEnv(" in repr(env.dist_env) + assert "_DistributedEnv(" in str(env.dist_env) + + assert "\n\tworker_env: _WorkerEnv(" in repr(env) + assert "\n\tworker_env: _WorkerEnv(" in str(env) + assert "_WorkerEnv(" in repr(env.worker_env) + assert "_WorkerEnv(" in str(env.worker_env) + + assert f"world_size: {num_workers}" in repr(env) + assert f"world_size: {num_workers}" in str(env) + assert f"world_size: {num_workers}" in repr(env.worker_env) + assert f"world_size: {num_workers}" in str(env.worker_env) + + assert f"rank: {current_worker_rank}" in repr(env) + assert f"rank: {current_worker_rank}" in str(env) + assert f"rank: {current_worker_rank}" in repr(env.worker_env) + assert f"rank: {current_worker_rank}" in str(env.worker_env) + + assert f"world_size: {dist_world_size}" in repr(env) + assert f"world_size: {dist_world_size}" in str(env) + assert f"world_size: {dist_world_size}" in repr(env.dist_env) + assert f"world_size: {dist_world_size}" in str(env.dist_env) + + assert f"global_rank: {global_rank}" in repr(env) + assert f"global_rank: {global_rank}" in str(env) + assert f"global_rank: {global_rank}" in repr(env.dist_env) + assert f"global_rank: {global_rank}" in str(env.dist_env) + + assert f"shard_rank: {expected_shard_rank}" in repr(env) + assert f"shard_rank: {expected_shard_rank}" in str(env) + + assert f"num_shards: {expected_num_shards}" in repr(env) + assert f"num_shards: {expected_num_shards}" in str(env) + + +class EnvTestDataset(torch.utils.data.IterableDataset): + def __init__(self, num_workers, dist_size, global_rank): + self.num_workers = num_workers + self.dist_size = dist_size + self.global_rank = global_rank + self.env = Environment(_DistributedEnv.detect(), None) + + def __iter__(self): + worker_info = get_worker_info() + env = self.env + env.worker_env = _WorkerEnv.detect() + assert env.worker_env.world_size == self.num_workers + assert env.dist_env.world_size == self.dist_size + assert env.dist_env.global_rank == self.global_rank + assert env.worker_env.rank == (worker_info.id if worker_info is not None else 0) + + yield 0 + + +def env_auto_test(fabric: Fabric, num_workers): + dset = EnvTestDataset(max(1, num_workers), fabric.world_size, fabric.global_rank) + loader = torch.utils.data.DataLoader(dset, num_workers=num_workers) + + # this triggers the `__iter__` of the dataset containing the actual test + for _ in loader: + pass + + +@pytest.mark.parametrize("num_workers", [0, 1, 2]) +@pytest.mark.parametrize("dist_world_size", [1, 2]) +def test_env_auto(num_workers, dist_world_size): + fabric = Fabric(accelerator="cpu", devices=dist_world_size, strategy="ddp_spawn") + fabric.launch(partial(env_auto_test, num_workers=num_workers)) diff --git a/tests/tests_data/test_get_index.py b/tests/tests_data/datasets/test_get_index.py similarity index 94% rename from tests/tests_data/test_get_index.py rename to tests/tests_data/datasets/test_get_index.py index 81703b4d45de5..821a9280cfb12 100644 --- a/tests/tests_data/test_get_index.py +++ b/tests/tests_data/datasets/test_get_index.py @@ -6,8 +6,8 @@ import pytest from lightning_utilities.core.imports import package_available -from lightning.data import dataset_index -from lightning.data.dataset_index import get_index +import lightning.data.datasets.index as dataset_index +from lightning.data.datasets.index import get_index THIS_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -54,7 +54,7 @@ def image_set(tmp_path_factory): @pytest.mark.skip(reason="Need a valid AWS key and AWS secret key in CI for this to work") -@mock.patch("lightning.data.dataset_index.LightningClient", MagicMock()) +@mock.patch("lightning.data.datasets.index.LightningClient", MagicMock()) def test_get_index_generate_for_s3_bucket(monkeypatch): """Can generate an index as s3 bucket mounted localled on the Lightning AI platform.""" @@ -84,7 +84,7 @@ def test_get_index_generate_for_s3_bucket(monkeypatch): @pytest.mark.skipif(not package_available("lightning"), reason="Supported only with mono-package") -@mock.patch("lightning.data.dataset_index.LightningClient", MagicMock()) +@mock.patch("lightning.data.datasets.index.LightningClient", MagicMock()) def test_get_index_generate_for_local_folder(image_set, monkeypatch): """Can generate an index for an s3 bucket.""" diff --git a/tests/tests_data/datasets/test_iterable.py b/tests/tests_data/datasets/test_iterable.py new file mode 100644 index 0000000000000..8cbeeb525ff4d --- /dev/null +++ b/tests/tests_data/datasets/test_iterable.py @@ -0,0 +1,610 @@ +import math +import sys +from collections import Counter +from functools import partial +from typing import Any, Dict + +import pytest +import torch + +import lightning +from lightning.data.datasets.iterable import ( + _Chunk, + _Stateful, + _StatefulIterableDataset, + DataLoader, + LightningIterableDataset, +) + + +class Foo1: + def state_dict(self, returned_samples: int) -> Dict[str, Any]: + pass + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + pass + + +class Foo2: + def state_dict(self) -> Dict[str, Any]: + pass + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + pass + + +class Bar1: + pass + + +class Bar2: + def state_dict(self) -> Dict[str, Any]: + pass + + +@pytest.mark.parametrize( + ("klass", "fullfilled"), + [ + pytest.param(Foo1, True), + pytest.param(Foo2, True), + pytest.param(Bar1, False), + pytest.param(Bar2, False), + ], +) +def test_serializable(klass, fullfilled): + assert isinstance(klass(), _Stateful) == fullfilled + + +class DummyIterableDataset(_StatefulIterableDataset): + def __init__(self, length: int): + super().__init__() + self.length = length + self.index = 0 + + def __iter__(self): + return self + + def __next__(self): + if self.index >= self.length: + raise StopIteration + + self.index += 1 + return 0 + + +class WrongDummySerializableIterableDataset1(DummyIterableDataset): + def state_dict(self): + return {"length": self.length, "index": self.index} + + +class WrongDummySerializableIterableDataset2(DummyIterableDataset): + def load_state_dict(self, state_dict): + self.length = state_dict.pop("length") + self.index = state_dict.pop("index") + + +class WorkingDummySerializableIterableDataset( + WrongDummySerializableIterableDataset1, WrongDummySerializableIterableDataset2 +): + pass + + +@pytest.mark.parametrize( + ("klass", "missing_method"), + [ + pytest.param(WrongDummySerializableIterableDataset1, "load_state_dict"), + pytest.param(WrongDummySerializableIterableDataset2, "state_dict"), + ], +) +def test_required_abstract_methods_serializable_dataset(klass, missing_method): + with pytest.raises( + TypeError, + match=f"Can't instantiate abstract class {klass.__name__} with abstract method.* {missing_method}", + ): + klass(10) + + +def test_serialization_iterable_dataset(): + dset = WorkingDummySerializableIterableDataset(10) + + dset_iter = iter(dset) + + assert dset_iter is dset + + for i in range(10): + assert dset.state_dict() == {"length": 10, "index": i} + next(dset_iter) + assert dset.state_dict() == {"length": 10, "index": i + 1} + + +def test_iteration_serializable_iterable_dataset(): + dset = WorkingDummySerializableIterableDataset(10) + + i = 0 + + for _ in dset: + i = i + 1 + + assert i == 10 + + +def test_resume_iterable_dataset(): + dset1 = WorkingDummySerializableIterableDataset(10) + dset1_iter = iter(dset1) + + for _ in range(5): + next(dset1_iter) + + assert dset1.state_dict() == {"length": 10, "index": 5} + + dset2 = WorkingDummySerializableIterableDataset(12) + dset2.load_state_dict(dset1.state_dict()) + + assert dset2.length == 10 + assert dset2.index == 5 + + i = 0 + for _ in dset2: + i = i + 1 + + assert i == 5 + + dset2.length = 12 + for _ in dset2: + i = i + 1 + + assert i == 7 + assert dset2.state_dict() == {"length": 12, "index": 12} + + +class WrongChunkedDataset1(LightningIterableDataset): + def load_chunk(self, curr_chunk: int): + return [(curr_chunk, i) for i in range(self._chunk_size)] + + +class WrongChunkedDataset2(LightningIterableDataset): + def load_sample_from_chunk(self, curr_chunk, curr_index): + return curr_chunk[curr_index] + + +class WorkingChunkedDataset(WrongChunkedDataset1, WrongChunkedDataset2): + pass + + +@pytest.mark.parametrize( + ("klass", "missing_method"), + [ + pytest.param(WrongChunkedDataset1, "load_sample_from_chunk"), + pytest.param(WrongChunkedDataset2, "load_chunk"), + ], +) +def test_required_abstract_methods_chunked_dataset(klass, missing_method): + with pytest.raises( + TypeError, + match=f"Can't instantiate abstract class {klass.__name__} with abstract method.* {missing_method}", + ): + klass([10], 10) + + +def test_chunked_dataset_iteration(): + dset = WorkingChunkedDataset(list(range(5)), chunk_size=2, shuffle=False, wrap=False) + + curr_item = 0 + for i, item in enumerate(dset): + assert item[0] == curr_item + assert item[1] == i % 2 + curr_item += item[1] + + # goes to 4 but increases again in last item + assert curr_item == 5 + assert i == 9 + + +@pytest.mark.parametrize("lazy_shuffle", [False, True]) +def test_chunk_dataset_iteration_shuffle(lazy_shuffle): + dset = WorkingChunkedDataset( + list(range(5)), + chunk_size=2, + shuffle=True, + seed=12345, + wrap=False, + lazy_shuffle=lazy_shuffle, + ) + counter = Counter() + + series = [] + unexpected_series = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4] + + series_keys = [] + unexpected_series_keys = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + + for item, key in dset: + counter.update({item: 1}) + series.append(item) + series_keys.append(key) + + for val in counter.values(): + assert val == 2 + + # with shuffling it can't be equal to ordered! + assert series != unexpected_series + assert series_keys != unexpected_series_keys + + +def test_chunked_dataset_wrap(): + dset = WorkingChunkedDataset(list(range(5)), chunk_size=2, shuffle=True, seed=12345, wrap=True) + + dset_iter = iter(dset) + + # dataset has length 10, so this wraps 2 times + for i in range(21): + _ = next(dset_iter) + + +def test_chunked_dataset_resume_and_reset(): + dset = WorkingChunkedDataset(list(range(5)), chunk_size=2, shuffle=False, wrap=False) + + for i, item in enumerate(dset): + assert item[0] == 0 + assert item[1] == i + if i == 1: + break + + # Every iterator starts from scratch + for i, item in enumerate(dset): + assert item[0] == 0 + assert item[1] == i + if i == 1: + break + + # this would be set when we load from state dict + dset._start_index_sample = 1 + for i, item in enumerate(dset): + assert item[0] == i + assert item[1] == (i + 1) % 2 + if i == 1: + break + + dset._start_index_chunk == 1 + for i, item in enumerate(dset): + assert item[0] == 1 + assert item[1] == (i + 1) % 2 + if i == 1: + break + + +@pytest.mark.parametrize("shuffle", [False, True]) +def test_chunked_dataset_serialization(shuffle): + dset = WorkingChunkedDataset(list(range(5)), chunk_size=2, shuffle=shuffle, wrap=False) + + assert dset.state_dict(0, 0) == {"current_chunk": 0, "current_sample_in_chunk": 0} + + dset_iter = iter(dset) + assert dset.state_dict(0, 0) == {"current_chunk": 0, "current_sample_in_chunk": 0} + + dset.load_state_dict(dset.state_dict(0, 0)) + assert dset.state_dict(0, 0) == {"current_chunk": 0, "current_sample_in_chunk": 0} + + dset_iter = iter(dset) + + # throw away first few batches to alter internal state + for i in range(3): + next(dset_iter) + + curr_state = dset.state_dict(3, 0) + + original = [next(dset_iter) for _ in range(5)] + + dset.load_state_dict(curr_state) + dset_iter = iter(dset) + after_loading = [next(dset_iter) for _ in range(5)] + + # this isn't because we always skip to beginning of next chunk when loading and not already at beginning of chunk + assert original != after_loading + assert original[1:] == after_loading[:-1] + + # this actually puts us already on beginning of a chunk, but we'll forward to beginning of next chunk, + # otherwise we'd two times resume from same checkpoint and assert equal behavior + dset.load_state_dict(curr_state) + _ = [next(dset_iter) for _ in range(2)] + + new_curr_state = dset.state_dict(6, 0) + + new_original = [next(dset_iter) for _ in range(3)] + + dset.load_state_dict(new_curr_state) + new_after_loading = [next(dset_iter) for _ in range(3)] + + # this is equal since we exactly stopped at beginning of new chunk + assert new_original == new_after_loading + + +class ChunkedTestDatasetDistributed(WorkingChunkedDataset): + def _apply_sharding(self): + super()._apply_sharding() + + assert len(self._local_chunks) == self.expected_num_chunks + + for i in range(1, len(self._local_chunks)): + assert self._local_chunks[i]._chunk_data - self._local_chunks[i - 1]._chunk_data == self.expected_step_width + + +def sharding_test(fabric: lightning.Fabric, num_workers): + dset = ChunkedTestDatasetDistributed(list(range(50)), 2, shuffle=False, wrap=False) + + num_shards = max(1, num_workers) * fabric.world_size + + # num_workers = 0 still has a single worker (the main process) + expected_num_chunks = 50 // num_shards + dset.expected_num_chunks = expected_num_chunks + dset.expected_step_width = fabric.world_size * max(num_workers, 1) + + num_samples_per_rank = max(num_workers, 1) * 2 * expected_num_chunks + loader = torch.utils.data.DataLoader(dset, num_workers=num_workers) + + for i, _ in enumerate(loader): + fabric.barrier() + + assert i == num_samples_per_rank - 1 + + +@pytest.mark.parametrize( + ("num_workers", "world_size"), + [ + pytest.param(0, 1), + pytest.param( + 0, + 2, + marks=pytest.mark.skipif( + sys.platform != "linux", reason="multiprocessing on other platforms takes forever" + ), + ), + pytest.param( + 1, + 1, + marks=pytest.mark.skipif( + sys.platform != "linux", reason="multiprocessing on other platforms takes forever" + ), + ), + pytest.param( + 1, + 2, + marks=pytest.mark.skipif( + sys.platform != "linux", reason="multiprocessing on other platforms takes forever" + ), + ), + pytest.param( + 2, + 1, + marks=pytest.mark.skipif( + sys.platform != "linux", reason="multiprocessing on other platforms takes forever" + ), + ), + pytest.param( + 2, + 2, + marks=pytest.mark.skipif( + sys.platform != "linux", reason="multiprocessing on other platforms takes forever" + ), + ), + ], +) +def test_sharding(num_workers, world_size): + fabric = lightning.Fabric(accelerator="cpu", devices=world_size, strategy="ddp_spawn") + fabric.launch(partial(sharding_test, num_workers=num_workers)) + + +def sharding_resume_test(fabric: lightning.Fabric, num_workers): + chunk_size = 2 + dset = WorkingChunkedDataset(list(range(100)), chunk_size, shuffle=False, wrap=False) + loader = torch.utils.data.DataLoader(dset, num_workers=num_workers, shuffle=False) + num_shards = max(1, num_workers) * fabric.world_size + + for i in [23, 37, 10, 20]: + curr_index = math.ceil(i / num_shards / chunk_size) * num_shards * chunk_size + next_chunk = math.ceil(curr_index / chunk_size) + + curr_state = dset.state_dict(i, num_workers=num_workers) + assert curr_state == {"current_chunk": next_chunk, "current_sample_in_chunk": 0} + + dset.load_state_dict(curr_state) + loader = torch.utils.data.DataLoader(dset, num_workers=num_workers, shuffle=False) + + # calculate starting chunks + # next_chunk + fabric.global_rank * max(1,num_workers) determines the base offset for each rank + # i % chunk_size makes sure that workers are alternating + # e.g. w0 returns first element of first chunk then w1 returns first element of second chunk then w0 returns + # second element of first chunk etc. + # i // num_shards * num_shards progresses to next chunks + curr_worker_chunk = { + i: next_chunk + + fabric.global_rank * max(1, num_workers) + + i % max(1, num_workers) + + i // (chunk_size * num_shards) + for i in range(max(1, num_workers)) + } + curr_worker_chunk_elem = {i: 0 for i in range(max(1, num_workers))} + + for i, batch in enumerate(loader): + curr_worker = i % max(1, num_workers) + assert batch[0] == curr_worker_chunk[curr_worker] + assert batch[1] == curr_worker_chunk_elem[curr_worker] + + curr_worker_chunk_elem[curr_worker] += 1 + + if curr_worker_chunk_elem[curr_worker] == chunk_size: + curr_worker_chunk[curr_worker] += num_shards + curr_worker_chunk_elem[curr_worker] = 0 + fabric.barrier() + + +@pytest.mark.parametrize( + ("num_workers", "world_size"), + [ + pytest.param(0, 1), + pytest.param( + 0, + 2, + marks=pytest.mark.skipif( + sys.platform != "linux", reason="multiprocessing on other platforms takes forever" + ), + ), + pytest.param( + 1, + 1, + marks=pytest.mark.skipif( + sys.platform != "linux", reason="multiprocessing on other platforms takes forever" + ), + ), + pytest.param( + 1, + 2, + marks=pytest.mark.skipif( + sys.platform != "linux", reason="multiprocessing on other platforms takes forever" + ), + ), + pytest.param( + 2, + 1, + marks=pytest.mark.skipif( + sys.platform != "linux", reason="multiprocessing on other platforms takes forever" + ), + ), + pytest.param( + 2, + 2, + marks=pytest.mark.skipif( + sys.platform != "linux", reason="multiprocessing on other platforms takes forever" + ), + ), + ], +) +def test_chunked_dataset_sharded_state_dict_resume(num_workers, world_size): + fabric = lightning.Fabric(accelerator="cpu", devices=world_size, strategy="ddp_spawn") + fabric.launch(partial(sharding_resume_test, num_workers=num_workers)) + + +@pytest.mark.parametrize("chunk_size", [20, 30, 40]) +@pytest.mark.parametrize("shuffle", [False, True]) +@pytest.mark.parametrize("shuffle_seed", [None, 123]) +@pytest.mark.parametrize("delayed_start", [False, True]) +def test_chunk(chunk_size, shuffle, shuffle_seed, delayed_start): + data = list(range(chunk_size)) + delayed_start_index = int(delayed_start) * (chunk_size - 10) + chunk = _Chunk(data, chunk_size=chunk_size, start_index=delayed_start_index) + linear_permutation = tuple(range(chunk_size)) + assert chunk.index_permutations == linear_permutation + + for i, index in enumerate(chunk): + assert index == delayed_start_index + i + + assert chunk.chunk_size == chunk_size + + if shuffle: + generator = torch.Generator().manual_seed(shuffle_seed) if shuffle_seed else None + + chunk = chunk.shuffle(generator=generator) + + old_permutation = chunk.index_permutations + assert old_permutation != linear_permutation + + new_perm = [] + + for i, index in enumerate(chunk): + new_perm.append(index) + + assert tuple(new_perm) == tuple([old_permutation[k] for k in range(delayed_start_index, chunk_size)]) + assert len(new_perm) == chunk_size - delayed_start_index + + if shuffle_seed: + chunk = chunk.shuffle(generator=generator.manual_seed(shuffle_seed)) + assert chunk.index_permutations == old_permutation + + assert chunk.chunk_size == chunk_size + + +class MyDataset(_StatefulIterableDataset): + def __init__(self, length): + self.length = length + self.samples = list(range(length)) + self.curr_iter = 0 + + def __iter__(self): + for sample in self.samples[self.curr_iter :]: + yield sample + self.curr_iter += 1 + + def state_dict(self, returned_samples, num_workers): + return {"curr_iter": returned_samples, "num_workers": num_workers} + + def load_state_dict(self, state_dict): + self.curr_iter = state_dict.pop("curr_iter") + + +@pytest.mark.parametrize("batch_size", [1, 2, 3]) +@pytest.mark.parametrize( + "num_workers", + [ + pytest.param(0), + pytest.param( + 1, + marks=pytest.mark.skipif( + sys.platform != "linux", reason="multiprocessing on other platforms takes forever" + ), + ), + pytest.param( + 2, + marks=pytest.mark.skipif( + sys.platform != "linux", reason="multiprocessing on other platforms takes forever" + ), + ), + ], +) +@pytest.mark.parametrize("prefetch_factor", [1, 2, 3]) +@pytest.mark.parametrize("length", [100, 101]) +@pytest.mark.parametrize("num_batches", [1, 2, 7]) +def test_resumable_loader(batch_size, num_workers, prefetch_factor, length, num_batches): + dset = MyDataset(length) + loader = DataLoader( + dset, + batch_size=batch_size, + num_workers=num_workers, + prefetch_factor=prefetch_factor if num_workers > 0 else None, + ) + + loader_iter = iter(loader) + for i, batch in enumerate(loader_iter): + assert loader._get_batch_size(batch) == batch_size + if i == num_batches - 1: + break + + state_dict = loader.state_dict() + assert state_dict["returned_samples"] == batch_size * num_batches + assert state_dict["dataset"] == { + "curr_iter": batch_size * num_batches, + "num_workers": num_workers, + } + + state_dict["returned_samples"] += 1 + state_dict["dataset"]["curr_iter"] += 1 + loader.load_state_dict(state_dict) + assert loader.returned_samples == batch_size * num_batches + 1 + assert loader.dataset.curr_iter == batch_size * num_batches + 1 + + +def test_state_dict_error(): + loader = DataLoader([1, 2, 3]) + with pytest.raises( + TypeError, + match="The dataset has no method `state_dict` that accepts `returned_samples` and `num_workers`", + ): + loader.state_dict() + + +def test_load_state_dict_error(): + loader = DataLoader([1, 2, 3]) + with pytest.raises( + TypeError, + match="The dataset has no method `load_state_dict` accepting a `state_dict`", + ): + loader.load_state_dict({"returned_samples": 1, "dataset": {"some_key": "some_val"}}) diff --git a/tests/tests_data/test_dataset.py b/tests/tests_data/datasets/test_mapping.py similarity index 94% rename from tests/tests_data/test_dataset.py rename to tests/tests_data/datasets/test_mapping.py index 0ac232608ed8e..afda27526fbe7 100644 --- a/tests/tests_data/test_dataset.py +++ b/tests/tests_data/datasets/test_mapping.py @@ -8,8 +8,8 @@ import pytest from lightning_utilities.core.imports import package_available -from lightning.data import dataset_index -from lightning.data.dataset import LightningDataset +from lightning.data.datasets import index as dataset_index +from lightning.data.datasets import LightningDataset from lightning.data.fileio import OpenCloudFileObj @@ -70,7 +70,7 @@ def load_sample(self, file_path, stream): @pytest.mark.skipif(not isConnectedWithInternet(), reason="Not connected to internet") @pytest.mark.skipif(not package_available("lightning"), reason="Supported only with mono-package") -@mock.patch("lightning.data.dataset_index.LightningClient", MagicMock()) +@mock.patch("lightning.data.datasets.index.LightningClient", MagicMock()) def test_lightning_dataset(tmpdir, image_set, monkeypatch): client = MagicMock() client.projects_service_list_project_cluster_bindings.return_value = None