Skip to content

Commit 42b5417

Browse files
authored
Sharing Datasets Across Process Boundaries (#10951)
1 parent 46f718d commit 42b5417

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

docs/source/advanced/training_tricks.rst

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,51 @@ Advanced GPU Optimizations
154154

155155
When training on single or multiple GPU machines, Lightning offers a host of advanced optimizations to improve throughput, memory efficiency, and model scaling.
156156
Refer to :doc:`Advanced GPU Optimized Training for more details <../advanced/advanced_gpu>`.
157+
158+
----------
159+
160+
Sharing Datasets Across Process Boundaries
161+
------------------------------------------
162+
The :class:`~pytorch_lightning.DataModule` class provides an organized way to decouple data loading from training logic, with :meth:`~pytorch_lightning.DataModule.prepare_data` being used for downloading and pre-processing the dataset on a single process, and :meth:`~pytorch_lightning.DataModule.setup` loading the pre-processed data for each process individually:
163+
164+
.. code-block:: python
165+
166+
class MNISTDataModule(pl.LightningDataModule):
167+
def prepare_data(self):
168+
MNIST(self.data_dir, download=True)
169+
170+
def setup(self, stage: Optional[str] = None):
171+
self.mnist = MNIST(self.data_dir)
172+
173+
def train_loader(self):
174+
return DataLoader(self.mnist, batch_size=128)
175+
176+
However, for in-memory datasets, that means that each process will hold a (redundant) replica of the dataset in memory, which may be impractical when using many processes while utilizing datasets that nearly fit into CPU memory, as the memory consumption will scale up linearly with the number of processes.
177+
For example, when training Graph Neural Networks, a common strategy is to load the entire graph into CPU memory for fast access to the entire graph structure and its features, and to then perform neighbor sampling to obtain mini-batches that fit onto the GPU.
178+
179+
A simple way to prevent redundant dataset replicas is to rely on :obj:`torch.multiprocessing` to share the `data automatically between spawned processes via shared memory <https://pytorch.org/docs/stable/notes/multiprocessing.html>`_.
180+
For this, all data pre-loading should be done on the main process inside :meth:`DataModule.__init__`.
181+
As a result, all tensor-data will get automatically shared when using the :class:`~pytorch_lightning.plugins.DDPSpawnPlugin` training type plugin:
182+
183+
.. warning::
184+
185+
:obj:`torch.multiprocessing` will send a handle of each individual tensor to other processes.
186+
In order to prevent any errors due to too many open file handles, try to reduce the number of tensors to share, *e.g.*, by stacking your data into a single tensor.
187+
188+
.. code-block:: python
189+
190+
class MNISTDataModule(pl.LightningDataModule):
191+
def __init__(self, data_dir: str):
192+
self.mnist = MNIST(data_dir, download=True, transform=T.ToTensor())
193+
194+
def train_loader(self):
195+
return DataLoader(self.mnist, batch_size=128)
196+
197+
198+
model = Model(...)
199+
datamodule = MNISTDataModule("data/MNIST")
200+
201+
trainer = Trainer(gpus=2, strategy="ddp_spawn")
202+
trainer.fit(model, datamodule)
203+
204+
See the `graph-level <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/pytorch_lightning/gin.py>`_ and `node-level <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/pytorch_lightning/graph_sage.py>`_ prediction examples in PyTorch Geometric for practical use-cases.

0 commit comments

Comments
 (0)