|
| 1 | +import functools |
| 2 | +import io |
| 3 | +from typing import Any, Callable, Dict, List, Optional, Tuple |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import torch |
| 7 | +from torchdata.datapipes.iter import ( |
| 8 | + IterDataPipe, |
| 9 | + Mapper, |
| 10 | + UnBatcher, |
| 11 | +) |
| 12 | +from torchvision.prototype.datasets.decoder import raw |
| 13 | +from torchvision.prototype.datasets.utils import ( |
| 14 | + Dataset, |
| 15 | + DatasetConfig, |
| 16 | + DatasetInfo, |
| 17 | + HttpResource, |
| 18 | + OnlineResource, |
| 19 | + DatasetType, |
| 20 | +) |
| 21 | +from torchvision.prototype.datasets.utils._internal import ( |
| 22 | + read_mat, |
| 23 | + hint_sharding, |
| 24 | + hint_shuffling, |
| 25 | + image_buffer_from_array, |
| 26 | +) |
| 27 | +from torchvision.prototype.features import Label, Image |
| 28 | + |
| 29 | + |
| 30 | +class SVHN(Dataset): |
| 31 | + def _make_info(self) -> DatasetInfo: |
| 32 | + return DatasetInfo( |
| 33 | + "svhn", |
| 34 | + type=DatasetType.RAW, |
| 35 | + dependencies=("scipy",), |
| 36 | + categories=10, |
| 37 | + homepage="http://ufldl.stanford.edu/housenumbers/", |
| 38 | + valid_options=dict(split=("train", "test", "extra")), |
| 39 | + ) |
| 40 | + |
| 41 | + _CHECKSUMS = { |
| 42 | + "train": "435e94d69a87fde4fd4d7f3dd208dfc32cb6ae8af2240d066de1df7508d083b8", |
| 43 | + "test": "cdce80dfb2a2c4c6160906d0bd7c68ec5a99d7ca4831afa54f09182025b6a75b", |
| 44 | + "extra": "a133a4beb38a00fcdda90c9489e0c04f900b660ce8a316a5e854838379a71eb3", |
| 45 | + } |
| 46 | + |
| 47 | + def resources(self, config: DatasetConfig) -> List[OnlineResource]: |
| 48 | + data = HttpResource( |
| 49 | + f"http://ufldl.stanford.edu/housenumbers/{config.split}_32x32.mat", |
| 50 | + sha256=self._CHECKSUMS[config.split], |
| 51 | + ) |
| 52 | + |
| 53 | + return [data] |
| 54 | + |
| 55 | + def _read_images_and_labels(self, data: Tuple[str, io.IOBase]) -> List[Tuple[np.ndarray, np.ndarray]]: |
| 56 | + _, buffer = data |
| 57 | + content = read_mat(buffer) |
| 58 | + return list( |
| 59 | + zip( |
| 60 | + content["X"].transpose((3, 0, 1, 2)), |
| 61 | + content["y"].squeeze(), |
| 62 | + ) |
| 63 | + ) |
| 64 | + |
| 65 | + def _collate_and_decode_sample( |
| 66 | + self, |
| 67 | + data: Tuple[np.ndarray, np.ndarray], |
| 68 | + *, |
| 69 | + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], |
| 70 | + ) -> Dict[str, Any]: |
| 71 | + image_array, label_array = data |
| 72 | + |
| 73 | + if decoder is raw: |
| 74 | + image = Image(image_array.transpose((2, 0, 1))) |
| 75 | + else: |
| 76 | + image_buffer = image_buffer_from_array(image_array) |
| 77 | + image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] |
| 78 | + |
| 79 | + return dict( |
| 80 | + image=image, |
| 81 | + label=Label(int(label_array) % 10), |
| 82 | + ) |
| 83 | + |
| 84 | + def _make_datapipe( |
| 85 | + self, |
| 86 | + resource_dps: List[IterDataPipe], |
| 87 | + *, |
| 88 | + config: DatasetConfig, |
| 89 | + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], |
| 90 | + ) -> IterDataPipe[Dict[str, Any]]: |
| 91 | + dp = resource_dps[0] |
| 92 | + dp = Mapper(dp, self._read_images_and_labels) |
| 93 | + dp = UnBatcher(dp) |
| 94 | + dp = hint_sharding(dp) |
| 95 | + dp = hint_shuffling(dp) |
| 96 | + return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) |
0 commit comments