Skip to content

Fully exhaust datapipes that are needed to construct a dataset #6076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Sep 13, 2022
12 changes: 6 additions & 6 deletions torchvision/prototype/datasets/_builtin/cub200.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,8 @@ def _2011_classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
return 0
elif path.name == "train_test_split.txt":
return 1
elif path.name == "images.txt":
return 2
elif path.name == "bounding_boxes.txt":
return 3
return 2
else:
return None

Expand Down Expand Up @@ -180,15 +178,17 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
prepare_ann_fn: Callable
if self._year == "2011":
archive_dp, segmentations_dp = resource_dps
images_dp, split_dp, image_files_dp, bounding_boxes_dp = Demultiplexer(
archive_dp, 4, self._2011_classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
)

image_files_dp = Filter(archive_dp, path_comparator("name", "images.txt"))
image_files_dp = CSVParser(image_files_dp, dialect="cub200")
image_files_map = dict(
(image_id, rel_posix_path.rsplit("/", maxsplit=1)[1]) for image_id, rel_posix_path in image_files_dp
)

images_dp, split_dp, bounding_boxes_dp = Demultiplexer(
archive_dp, 3, self._2011_classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
)

split_dp = CSVParser(split_dp, dialect="cub200")
split_dp = Filter(split_dp, self._2011_filter_split)
split_dp = Mapper(split_dp, getitem(0))
Expand Down
37 changes: 10 additions & 27 deletions torchvision/prototype/datasets/_builtin/imagenet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import enum
import functools
import pathlib
import re
Expand All @@ -10,7 +9,6 @@
IterKeyZipper,
Mapper,
Filter,
Demultiplexer,
TarArchiveLoader,
Enumerator,
)
Expand All @@ -27,6 +25,7 @@
hint_shuffling,
read_categories_file,
path_accessor,
path_comparator,
)
from torchvision.prototype.features import Label, EncodedImage

Expand All @@ -46,11 +45,6 @@ def __init__(self, **kwargs: Any) -> None:
super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs)


class ImageNetDemux(enum.IntEnum):
META = 0
LABEL = 1


@register_dataset(NAME)
class ImageNet(Dataset):
"""
Expand Down Expand Up @@ -108,21 +102,19 @@ def _prepare_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label,
def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]:
return None, data

def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]:
return {
"meta.mat": ImageNetDemux.META,
"ILSVRC2012_validation_ground_truth.txt": ImageNetDemux.LABEL,
}.get(pathlib.Path(data[0]).name)

# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
_WNID_MAP = {
"n03126707": "construction crane",
"n03710721": "tank suit",
}

def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]:
synsets = read_mat(data[1], squeeze_me=True)["synsets"]
def _extract_categories_and_wnids(self, devkit_dp: IterDataPipe[Tuple[str, BinaryIO]]) -> List[Tuple[str, str]]:
meta_dp = Filter(devkit_dp, path_comparator("name", "meta.mat"))

_, buffer = list(meta_dp)[0]
synsets = read_mat(buffer, squeeze_me=True)["synsets"]

return [
(self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid)
for _, wnid, category, _, num_children, *_ in synsets
Expand Down Expand Up @@ -172,13 +164,9 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
else: # config.split == "val":
images_dp, devkit_dp = resource_dps

meta_dp, label_dp = Demultiplexer(
devkit_dp, 2, self._classifiy_devkit, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
)

meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids)
_, wnids = zip(*next(iter(meta_dp)))
_, wnids = zip(*self._extract_categories_and_wnids(devkit_dp))

label_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt"))
label_dp = LineReader(label_dp, decode=True, return_path=False)
# We cannot use self._wnids here, since we use a different order than the dataset
label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids))
Expand All @@ -204,17 +192,12 @@ def __len__(self) -> int:
"test": 100_000,
}[self._split]

def _filter_meta(self, data: Tuple[str, Any]) -> bool:
return self._classifiy_devkit(data) == ImageNetDemux.META

def _generate_categories(self) -> List[Tuple[str, ...]]:
self._split = "val"
resources = self._resources()

devkit_dp = resources[1].load(self._root)
meta_dp = Filter(devkit_dp, self._filter_meta)
meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids)

categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp)))
categories_and_wnids = self._extract_categories_and_wnids(devkit_dp)
categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1])
return categories_and_wnids