Skip to content

Commit 5c39840

Browse files
szagoruykoapaszke
authored andcommitted
Allow to pass load function in ImageFolder (#20)
1 parent df55747 commit 5c39840

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

torchvision/datasets/folder.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,14 @@ def make_dataset(dir, class_to_idx):
3333

3434
return images
3535

36+
37+
def default_loader(path):
38+
return Image.open(path).convert('RGB')
39+
40+
3641
class ImageFolder(data.Dataset):
37-
def __init__(self, root, transform=None, target_transform=None):
42+
def __init__(self, root, transform=None, target_transform=None,
43+
loader=default_loader):
3844
classes, class_to_idx = find_classes(root)
3945
imgs = make_dataset(root, class_to_idx)
4046

@@ -44,10 +50,11 @@ def __init__(self, root, transform=None, target_transform=None):
4450
self.class_to_idx = class_to_idx
4551
self.transform = transform
4652
self.target_transform = target_transform
53+
self.loader = loader
4754

4855
def __getitem__(self, index):
4956
path, target = self.imgs[index]
50-
img = Image.open(os.path.join(self.root, path)).convert('RGB')
57+
img = self.loader(os.path.join(self.root, path))
5158
if self.transform is not None:
5259
img = self.transform(img)
5360
if self.target_transform is not None:

0 commit comments

Comments
 (0)