|
1 | 1 | import enum
|
2 |
| -import functools |
3 | 2 | import pathlib
|
4 | 3 | import re
|
5 |
| -from typing import Any, BinaryIO, cast, Dict, List, Match, Optional, Tuple, Union |
| 4 | + |
| 5 | +from typing import Any, BinaryIO, cast, Dict, Iterator, List, Match, Optional, Tuple, Union |
6 | 6 |
|
7 | 7 | from torchdata.datapipes.iter import (
|
8 | 8 | Demultiplexer,
|
|
14 | 14 | Mapper,
|
15 | 15 | TarArchiveLoader,
|
16 | 16 | )
|
| 17 | +from torchdata.datapipes.map import IterToMapConverter |
17 | 18 | from torchvision.prototype.datasets.utils import Dataset, ManualDownloadResource, OnlineResource
|
18 | 19 | from torchvision.prototype.datasets.utils._internal import (
|
19 | 20 | getitem,
|
@@ -47,6 +48,28 @@ class ImageNetDemux(enum.IntEnum):
|
47 | 48 | LABEL = 1
|
48 | 49 |
|
49 | 50 |
|
| 51 | +class CategoryAndWordNetIDExtractor(IterDataPipe): |
| 52 | + # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 |
| 53 | + # and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment |
| 54 | + _WNID_MAP = { |
| 55 | + "n03126707": "construction crane", |
| 56 | + "n03710721": "tank suit", |
| 57 | + } |
| 58 | + |
| 59 | + def __init__(self, datapipe: IterDataPipe[Tuple[str, BinaryIO]]) -> None: |
| 60 | + self.datapipe = datapipe |
| 61 | + |
| 62 | + def __iter__(self) -> Iterator[Tuple[str, str]]: |
| 63 | + for _, stream in self.datapipe: |
| 64 | + synsets = read_mat(stream, squeeze_me=True)["synsets"] |
| 65 | + for _, wnid, category, _, num_children, *_ in synsets: |
| 66 | + if num_children > 0: |
| 67 | + # we are looking at a superclass that has no direct instance |
| 68 | + continue |
| 69 | + |
| 70 | + yield self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid |
| 71 | + |
| 72 | + |
50 | 73 | @register_dataset(NAME)
|
51 | 74 | class ImageNet(Dataset):
|
52 | 75 | """
|
@@ -110,25 +133,6 @@ def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]:
|
110 | 133 | "ILSVRC2012_validation_ground_truth.txt": ImageNetDemux.LABEL,
|
111 | 134 | }.get(pathlib.Path(data[0]).name)
|
112 | 135 |
|
113 |
| - # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 |
114 |
| - # and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment |
115 |
| - _WNID_MAP = { |
116 |
| - "n03126707": "construction crane", |
117 |
| - "n03710721": "tank suit", |
118 |
| - } |
119 |
| - |
120 |
| - def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]: |
121 |
| - synsets = read_mat(data[1], squeeze_me=True)["synsets"] |
122 |
| - return [ |
123 |
| - (self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid) |
124 |
| - for _, wnid, category, _, num_children, *_ in synsets |
125 |
| - # if num_children > 0, we are looking at a superclass that has no direct instance |
126 |
| - if num_children == 0 |
127 |
| - ] |
128 |
| - |
129 |
| - def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: Tuple[str, ...]) -> str: |
130 |
| - return wnids[int(imagenet_label) - 1] |
131 |
| - |
132 | 136 | _VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")
|
133 | 137 |
|
134 | 138 | def _val_test_image_key(self, path: pathlib.Path) -> int:
|
@@ -172,12 +176,15 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
|
172 | 176 | devkit_dp, 2, self._classifiy_devkit, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
|
173 | 177 | )
|
174 | 178 |
|
175 |
| - meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids) |
176 |
| - _, wnids = zip(*next(iter(meta_dp))) |
| 179 | + # We cannot use self._wnids here, since we use a different order than the dataset |
| 180 | + meta_dp = CategoryAndWordNetIDExtractor(meta_dp) |
| 181 | + wnid_dp = Mapper(meta_dp, getitem(1)) |
| 182 | + wnid_dp = Enumerator(wnid_dp, 1) |
| 183 | + wnid_map = IterToMapConverter(wnid_dp) |
177 | 184 |
|
178 | 185 | label_dp = LineReader(label_dp, decode=True, return_path=False)
|
179 |
| - # We cannot use self._wnids here, since we use a different order than the dataset |
180 |
| - label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids)) |
| 186 | + label_dp = Mapper(label_dp, int) |
| 187 | + label_dp = Mapper(label_dp, wnid_map.__getitem__) |
181 | 188 | label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1)
|
182 | 189 | label_dp = hint_shuffling(label_dp)
|
183 | 190 | label_dp = hint_sharding(label_dp)
|
@@ -209,8 +216,8 @@ def _generate_categories(self) -> List[Tuple[str, ...]]:
|
209 | 216 |
|
210 | 217 | devkit_dp = resources[1].load(self._root)
|
211 | 218 | meta_dp = Filter(devkit_dp, self._filter_meta)
|
212 |
| - meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids) |
| 219 | + meta_dp = CategoryAndWordNetIDExtractor(meta_dp) |
213 | 220 |
|
214 |
| - categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp))) |
| 221 | + categories_and_wnids = cast(List[Tuple[str, ...]], list(meta_dp)) |
215 | 222 | categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1])
|
216 | 223 | return categories_and_wnids
|
0 commit comments