@@ -235,6 +235,7 @@ def gdtl(
235
235
workers = 5 ,
236
236
_worker_init_fn = None ,
237
237
memory_format = torch .contiguous_format ,
238
+ ** kwargs ,
238
239
):
239
240
if torch .distributed .is_initialized ():
240
241
rank = torch .distributed .get_rank ()
@@ -284,6 +285,7 @@ def gdvl(
284
285
workers = 5 ,
285
286
_worker_init_fn = None ,
286
287
memory_format = torch .contiguous_format ,
288
+ ** kwargs ,
287
289
):
288
290
if torch .distributed .is_initialized ():
289
291
rank = torch .distributed .get_rank ()
@@ -413,6 +415,7 @@ def get_pytorch_train_loader(
413
415
start_epoch = 0 ,
414
416
workers = 5 ,
415
417
_worker_init_fn = None ,
418
+ prefetch_factor = 2 ,
416
419
memory_format = torch .contiguous_format ,
417
420
):
418
421
interpolation = {"bicubic" : Image .BICUBIC , "bilinear" : Image .BILINEAR }[
@@ -445,6 +448,7 @@ def get_pytorch_train_loader(
445
448
collate_fn = partial (fast_collate , memory_format ),
446
449
drop_last = True ,
447
450
persistent_workers = True ,
451
+ prefetch_factor = prefetch_factor ,
448
452
)
449
453
450
454
return (
@@ -464,6 +468,7 @@ def get_pytorch_val_loader(
464
468
_worker_init_fn = None ,
465
469
crop_padding = 32 ,
466
470
memory_format = torch .contiguous_format ,
471
+ prefetch_factor = 2 ,
467
472
):
468
473
interpolation = {"bicubic" : Image .BICUBIC , "bilinear" : Image .BILINEAR }[
469
474
interpolation
@@ -499,6 +504,7 @@ def get_pytorch_val_loader(
499
504
collate_fn = partial (fast_collate , memory_format ),
500
505
drop_last = False ,
501
506
persistent_workers = True ,
507
+ prefetch_factor = prefetch_factor ,
502
508
)
503
509
504
510
return PrefetchedWrapper (val_loader , 0 , num_classes , one_hot ), len (val_loader )
@@ -548,6 +554,7 @@ def get_syntetic_loader(
548
554
workers = None ,
549
555
_worker_init_fn = None ,
550
556
memory_format = torch .contiguous_format ,
557
+ ** kwargs ,
551
558
):
552
559
return (
553
560
SynteticDataLoader (
0 commit comments