From 502acf76b28c9ae2816903edd903cd00c66f769c Mon Sep 17 00:00:00 2001 From: Sergey Zagoruyko Date: Sun, 1 Jan 2017 22:44:07 +0100 Subject: [PATCH 1/2] allow load in ImageFolder --- torchvision/datasets/folder.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 5eb3126ae96..7e56004dfa8 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -33,8 +33,14 @@ def make_dataset(dir, class_to_idx): return images + +def default_load(path): + return Image.open(path).convert('RGB') + + class ImageFolder(data.Dataset): - def __init__(self, root, transform=None, target_transform=None): + def __init__(self, root, transform=None, target_transform=None, + load=default_load): classes, class_to_idx = find_classes(root) imgs = make_dataset(root, class_to_idx) @@ -44,10 +50,11 @@ def __init__(self, root, transform=None, target_transform=None): self.class_to_idx = class_to_idx self.transform = transform self.target_transform = target_transform + self.load = load def __getitem__(self, index): path, target = self.imgs[index] - img = Image.open(os.path.join(self.root, path)).convert('RGB') + img = self.load(os.path.join(self.root, path)) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: From fcece53d4683efe370a8577b5f0c88263b25a9a9 Mon Sep 17 00:00:00 2001 From: Sergey Zagoruyko Date: Sun, 1 Jan 2017 23:14:35 +0100 Subject: [PATCH 2/2] load to loader --- torchvision/datasets/folder.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 7e56004dfa8..40adcb9fcec 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -34,13 +34,13 @@ def make_dataset(dir, class_to_idx): return images -def default_load(path): +def default_loader(path): return Image.open(path).convert('RGB') class ImageFolder(data.Dataset): def __init__(self, root, transform=None, target_transform=None, - load=default_load): + loader=default_loader): classes, class_to_idx = find_classes(root) imgs = make_dataset(root, class_to_idx) @@ -50,11 +50,11 @@ def __init__(self, root, transform=None, target_transform=None, self.class_to_idx = class_to_idx self.transform = transform self.target_transform = target_transform - self.load = load + self.loader = loader def __getitem__(self, index): path, target = self.imgs[index] - img = self.load(os.path.join(self.root, path)) + img = self.loader(os.path.join(self.root, path)) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: