diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py index bb3f712c59d..c07166a960c 100644 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ b/torchvision/prototype/datasets/_builtin/cub200.py @@ -13,6 +13,7 @@ LineReader, Mapper, ) +from torchdata.datapipes.map import IterToMapConverter from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( getitem, @@ -114,6 +115,9 @@ def _2011_classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: else: return None + def _2011_extract_file_name(self, rel_posix_path: str) -> str: + return rel_posix_path.rsplit("/", maxsplit=1)[1] + def _2011_filter_split(self, row: List[str]) -> bool: _, split_id = row return { @@ -185,17 +189,16 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, ) image_files_dp = CSVParser(image_files_dp, dialect="cub200") - image_files_map = dict( - (image_id, rel_posix_path.rsplit("/", maxsplit=1)[1]) for image_id, rel_posix_path in image_files_dp - ) + image_files_dp = Mapper(image_files_dp, self._2011_extract_file_name, input_col=1) + image_files_map = IterToMapConverter(image_files_dp) split_dp = CSVParser(split_dp, dialect="cub200") split_dp = Filter(split_dp, self._2011_filter_split) split_dp = Mapper(split_dp, getitem(0)) - split_dp = Mapper(split_dp, image_files_map.get) + split_dp = Mapper(split_dp, image_files_map.__getitem__) bounding_boxes_dp = CSVParser(bounding_boxes_dp, dialect="cub200") - bounding_boxes_dp = Mapper(bounding_boxes_dp, image_files_map.get, input_col=0) + bounding_boxes_dp = Mapper(bounding_boxes_dp, image_files_map.__getitem__, input_col=0) anns_dp = IterKeyZipper( bounding_boxes_dp, diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 062e240a8b8..3192f1f5503 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -1,8 +1,8 @@ import enum -import functools import pathlib import re -from typing import Any, BinaryIO, cast, Dict, List, Match, Optional, Tuple, Union + +from typing import Any, BinaryIO, cast, Dict, Iterator, List, Match, Optional, Tuple, Union from torchdata.datapipes.iter import ( Demultiplexer, @@ -14,6 +14,7 @@ Mapper, TarArchiveLoader, ) +from torchdata.datapipes.map import IterToMapConverter from torchvision.prototype.datasets.utils import Dataset, ManualDownloadResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( getitem, @@ -47,6 +48,28 @@ class ImageNetDemux(enum.IntEnum): LABEL = 1 +class CategoryAndWordNetIDExtractor(IterDataPipe): + # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 + # and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment + _WNID_MAP = { + "n03126707": "construction crane", + "n03710721": "tank suit", + } + + def __init__(self, datapipe: IterDataPipe[Tuple[str, BinaryIO]]) -> None: + self.datapipe = datapipe + + def __iter__(self) -> Iterator[Tuple[str, str]]: + for _, stream in self.datapipe: + synsets = read_mat(stream, squeeze_me=True)["synsets"] + for _, wnid, category, _, num_children, *_ in synsets: + if num_children > 0: + # we are looking at a superclass that has no direct instance + continue + + yield self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid + + @register_dataset(NAME) class ImageNet(Dataset): """ @@ -110,25 +133,6 @@ def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]: "ILSVRC2012_validation_ground_truth.txt": ImageNetDemux.LABEL, }.get(pathlib.Path(data[0]).name) - # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 - # and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment - _WNID_MAP = { - "n03126707": "construction crane", - "n03710721": "tank suit", - } - - def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]: - synsets = read_mat(data[1], squeeze_me=True)["synsets"] - return [ - (self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid) - for _, wnid, category, _, num_children, *_ in synsets - # if num_children > 0, we are looking at a superclass that has no direct instance - if num_children == 0 - ] - - def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: Tuple[str, ...]) -> str: - return wnids[int(imagenet_label) - 1] - _VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P\d{8})[.]JPEG") def _val_test_image_key(self, path: pathlib.Path) -> int: @@ -172,12 +176,15 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, devkit_dp, 2, self._classifiy_devkit, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE ) - meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids) - _, wnids = zip(*next(iter(meta_dp))) + # We cannot use self._wnids here, since we use a different order than the dataset + meta_dp = CategoryAndWordNetIDExtractor(meta_dp) + wnid_dp = Mapper(meta_dp, getitem(1)) + wnid_dp = Enumerator(wnid_dp, 1) + wnid_map = IterToMapConverter(wnid_dp) label_dp = LineReader(label_dp, decode=True, return_path=False) - # We cannot use self._wnids here, since we use a different order than the dataset - label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids)) + label_dp = Mapper(label_dp, int) + label_dp = Mapper(label_dp, wnid_map.__getitem__) label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1) label_dp = hint_shuffling(label_dp) label_dp = hint_sharding(label_dp) @@ -209,8 +216,8 @@ def _generate_categories(self) -> List[Tuple[str, ...]]: devkit_dp = resources[1].load(self._root) meta_dp = Filter(devkit_dp, self._filter_meta) - meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids) + meta_dp = CategoryAndWordNetIDExtractor(meta_dp) - categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp))) + categories_and_wnids = cast(List[Tuple[str, ...]], list(meta_dp)) categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1]) return categories_and_wnids