Skip to content

Commit 75af776

Browse files
authored
add SVHN prototype dataset (#5155)
* add SVHN prototype dataset * add test
1 parent 4d08a67 commit 75af776

File tree

3 files changed

+117
-0
lines changed

3 files changed

+117
-0
lines changed

test/builtin_dataset_mocks.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,3 +1294,23 @@ def generate(cls, root):
12941294
def cub200(info, root, config):
12951295
num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root)
12961296
return {config_: num_samples_map[config_.split] for config_ in info._configs if config_.year == config.year}
1297+
1298+
1299+
@DATASET_MOCKS.set_from_named_callable
1300+
def svhn(info, root, config):
1301+
import scipy.io as sio
1302+
1303+
num_samples = {
1304+
"train": 2,
1305+
"test": 3,
1306+
"extra": 4,
1307+
}[config.split]
1308+
1309+
sio.savemat(
1310+
root / f"{config.split}_32x32.mat",
1311+
{
1312+
"X": np.random.randint(256, size=(32, 32, 3, num_samples), dtype=np.uint8),
1313+
"y": np.random.randint(10, size=(num_samples,), dtype=np.uint8),
1314+
},
1315+
)
1316+
return num_samples

torchvision/prototype/datasets/_builtin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@
1111
from .oxford_iiit_pet import OxfordIITPet
1212
from .sbd import SBD
1313
from .semeion import SEMEION
14+
from .svhn import SVHN
1415
from .voc import VOC
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import functools
2+
import io
3+
from typing import Any, Callable, Dict, List, Optional, Tuple
4+
5+
import numpy as np
6+
import torch
7+
from torchdata.datapipes.iter import (
8+
IterDataPipe,
9+
Mapper,
10+
UnBatcher,
11+
)
12+
from torchvision.prototype.datasets.decoder import raw
13+
from torchvision.prototype.datasets.utils import (
14+
Dataset,
15+
DatasetConfig,
16+
DatasetInfo,
17+
HttpResource,
18+
OnlineResource,
19+
DatasetType,
20+
)
21+
from torchvision.prototype.datasets.utils._internal import (
22+
read_mat,
23+
hint_sharding,
24+
hint_shuffling,
25+
image_buffer_from_array,
26+
)
27+
from torchvision.prototype.features import Label, Image
28+
29+
30+
class SVHN(Dataset):
31+
def _make_info(self) -> DatasetInfo:
32+
return DatasetInfo(
33+
"svhn",
34+
type=DatasetType.RAW,
35+
dependencies=("scipy",),
36+
categories=10,
37+
homepage="http://ufldl.stanford.edu/housenumbers/",
38+
valid_options=dict(split=("train", "test", "extra")),
39+
)
40+
41+
_CHECKSUMS = {
42+
"train": "435e94d69a87fde4fd4d7f3dd208dfc32cb6ae8af2240d066de1df7508d083b8",
43+
"test": "cdce80dfb2a2c4c6160906d0bd7c68ec5a99d7ca4831afa54f09182025b6a75b",
44+
"extra": "a133a4beb38a00fcdda90c9489e0c04f900b660ce8a316a5e854838379a71eb3",
45+
}
46+
47+
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
48+
data = HttpResource(
49+
f"http://ufldl.stanford.edu/housenumbers/{config.split}_32x32.mat",
50+
sha256=self._CHECKSUMS[config.split],
51+
)
52+
53+
return [data]
54+
55+
def _read_images_and_labels(self, data: Tuple[str, io.IOBase]) -> List[Tuple[np.ndarray, np.ndarray]]:
56+
_, buffer = data
57+
content = read_mat(buffer)
58+
return list(
59+
zip(
60+
content["X"].transpose((3, 0, 1, 2)),
61+
content["y"].squeeze(),
62+
)
63+
)
64+
65+
def _collate_and_decode_sample(
66+
self,
67+
data: Tuple[np.ndarray, np.ndarray],
68+
*,
69+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
70+
) -> Dict[str, Any]:
71+
image_array, label_array = data
72+
73+
if decoder is raw:
74+
image = Image(image_array.transpose((2, 0, 1)))
75+
else:
76+
image_buffer = image_buffer_from_array(image_array)
77+
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
78+
79+
return dict(
80+
image=image,
81+
label=Label(int(label_array) % 10),
82+
)
83+
84+
def _make_datapipe(
85+
self,
86+
resource_dps: List[IterDataPipe],
87+
*,
88+
config: DatasetConfig,
89+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
90+
) -> IterDataPipe[Dict[str, Any]]:
91+
dp = resource_dps[0]
92+
dp = Mapper(dp, self._read_images_and_labels)
93+
dp = UnBatcher(dp)
94+
dp = hint_sharding(dp)
95+
dp = hint_shuffling(dp)
96+
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))

0 commit comments

Comments
 (0)