Skip to content

Commit 006fde9

Browse files
justusschockBordaakihironittalantiga
authored
FCCV Docs (#15598)
* add custom data iter docs * add custom data iter docs * Update docs/source-pytorch/data/custom_data_iterables.rst * remove ToDevice * nit * Update docs/source-pytorch/data/custom_data_iterables.rst Co-authored-by: Luca Antiga <[email protected]> * clarification for @lantiga * typo * Update docs/source-pytorch/data/custom_data_iterables.rst * Update docs/source-pytorch/data/custom_data_iterables.rst * Update docs/source-pytorch/data/custom_data_iterables.rst Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Akihiro Nitta <[email protected]> Co-authored-by: Luca Antiga <[email protected]>
1 parent 88b2e5a commit 006fde9

File tree

2 files changed

+123
-0
lines changed

2 files changed

+123
-0
lines changed
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
.. _dataiters:
2+
3+
##################################
4+
Injecting 3rd Party Data Iterables
5+
##################################
6+
7+
When training a model on a specific task, data loading and preprocessing might become a bottleneck.
8+
Lightning does not enforce a specific data loading approach nor does it try to control it.
9+
The only assumption Lightning makes is that the data is returned as an iterable of batches.
10+
11+
For PyTorch-based programs, these iterables are typically instances of :class:`~torch.utils.data.DataLoader`.
12+
13+
However, Lightning also supports other data types such as plain list of batches, generators or other custom iterables.
14+
15+
.. code-block:: python
16+
17+
# random list of batches
18+
data = [(torch.rand(32, 3, 32, 32), torch.randint(0, 10, (32,))) for _ in range(100)]
19+
model = LitClassifier()
20+
trainer = Trainer()
21+
trainer.fit(model, data)
22+
23+
Examples for custom iterables include `NVIDIA DALI <https://github.com/NVIDIA/DALI>`__ or `FFCV <https://github.com/libffcv/ffcv>`__ for computer vision.
24+
Both libraries offer support for custom data loading and preprocessing (also hardware accelerated) and can be used with Lightning.
25+
26+
27+
For example, taking the example from FFCV's readme, we can use it with Lightning by just removing the hardcoded ``ToDevice(0)``
28+
as Lightning takes care of GPU placement. In case you want to use some data transformations on GPUs, change the
29+
``ToDevice(0)`` to ``ToDevice(self.trainer.local_rank)`` to correctly map to the desired GPU in your pipeline.
30+
31+
.. code-block:: python
32+
33+
from ffcv.loader import Loader, OrderOption
34+
from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, Cutout
35+
from ffcv.fields.decoders import IntDecoder, RandomResizedCropRGBImageDecoder
36+
37+
38+
class CustomClassifier(LitClassifier):
39+
def train_dataloader(self):
40+
41+
# Random resized crop
42+
decoder = RandomResizedCropRGBImageDecoder((224, 224))
43+
44+
# Data decoding and augmentation
45+
image_pipeline = [decoder, Cutout(), ToTensor(), ToTorchImage()]
46+
label_pipeline = [IntDecoder(), ToTensor()]
47+
48+
# Pipeline for each data field
49+
pipelines = {"image": image_pipeline, "label": label_pipeline}
50+
51+
# Replaces PyTorch data loader (`torch.utils.data.Dataloader`)
52+
loader = Loader(
53+
write_path, batch_size=bs, num_workers=num_workers, order=OrderOption.RANDOM, pipelines=pipelines
54+
)
55+
56+
return loader
57+
58+
When moving data to a specific device, you can always refer to ``self.trainer.local_rank`` to get the accelerator
59+
used by the current process.
60+
61+
By just changing ``device_id=0`` to ``device_id=self.trainer.local_rank`` we can also leverage DALI's GPU decoding:
62+
63+
.. code-block:: python
64+
65+
from nvidia.dali.pipeline import pipeline_def
66+
import nvidia.dali.types as types
67+
import nvidia.dali.fn as fn
68+
from nvidia.dali.plugin.pytorch import DALIGenericIterator
69+
import os
70+
71+
72+
class CustomLitClassifier(LitClassifier):
73+
def train_dataloader(self):
74+
75+
# To run with different data, see documentation of nvidia.dali.fn.readers.file
76+
# points to https://github.com/NVIDIA/DALI_extra
77+
data_root_dir = os.environ["DALI_EXTRA_PATH"]
78+
images_dir = os.path.join(data_root_dir, "db", "single", "jpeg")
79+
80+
@pipeline_def(num_threads=4, device_id=self.trainer.local_rank)
81+
def get_dali_pipeline():
82+
images, labels = fn.readers.file(file_root=images_dir, random_shuffle=True, name="Reader")
83+
# decode data on the GPU
84+
images = fn.decoders.image_random_crop(images, device="mixed", output_type=types.RGB)
85+
# the rest of processing happens on the GPU as well
86+
images = fn.resize(images, resize_x=256, resize_y=256)
87+
images = fn.crop_mirror_normalize(
88+
images,
89+
crop_h=224,
90+
crop_w=224,
91+
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
92+
std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
93+
mirror=fn.random.coin_flip(),
94+
)
95+
return images, labels
96+
97+
train_data = DALIGenericIterator(
98+
[get_dali_pipeline(batch_size=16)],
99+
["data", "label"],
100+
reader_name="Reader",
101+
)
102+
103+
return train_data
104+
105+
106+
Limitations
107+
------------
108+
Lightning works with all kinds of custom data iterables as shown above. There are, however, a few features that cannot
109+
be supported this way. These restrictions come from the fact that for their support,
110+
Lightning needs to know a lot on the internals of these iterables.
111+
112+
- In a distributed multi-GPU setting (ddp),
113+
Lightning automatically replaces the DataLoader's sampler with its distributed counterpart.
114+
This makes sure that each GPU sees a different part of the dataset.
115+
As sampling can be implemented in arbitrary ways with custom iterables,
116+
there is no way for Lightning to know, how to replace the sampler.
117+
118+
- When training fails for some reason, Lightning is able to extract all of the relevant data from the model,
119+
optimizers, trainer and dataloader to resume it at the exact same batch it crashed.
120+
This feature is called fault-tolerance and is limited to PyTorch DataLoaders.
121+
Lighning needs to know a lot about sampling, fast forwarding and random number handling to enable fault tolerance,
122+
meaning that it cannot be supported for arbitrary iterables.

docs/source-pytorch/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ Current Lightning Users
206206
Train on single or multiple TPUs <accelerators/tpu>
207207
Train on MPS <accelerators/mps>
208208
Use a pretrained model <advanced/pretrained>
209+
Inject Custom Data Iterables <data/custom_data_iterables>
209210
model/own_your_loop
210211

211212
.. toctree::

0 commit comments

Comments
 (0)