Skip to content

Commit 2edbd8d

Browse files
authored
Merge branch 'main' into models/convnext
2 parents daf07e0 + 75af776 commit 2edbd8d

File tree

7 files changed

+236
-82
lines changed

7 files changed

+236
-82
lines changed

test/builtin_dataset_mocks.py

Lines changed: 80 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pickle
1111
import random
1212
import tempfile
13+
import unittest.mock
1314
import xml.etree.ElementTree as ET
1415
from collections import defaultdict, Counter, UserDict
1516

@@ -21,7 +22,8 @@
2122
from torch.nn.functional import one_hot
2223
from torch.testing import make_tensor as _make_tensor
2324
from torchvision.prototype import datasets
24-
from torchvision.prototype.datasets._api import DEFAULT_DECODER_MAP, DEFAULT_DECODER, find
25+
from torchvision.prototype.datasets._api import find
26+
from torchvision.prototype.utils._internal import sequence_to_str
2527

2628
make_tensor = functools.partial(_make_tensor, device="cpu")
2729
make_scalar = functools.partial(make_tensor, ())
@@ -49,7 +51,7 @@ class DatasetMock:
4951
def __init__(self, name, mock_data_fn, *, configs=None):
5052
self.dataset = find(name)
5153
self.root = TEST_HOME / self.dataset.name
52-
self.mock_data_fn = self._parse_mock_data(mock_data_fn)
54+
self.mock_data_fn = mock_data_fn
5355
self.configs = configs or self.info._configs
5456
self._cache = {}
5557

@@ -61,77 +63,71 @@ def info(self):
6163
def name(self):
6264
return self.info.name
6365

64-
def _parse_mock_data(self, mock_data_fn):
65-
def wrapper(info, root, config):
66-
mock_infos = mock_data_fn(info, root, config)
66+
def _parse_mock_data(self, config, mock_infos):
67+
if mock_infos is None:
68+
raise pytest.UsageError(
69+
f"The mock data function for dataset '{self.name}' returned nothing. It needs to at least return an "
70+
f"integer indicating the number of samples for the current `config`."
71+
)
72+
73+
key_types = set(type(key) for key in mock_infos) if isinstance(mock_infos, dict) else {}
74+
if datasets.utils.DatasetConfig not in key_types:
75+
mock_infos = {config: mock_infos}
76+
elif len(key_types) > 1:
77+
raise pytest.UsageError(
78+
f"Unable to handle the returned dictionary of the mock data function for dataset {self.name}. If "
79+
f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type."
80+
)
6781

68-
if mock_infos is None:
82+
for config_, mock_info in list(mock_infos.items()):
83+
if config_ in self._cache:
6984
raise pytest.UsageError(
70-
f"The mock data function for dataset '{self.name}' returned nothing. It needs to at least return an "
71-
f"integer indicating the number of samples for the current `config`."
85+
f"The mock info for config {config_} of dataset {self.name} generated for config {config} "
86+
f"already exists in the cache."
7287
)
73-
74-
key_types = set(type(key) for key in mock_infos) if isinstance(mock_infos, dict) else {}
75-
if datasets.utils.DatasetConfig not in key_types:
76-
mock_infos = {config: mock_infos}
77-
elif len(key_types) > 1:
88+
if isinstance(mock_info, int):
89+
mock_infos[config_] = dict(num_samples=mock_info)
90+
elif not isinstance(mock_info, dict):
7891
raise pytest.UsageError(
79-
f"Unable to handle the returned dictionary of the mock data function for dataset {self.name}. If "
80-
f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type."
92+
f"The mock data function for dataset '{self.name}' returned a {type(mock_infos)} for `config` "
93+
f"{config_}. The returned object should be a dictionary containing at least the number of "
94+
f"samples for the key `'num_samples'`. If no additional information is required for specific "
95+
f"tests, the number of samples can also be returned as an integer."
96+
)
97+
elif "num_samples" not in mock_info:
98+
raise pytest.UsageError(
99+
f"The dictionary returned by the mock data function for dataset '{self.name}' and config "
100+
f"{config_} has to contain a `'num_samples'` entry indicating the number of samples."
81101
)
82102

