Skip to content

Fully exhaust datapipes that are needed to construct a dataset #6076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Sep 13, 2022
13 changes: 8 additions & 5 deletions torchvision/prototype/datasets/_builtin/cub200.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
LineReader,
Mapper,
)
from torchdata.datapipes.map import IterToMapConverter
from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
Expand Down Expand Up @@ -114,6 +115,9 @@ def _2011_classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
else:
return None

def _2011_extract_file_name(self, rel_posix_path: str) -> str:
return rel_posix_path.rsplit("/", maxsplit=1)[1]

def _2011_filter_split(self, row: List[str]) -> bool:
_, split_id = row
return {
Expand Down Expand Up @@ -185,17 +189,16 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
)

image_files_dp = CSVParser(image_files_dp, dialect="cub200")
image_files_map = dict(
(image_id, rel_posix_path.rsplit("/", maxsplit=1)[1]) for image_id, rel_posix_path in image_files_dp
)
image_files_dp = Mapper(image_files_dp, self._2011_extract_file_name, input_col=1)
image_files_map = IterToMapConverter(image_files_dp)

split_dp = CSVParser(split_dp, dialect="cub200")
split_dp = Filter(split_dp, self._2011_filter_split)
split_dp = Mapper(split_dp, getitem(0))
split_dp = Mapper(split_dp, image_files_map.get)
split_dp = Mapper(split_dp, image_files_map.__getitem__)

bounding_boxes_dp = CSVParser(bounding_boxes_dp, dialect="cub200")
bounding_boxes_dp = Mapper(bounding_boxes_dp, image_files_map.get, input_col=0)
bounding_boxes_dp = Mapper(bounding_boxes_dp, image_files_map.__getitem__, input_col=0)
Copy link
Contributor

@ejguan ejguan May 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good trick to have the similar behavior as zip_with_map
cc: @NivekT


anns_dp = IterKeyZipper(
bounding_boxes_dp,
Expand Down
58 changes: 34 additions & 24 deletions torchvision/prototype/datasets/_builtin/imagenet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import enum
import functools
import pathlib
import re
from typing import Any, BinaryIO, cast, Dict, List, Match, Optional, Tuple, Union

from typing import Any, BinaryIO, cast, Dict, Iterator, List, Match, Optional, Tuple, Union

from torchdata.datapipes.iter import (
Demultiplexer,
Expand All @@ -14,6 +14,7 @@
Mapper,
TarArchiveLoader,
)
from torchdata.datapipes.map import IterToMapConverter
from torchvision.prototype.datasets.utils import Dataset, ManualDownloadResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
Expand Down Expand Up @@ -47,6 +48,28 @@ class ImageNetDemux(enum.IntEnum):
LABEL = 1


class CategoryAndWordNetIDExtractor(IterDataPipe):
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
_WNID_MAP = {
"n03126707": "construction crane",
"n03710721": "tank suit",
}

def __init__(self, datapipe: IterDataPipe[Tuple[str, BinaryIO]]) -> None:
self.datapipe = datapipe

def __iter__(self) -> Iterator[Tuple[str, str]]:
for _, stream in self.datapipe:
synsets = read_mat(stream, squeeze_me=True)["synsets"]
for _, wnid, category, _, num_children, *_ in synsets:
if num_children > 0:
# we are looking at a superclass that has no direct instance
continue

yield self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid


@register_dataset(NAME)
class ImageNet(Dataset):
"""
Expand Down Expand Up @@ -110,22 +133,6 @@ def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]:
"ILSVRC2012_validation_ground_truth.txt": ImageNetDemux.LABEL,
}.get(pathlib.Path(data[0]).name)

# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
_WNID_MAP = {
"n03126707": "construction crane",
"n03710721": "tank suit",
}

def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]:
synsets = read_mat(data[1], squeeze_me=True)["synsets"]
return [
(self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid)
for _, wnid, category, _, num_children, *_ in synsets
# if num_children > 0, we are looking at a superclass that has no direct instance
if num_children == 0
]

def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: Tuple[str, ...]) -> str:
return wnids[int(imagenet_label) - 1]

Expand Down Expand Up @@ -172,12 +179,15 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
devkit_dp, 2, self._classifiy_devkit, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
)

meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids)
_, wnids = zip(*next(iter(meta_dp)))
# We cannot use self._wnids here, since we use a different order than the dataset
meta_dp = CategoryAndWordNetIDExtractor(meta_dp)
wnid_dp = Mapper(meta_dp, getitem(1))
wnid_dp = Enumerator(wnid_dp, 1)
wnid_dp = Mapper(wnid_dp, str, input_col=0)
wnid_map = IterToMapConverter(wnid_dp)

label_dp = LineReader(label_dp, decode=True, return_path=False)
# We cannot use self._wnids here, since we use a different order than the dataset
label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids))
label_dp = Mapper(label_dp, wnid_map.__getitem__)
label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1)
label_dp = hint_shuffling(label_dp)
label_dp = hint_sharding(label_dp)
Expand Down Expand Up @@ -209,8 +219,8 @@ def _generate_categories(self) -> List[Tuple[str, ...]]:

devkit_dp = resources[1].load(self._root)
meta_dp = Filter(devkit_dp, self._filter_meta)
meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids)
meta_dp = CategoryAndWordNetIDExtractor(meta_dp)

categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp)))
categories_and_wnids = cast(List[Tuple[str, ...]], list(meta_dp))
categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1])
return categories_and_wnids