Skip to content

Commit b4686f2

Browse files
authored
Fully exhaust datapipes that are needed to construct a dataset (#6076)
1 parent 2c19af3 commit b4686f2

File tree

2 files changed

+42
-32
lines changed

2 files changed

+42
-32
lines changed

torchvision/prototype/datasets/_builtin/cub200.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
LineReader,
1414
Mapper,
1515
)
16+
from torchdata.datapipes.map import IterToMapConverter
1617
from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource
1718
from torchvision.prototype.datasets.utils._internal import (
1819
getitem,
@@ -114,6 +115,9 @@ def _2011_classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
114115
else:
115116
return None
116117

118+
def _2011_extract_file_name(self, rel_posix_path: str) -> str:
119+
return rel_posix_path.rsplit("/", maxsplit=1)[1]
120+
117121
def _2011_filter_split(self, row: List[str]) -> bool:
118122
_, split_id = row
119123
return {
@@ -185,17 +189,16 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
185189
)
186190

187191
image_files_dp = CSVParser(image_files_dp, dialect="cub200")
188-
image_files_map = dict(
189-
(image_id, rel_posix_path.rsplit("/", maxsplit=1)[1]) for image_id, rel_posix_path in image_files_dp
190-
)
192+
image_files_dp = Mapper(image_files_dp, self._2011_extract_file_name, input_col=1)
193+
image_files_map = IterToMapConverter(image_files_dp)
191194

192195
split_dp = CSVParser(split_dp, dialect="cub200")
193196
split_dp = Filter(split_dp, self._2011_filter_split)
194197
split_dp = Mapper(split_dp, getitem(0))
195-
split_dp = Mapper(split_dp, image_files_map.get)
198+
split_dp = Mapper(split_dp, image_files_map.__getitem__)
196199

197200
bounding_boxes_dp = CSVParser(bounding_boxes_dp, dialect="cub200")
198-
bounding_boxes_dp = Mapper(bounding_boxes_dp, image_files_map.get, input_col=0)
201+
bounding_boxes_dp = Mapper(bounding_boxes_dp, image_files_map.__getitem__, input_col=0)
199202

200203
anns_dp = IterKeyZipper(
201204
bounding_boxes_dp,

torchvision/prototype/datasets/_builtin/imagenet.py

+34-27
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import enum
2-
import functools
32
import pathlib
43
import re
5-
from typing import Any, BinaryIO, cast, Dict, List, Match, Optional, Tuple, Union
4+
5+
from typing import Any, BinaryIO, cast, Dict, Iterator, List, Match, Optional, Tuple, Union
66

77
from torchdata.datapipes.iter import (
88
Demultiplexer,
@@ -14,6 +14,7 @@
1414
Mapper,
1515
TarArchiveLoader,
1616
)
17+
from torchdata.datapipes.map import IterToMapConverter
1718
from torchvision.prototype.datasets.utils import Dataset, ManualDownloadResource, OnlineResource
1819
from torchvision.prototype.datasets.utils._internal import (
1920
getitem,
@@ -47,6 +48,28 @@ class ImageNetDemux(enum.IntEnum):
4748
LABEL = 1
4849

4950

51+
class CategoryAndWordNetIDExtractor(IterDataPipe):
52+
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
53+
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
54+
_WNID_MAP = {
55+
"n03126707": "construction crane",
56+
"n03710721": "tank suit",
57+
}
58+
59+
def __init__(self, datapipe: IterDataPipe[Tuple[str, BinaryIO]]) -> None:
60+
self.datapipe = datapipe
61+
62+
def __iter__(self) -> Iterator[Tuple[str, str]]:
63+
for _, stream in self.datapipe:
64+
synsets = read_mat(stream, squeeze_me=True)["synsets"]
65+
for _, wnid, category, _, num_children, *_ in synsets:
66+
if num_children > 0:
67+
# we are looking at a superclass that has no direct instance
68+
continue
69+
70+
yield self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid
71+
72+
5073
@register_dataset(NAME)
5174
class ImageNet(Dataset):
5275
"""
@@ -110,25 +133,6 @@ def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]:
110133
"ILSVRC2012_validation_ground_truth.txt": ImageNetDemux.LABEL,
111134
}.get(pathlib.Path(data[0]).name)
112135

113-
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
114-
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
115-
_WNID_MAP = {
116-
"n03126707": "construction crane",
117-
"n03710721": "tank suit",
118-
}
119-
120-
def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]:
121-
synsets = read_mat(data[1], squeeze_me=True)["synsets"]
122-
return [
123-
(self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid)
124-
for _, wnid, category, _, num_children, *_ in synsets
125-
# if num_children > 0, we are looking at a superclass that has no direct instance
126-
if num_children == 0
127-
]
128-
129-
def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: Tuple[str, ...]) -> str:
130-
return wnids[int(imagenet_label) - 1]
131-
132136
_VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")
133137

134138
def _val_test_image_key(self, path: pathlib.Path) -> int:
@@ -172,12 +176,15 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
172176
devkit_dp, 2, self._classifiy_devkit, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
173177
)
174178

175-
meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids)
176-
_, wnids = zip(*next(iter(meta_dp)))
179+
# We cannot use self._wnids here, since we use a different order than the dataset
180+
meta_dp = CategoryAndWordNetIDExtractor(meta_dp)
181+
wnid_dp = Mapper(meta_dp, getitem(1))
182+
wnid_dp = Enumerator(wnid_dp, 1)
183+
wnid_map = IterToMapConverter(wnid_dp)
177184

178185
label_dp = LineReader(label_dp, decode=True, return_path=False)
179-
# We cannot use self._wnids here, since we use a different order than the dataset
180-
label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids))
186+
label_dp = Mapper(label_dp, int)
187+
label_dp = Mapper(label_dp, wnid_map.__getitem__)
181188
label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1)
182189
label_dp = hint_shuffling(label_dp)
183190
label_dp = hint_sharding(label_dp)
@@ -209,8 +216,8 @@ def _generate_categories(self) -> List[Tuple[str, ...]]:
209216

210217
devkit_dp = resources[1].load(self._root)
211218
meta_dp = Filter(devkit_dp, self._filter_meta)
212-
meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids)
219+
meta_dp = CategoryAndWordNetIDExtractor(meta_dp)
213220

214-
categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp)))
221+
categories_and_wnids = cast(List[Tuple[str, ...]], list(meta_dp))
215222
categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1])
216223
return categories_and_wnids

0 commit comments

Comments
 (0)