83-
for config_, mock_info in list(mock_infos.items()):
84-
if config_ in self._cache:
85-
raise pytest.UsageError(
86-
f"The mock info for config {config_} of dataset {self.name} generated for config {config} "
87-
f"already exists in the cache."
88-
)
89-
if isinstance(mock_info, int):
90-
mock_infos[config_] = dict(num_samples=mock_info)
91-
elif not isinstance(mock_info, dict):
92-
raise pytest.UsageError(
93-
f"The mock data function for dataset '{self.name}' returned a {type(mock_infos)} for `config` "
94-
f"{config_}. The returned object should be a dictionary containing at least the number of "
95-
f"samples for the key `'num_samples'`. If no additional information is required for specific "
96-
f"tests, the number of samples can also be returned as an integer."
97-
)
98-
elif "num_samples" not in mock_info:
99-
raise pytest.UsageError(
100-
f"The dictionary returned by the mock data function for dataset '{self.name}' and config "
101-
f"{config_} has to contain a `'num_samples'` entry indicating the number of samples."
102-
)
103-
104-
return mock_infos
105-
106-
return wrapper
103+
return mock_infos
107104

108-
def _load_mock(self, config):
105+
def _prepare_resources(self, config):
109106
with contextlib.suppress(KeyError):
110107
return self._cache[config]
111108

112109
self.root.mkdir(exist_ok=True)
113-
for config_, mock_info in self.mock_data_fn(self.info, self.root, config).items():
114-
mock_resources = [
115-
ResourceMock(dataset_name=self.name, dataset_config=config_, file_name=resource.file_name)
116-
for resource in self.dataset.resources(config_)
117-
]
118-
self._cache[config_] = (mock_resources, mock_info)
110+
mock_infos = self._parse_mock_data(config, self.mock_data_fn(self.info, self.root, config))
111+
112+
available_file_names = {path.name for path in self.root.glob("*")}
113+
for config_, mock_info in mock_infos.items():
114+
required_file_names = {resource.file_name for resource in self.dataset.resources(config_)}
115+
missing_file_names = required_file_names - available_file_names
116+
if missing_file_names:
117+
raise pytest.UsageError(
118+
f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} "
119+
f"for {config_}, but they were not created by the mock data function."
120+
)
121+
122+
self._cache[config_] = mock_info
119123

120124
return self._cache[config]
121125

122-
def load(self, config, *, decoder=DEFAULT_DECODER):
123-
try:
124-
self.info.check_dependencies()
125-
except ModuleNotFoundError as error:
126-
pytest.skip(str(error))
127-
128-
mock_resources, mock_info = self._load_mock(config)
129-
datapipe = self.dataset._make_datapipe(
130-
[resource.load(self.root) for resource in mock_resources],
131-
config=config,
132-
decoder=DEFAULT_DECODER_MAP.get(self.info.type) if decoder is DEFAULT_DECODER else decoder,
133-
)
134-
return datapipe, mock_info
126+
@contextlib.contextmanager
127+
def prepare(self, config):
128+
mock_info = self._prepare_resources(config)
129+
with unittest.mock.patch("torchvision.prototype.datasets._api.home", return_value=str(TEST_HOME)):
130+
yield mock_info
135131

136132

137133
def config_id(name, config):
@@ -1000,7 +996,7 @@ def dtd(info, root, _):
1000996
def fer2013(info, root, config):
1001997
num_samples = 5 if config.split == "train" else 3
1002998

1003-
path = root / f"{config.split}.txt"
999+
path = root / f"{config.split}.csv"
10041000
with open(path, "w", newline="") as file:
10051001
field_names = ["emotion"] if config.split == "train" else []
10061002
field_names.append("pixels")
@@ -1061,7 +1057,7 @@ def clevr(info, root, config):
10611057
file,
10621058
)
10631059

1064-
make_zip(root, f"{data_folder.name}.zip")
1060+
make_zip(root, f"{data_folder.name}.zip", data_folder)
10651061

10661062
return {config_: num_samples_map[config_.split] for config_ in info._configs}
10671063

@@ -1121,8 +1117,8 @@ def generate(self, root):
11211117
for path in segmentation_files:
11221118
path.with_name(f".{path.name}").touch()
11231119

