@@ -33,8 +33,14 @@ def make_dataset(dir, class_to_idx):
33
33
34
34
return images
35
35
36
+
37
+ def default_loader (path ):
38
+ return Image .open (path ).convert ('RGB' )
39
+
40
+
36
41
class ImageFolder (data .Dataset ):
37
- def __init__ (self , root , transform = None , target_transform = None ):
42
+ def __init__ (self , root , transform = None , target_transform = None ,
43
+ loader = default_loader ):
38
44
classes , class_to_idx = find_classes (root )
39
45
imgs = make_dataset (root , class_to_idx )
40
46
@@ -44,10 +50,11 @@ def __init__(self, root, transform=None, target_transform=None):
44
50
self .class_to_idx = class_to_idx
45
51
self .transform = transform
46
52
self .target_transform = target_transform
53
+ self .loader = loader
47
54
48
55
def __getitem__ (self , index ):
49
56
path , target = self .imgs [index ]
50
- img = Image . open (os .path .join (self .root , path )). convert ( 'RGB' )
57
+ img = self . loader (os .path .join (self .root , path ))
51
58
if self .transform is not None :
52
59
img = self .transform (img )
53
60
if self .target_transform is not None :
0 commit comments