|
| 1 | +.. _dataiters: |
| 2 | + |
| 3 | +################################## |
| 4 | +Injecting 3rd Party Data Iterables |
| 5 | +################################## |
| 6 | + |
| 7 | +When training a model on a specific task, data loading and preprocessing might become a bottleneck. |
| 8 | +Lightning does not enforce a specific data loading approach nor does it try to control it. |
| 9 | +The only assumption Lightning makes is that the data is returned as an iterable of batches. |
| 10 | + |
| 11 | +For PyTorch-based programs, these iterables are typically instances of :class:`~torch.utils.data.DataLoader`. |
| 12 | + |
| 13 | +However, Lightning also supports other data types such as plain list of batches, generators or other custom iterables. |
| 14 | + |
| 15 | +.. code-block:: python |
| 16 | +
|
| 17 | + # random list of batches |
| 18 | + data = [(torch.rand(32, 3, 32, 32), torch.randint(0, 10, (32,))) for _ in range(100)] |
| 19 | + model = LitClassifier() |
| 20 | + trainer = Trainer() |
| 21 | + trainer.fit(model, data) |
| 22 | +
|
| 23 | +Examples for custom iterables include `NVIDIA DALI <https://github.com/NVIDIA/DALI>`__ or `FFCV <https://github.com/libffcv/ffcv>`__ for computer vision. |
| 24 | +Both libraries offer support for custom data loading and preprocessing (also hardware accelerated) and can be used with Lightning. |
| 25 | + |
| 26 | + |
| 27 | +For example, taking the example from FFCV's readme, we can use it with Lightning by just removing the hardcoded ``ToDevice(0)`` |
| 28 | +as Lightning takes care of GPU placement. In case you want to use some data transformations on GPUs, change the |
| 29 | +``ToDevice(0)`` to ``ToDevice(self.trainer.local_rank)`` to correctly map to the desired GPU in your pipeline. |
| 30 | + |
| 31 | +.. code-block:: python |
| 32 | +
|
| 33 | + from ffcv.loader import Loader, OrderOption |
| 34 | + from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, Cutout |
| 35 | + from ffcv.fields.decoders import IntDecoder, RandomResizedCropRGBImageDecoder |
| 36 | +
|
| 37 | +
|
| 38 | + class CustomClassifier(LitClassifier): |
| 39 | + def train_dataloader(self): |
| 40 | +
|
| 41 | + # Random resized crop |
| 42 | + decoder = RandomResizedCropRGBImageDecoder((224, 224)) |
| 43 | +
|
| 44 | + # Data decoding and augmentation |
| 45 | + image_pipeline = [decoder, Cutout(), ToTensor(), ToTorchImage()] |
| 46 | + label_pipeline = [IntDecoder(), ToTensor()] |
| 47 | +
|
| 48 | + # Pipeline for each data field |
| 49 | + pipelines = {"image": image_pipeline, "label": label_pipeline} |
| 50 | +
|
| 51 | + # Replaces PyTorch data loader (`torch.utils.data.Dataloader`) |
| 52 | + loader = Loader( |
| 53 | + write_path, batch_size=bs, num_workers=num_workers, order=OrderOption.RANDOM, pipelines=pipelines |
| 54 | + ) |
| 55 | +
|
| 56 | + return loader |
| 57 | +
|
| 58 | +When moving data to a specific device, you can always refer to ``self.trainer.local_rank`` to get the accelerator |
| 59 | +used by the current process. |
| 60 | + |
| 61 | +By just changing ``device_id=0`` to ``device_id=self.trainer.local_rank`` we can also leverage DALI's GPU decoding: |
| 62 | + |
| 63 | +.. code-block:: python |
| 64 | +
|
| 65 | + from nvidia.dali.pipeline import pipeline_def |
| 66 | + import nvidia.dali.types as types |
| 67 | + import nvidia.dali.fn as fn |
| 68 | + from nvidia.dali.plugin.pytorch import DALIGenericIterator |
| 69 | + import os |
| 70 | +
|
| 71 | +
|
| 72 | + class CustomLitClassifier(LitClassifier): |
| 73 | + def train_dataloader(self): |
| 74 | +
|
| 75 | + # To run with different data, see documentation of nvidia.dali.fn.readers.file |
| 76 | + # points to https://github.com/NVIDIA/DALI_extra |
| 77 | + data_root_dir = os.environ["DALI_EXTRA_PATH"] |
| 78 | + images_dir = os.path.join(data_root_dir, "db", "single", "jpeg") |
| 79 | +
|
| 80 | + @pipeline_def(num_threads=4, device_id=self.trainer.local_rank) |
| 81 | + def get_dali_pipeline(): |
| 82 | + images, labels = fn.readers.file(file_root=images_dir, random_shuffle=True, name="Reader") |
| 83 | + # decode data on the GPU |
| 84 | + images = fn.decoders.image_random_crop(images, device="mixed", output_type=types.RGB) |
| 85 | + # the rest of processing happens on the GPU as well |
| 86 | + images = fn.resize(images, resize_x=256, resize_y=256) |
| 87 | + images = fn.crop_mirror_normalize( |
| 88 | + images, |
| 89 | + crop_h=224, |
| 90 | + crop_w=224, |
| 91 | + mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], |
| 92 | + std=[0.229 * 255, 0.224 * 255, 0.225 * 255], |
| 93 | + mirror=fn.random.coin_flip(), |
| 94 | + ) |
| 95 | + return images, labels |
| 96 | +
|
| 97 | + train_data = DALIGenericIterator( |
| 98 | + [get_dali_pipeline(batch_size=16)], |
| 99 | + ["data", "label"], |
| 100 | + reader_name="Reader", |
| 101 | + ) |
| 102 | +
|
| 103 | + return train_data |
| 104 | +
|
| 105 | +
|
| 106 | +Limitations |
| 107 | +------------ |
| 108 | +Lightning works with all kinds of custom data iterables as shown above. There are, however, a few features that cannot |
| 109 | +be supported this way. These restrictions come from the fact that for their support, |
| 110 | +Lightning needs to know a lot on the internals of these iterables. |
| 111 | + |
| 112 | +- In a distributed multi-GPU setting (ddp), |
| 113 | + Lightning automatically replaces the DataLoader's sampler with its distributed counterpart. |
| 114 | + This makes sure that each GPU sees a different part of the dataset. |
| 115 | + As sampling can be implemented in arbitrary ways with custom iterables, |
| 116 | + there is no way for Lightning to know, how to replace the sampler. |
| 117 | + |
| 118 | +- When training fails for some reason, Lightning is able to extract all of the relevant data from the model, |
| 119 | + optimizers, trainer and dataloader to resume it at the exact same batch it crashed. |
| 120 | + This feature is called fault-tolerance and is limited to PyTorch DataLoaders. |
| 121 | + Lighning needs to know a lot about sampling, fast forwarding and random number handling to enable fault tolerance, |
| 122 | + meaning that it cannot be supported for arbitrary iterables. |
0 commit comments