1124-
make_tar(root, "images.tar")
1125-
make_tar(root, anns_folder.with_suffix(".tar").name)
1120+
make_tar(root, "images.tar.gz", compression="gz")
1121+
make_tar(root, anns_folder.with_suffix(".tar.gz").name, compression="gz")
11261122

11271123
return num_samples_map
11281124

@@ -1211,7 +1207,7 @@ def _make_segmentations(cls, root, image_files):
12111207
size=[1, *make_tensor((2,), low=3, dtype=torch.int).tolist()],
12121208
)
12131209

1214-
make_tar(root, segmentations_folder.with_suffix(".tgz").name)
1210+
make_tar(root, segmentations_folder.with_suffix(".tgz").name, compression="gz")
12151211

12161212
@classmethod
12171213
def generate(cls, root):
@@ -1298,3 +1294,23 @@ def generate(cls, root):
12981294
def cub200(info, root, config):
12991295
num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root)
13001296
return {config_: num_samples_map[config_.split] for config_ in info._configs if config_.year == config.year}
1297+
1298+
1299+
@DATASET_MOCKS.set_from_named_callable
1300+
def svhn(info, root, config):
1301+
import scipy.io as sio
1302+
1303+
num_samples = {
1304+
"train": 2,
1305+
"test": 3,
1306+
"extra": 4,
1307+
}[config.split]
1308+
1309+
sio.savemat(
1310+
root / f"{config.split}_32x32.mat",
1311+
{
1312+
"X": np.random.randint(256, size=(32, 32, 3, num_samples), dtype=np.uint8),
1313+
"y": np.random.randint(10, size=(num_samples,), dtype=np.uint8),
1314+
},
1315+
)
1316+
return num_samples

test/datasets_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -868,9 +868,13 @@ def _split_files_or_dirs(root, *files_or_dirs):
868868
def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True):
869869
archive = pathlib.Path(root) / name
870870
if not files_or_dirs:
871-
dir = archive.with_suffix("")
872-
if dir.exists() and dir.is_dir():
873-
files_or_dirs = (dir,)
871+
# We need to invoke `Path.with_suffix("")`, since call only applies to the last suffix if multiple suffixes are
872+
# present. For example, `pathlib.Path("foo.tar.gz").with_suffix("")` results in `foo.tar`.
873+
file_or_dir = archive
874+
for _ in range(len(archive.suffixes)):
875+
file_or_dir = file_or_dir.with_suffix("")
876+
if file_or_dir.exists():
877+
files_or_dirs = (file_or_dir,)
874878
else:
875879
raise ValueError("No file or dir provided.")
876880

test/test_datasets_utils.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import contextlib
22
import gzip
33
import os
4+
import pathlib
5+
import re
46
import tarfile
57
import zipfile
68

79
import pytest
810
import torchvision.datasets.utils as utils
911
from torch._utils_internal import get_file_path_2
12+
from torchvision.datasets.folder import make_dataset
1013
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS
1114

12-
1315
TEST_FILE = get_file_path_2(
1416
os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
1517
)
@@ -214,5 +216,29 @@ def test_verify_str_arg(self):
214216
pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")
215217

216218

219+
@pytest.mark.parametrize(
220+
("kwargs", "expected_error_msg"),
221+
[
222+
(dict(is_valid_file=lambda path: pathlib.Path(path).suffix in {".png", ".jpeg"}), "classes c"),
223+
(dict(extensions=".png"), re.escape("classes b, c. Supported extensions are: .png")),
224+
(dict(extensions=(".png", ".jpeg")), re.escape("classes c. Supported extensions are: .png, .jpeg")),
225+
],
226+
)
227+
def test_make_dataset_no_valid_files(tmpdir, kwargs, expected_error_msg):
228+
tmpdir = pathlib.Path(tmpdir)
229+
230+
(tmpdir / "a").mkdir()
231+
(tmpdir / "a" / "a.png").touch()
232+
233+
(tmpdir / "b").mkdir()
234+
(tmpdir / "b" / "b.jpeg").touch()
235+
236+
(tmpdir / "c").mkdir()
237+
(tmpdir / "c" / "c.unknown").touch()
238+
239+
with pytest.raises(FileNotFoundError, match=expected_error_msg):
240+
make_dataset(str(tmpdir), **kwargs)
241+
242+
217243
if __name__ == "__main__":
218244
pytest.main([__file__])

