diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 5eb3126ae96..40adcb9fcec 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -33,8 +33,14 @@ def make_dataset(dir, class_to_idx): return images + +def default_loader(path): + return Image.open(path).convert('RGB') + + class ImageFolder(data.Dataset): - def __init__(self, root, transform=None, target_transform=None): + def __init__(self, root, transform=None, target_transform=None, + loader=default_loader): classes, class_to_idx = find_classes(root) imgs = make_dataset(root, class_to_idx) @@ -44,10 +50,11 @@ def __init__(self, root, transform=None, target_transform=None): self.class_to_idx = class_to_idx self.transform = transform self.target_transform = target_transform + self.loader = loader def __getitem__(self, index): path, target = self.imgs[index] - img = Image.open(os.path.join(self.root, path)).convert('RGB') + img = self.loader(os.path.join(self.root, path)) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: