File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed
pytorch_lightning/trainer Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change 37
37
CaptureMapDataset ,
38
38
FastForwardSampler ,
39
39
)
40
- from pytorch_lightning .utilities .data import has_iterable_dataset , has_len_all_ranks
40
+ from pytorch_lightning .utilities .data import get_len , has_iterable_dataset , has_len_all_ranks
41
41
from pytorch_lightning .utilities .enums import DistributedType
42
42
from pytorch_lightning .utilities .exceptions import MisconfigurationException
43
43
from pytorch_lightning .utilities .imports import _fault_tolerant_training
@@ -282,10 +282,11 @@ def _get_dataloader_init_kwargs(
282
282
dl_kwargs ["sampler" ] = None
283
283
284
284
if _fault_tolerant_training ():
285
- if isinstance (dl_kwargs ["dataset" ], IterableDataset ):
285
+ dataset = dl_kwargs ["dataset" ]
286
+ if isinstance (dataset , IterableDataset ):
286
287
# wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
287
288
dl_kwargs ["dataset" ] = CaptureIterableDataset (dataset = dl_kwargs ["dataset" ])
288
- elif len ( dl_kwargs [ " dataset" ] ):
289
+ elif get_len ( dataset ) != float ( "inf" ):
289
290
dl_kwargs ["dataset" ] = CaptureMapDataset (dataset = dl_kwargs ["dataset" ])
290
291
else :
291
292
raise MisconfigurationException (
You can’t perform that action at this time.
0 commit comments