test/test_prototype_builtin_datasets.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,16 @@ def test_coverage():
2323
class TestCommon:
2424
@parametrize_dataset_mocks(DATASET_MOCKS)
2525
def test_smoke(self, dataset_mock, config):
26-
dataset, _ = dataset_mock.load(config)
26+
with dataset_mock.prepare(config):
27+
dataset = datasets.load(dataset_mock.name, **config)
28+
2729
if not isinstance(dataset, IterDataPipe):
2830
raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.")
2931

3032
@parametrize_dataset_mocks(DATASET_MOCKS)
3133
def test_sample(self, dataset_mock, config):
32-
dataset, _ = dataset_mock.load(config)
34+
with dataset_mock.prepare(config):
35+
dataset = datasets.load(dataset_mock.name, **config)
3336

3437
try:
3538
sample = next(iter(dataset))
@@ -44,7 +47,8 @@ def test_sample(self, dataset_mock, config):
4447

4548
@parametrize_dataset_mocks(DATASET_MOCKS)
4649
def test_num_samples(self, dataset_mock, config):
47-
dataset, mock_info = dataset_mock.load(config)
50+
with dataset_mock.prepare(config) as mock_info:
51+
dataset = datasets.load(dataset_mock.name, **config)
4852

4953
num_samples = 0
5054
for _ in dataset:
@@ -54,7 +58,8 @@ def test_num_samples(self, dataset_mock, config):
5458

5559
@parametrize_dataset_mocks(DATASET_MOCKS)
5660
def test_decoding(self, dataset_mock, config):
57-
dataset, _ = dataset_mock.load(config)
61+
with dataset_mock.prepare(config):
62+
dataset = datasets.load(dataset_mock.name, **config)
5863

5964
undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)}
6065
if undecoded_features:
@@ -65,7 +70,8 @@ def test_decoding(self, dataset_mock, config):
6570

6671
@parametrize_dataset_mocks(DATASET_MOCKS)
6772
def test_no_vanilla_tensors(self, dataset_mock, config):
68-
dataset, _ = dataset_mock.load(config)
73+
with dataset_mock.prepare(config):
74+
dataset = datasets.load(dataset_mock.name, **config)
6975

7076
vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor}
7177
if vanilla_tensors:
@@ -76,7 +82,8 @@ def test_no_vanilla_tensors(self, dataset_mock, config):
7682

7783
@parametrize_dataset_mocks(DATASET_MOCKS)
7884
def test_transformable(self, dataset_mock, config):
79-
dataset, _ = dataset_mock.load(config)
85+
with dataset_mock.prepare(config):
86+
dataset = datasets.load(dataset_mock.name, **config)
8087

8188
next(iter(dataset.map(transforms.Identity())))
8289

@@ -89,7 +96,8 @@ def test_transformable(self, dataset_mock, config):
8996
},
9097
)
9198
def test_traversable(self, dataset_mock, config):
92-
dataset, _ = dataset_mock.load(config)
99+
with dataset_mock.prepare(config):
100+
dataset = datasets.load(dataset_mock.name, **config)
93101

94102
traverse(dataset)
95103

@@ -108,7 +116,8 @@ def scan(graph):
108116
yield node
109117
yield from scan(sub_graph)
110118

111-
dataset, _ = dataset_mock.load(config)
119+
with dataset_mock.prepare(config):
120+
dataset = datasets.load(dataset_mock.name, **config)
112121

113122
for dp in scan(traverse(dataset)):
114123
if type(dp) is annotation_dp_type:
@@ -120,7 +129,8 @@ def scan(graph):
120129
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
121130
class TestQMNIST:
122131
def test_extra_label(self, dataset_mock, config):
123-
dataset, _ = dataset_mock.load(config)
132+
with dataset_mock.prepare(config):
133+
dataset = datasets.load(dataset_mock.name, **config)
124134

125135
sample = next(iter(dataset))
126136
for key, type in (

0 commit comments

Comments
 (0)