10
10
import pickle
11
11
import random
12
12
import tempfile
13
+ import unittest .mock
13
14
import xml .etree .ElementTree as ET
14
15
from collections import defaultdict , Counter , UserDict
15
16
21
22
from torch .nn .functional import one_hot
22
23
from torch .testing import make_tensor as _make_tensor
23
24
from torchvision .prototype import datasets
24
- from torchvision .prototype .datasets ._api import DEFAULT_DECODER_MAP , DEFAULT_DECODER , find
25
+ from torchvision .prototype .datasets ._api import find
26
+ from torchvision .prototype .utils ._internal import sequence_to_str
25
27
26
28
make_tensor = functools .partial (_make_tensor , device = "cpu" )
27
29
make_scalar = functools .partial (make_tensor , ())
@@ -49,7 +51,7 @@ class DatasetMock:
49
51
def __init__ (self , name , mock_data_fn , * , configs = None ):
50
52
self .dataset = find (name )
51
53
self .root = TEST_HOME / self .dataset .name
52
- self .mock_data_fn = self . _parse_mock_data ( mock_data_fn )
54
+ self .mock_data_fn = mock_data_fn
53
55
self .configs = configs or self .info ._configs
54
56
self ._cache = {}
55
57
@@ -61,77 +63,71 @@ def info(self):
61
63
def name (self ):
62
64
return self .info .name
63
65
64
- def _parse_mock_data (self , mock_data_fn ):
65
- def wrapper (info , root , config ):
66
- mock_infos = mock_data_fn (info , root , config )
66
+ def _parse_mock_data (self , config , mock_infos ):
67
+ if mock_infos is None :
68
+ raise pytest .UsageError (
69
+ f"The mock data function for dataset '{ self .name } ' returned nothing. It needs to at least return an "
70
+ f"integer indicating the number of samples for the current `config`."
71
+ )
72
+
73
+ key_types = set (type (key ) for key in mock_infos ) if isinstance (mock_infos , dict ) else {}
74
+ if datasets .utils .DatasetConfig not in key_types :
75
+ mock_infos = {config : mock_infos }
76
+ elif len (key_types ) > 1 :
77
+ raise pytest .UsageError (
78
+ f"Unable to handle the returned dictionary of the mock data function for dataset { self .name } . If "
79
+ f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type."
80
+ )
67
81
68
- if mock_infos is None :
82
+ for config_ , mock_info in list (mock_infos .items ()):
83
+ if config_ in self ._cache :
69
84
raise pytest .UsageError (
70
- f"The mock data function for dataset ' { self .name } ' returned nothing. It needs to at least return an "
71
- f"integer indicating the number of samples for the current `config` ."
85
+ f"The mock info for config { config_ } of dataset { self .name } generated for config { config } "
86
+ f"already exists in the cache ."
72
87
)
73
-
74
- key_types = set (type (key ) for key in mock_infos ) if isinstance (mock_infos , dict ) else {}
75
- if datasets .utils .DatasetConfig not in key_types :
76
- mock_infos = {config : mock_infos }
77
- elif len (key_types ) > 1 :
88
+ if isinstance (mock_info , int ):
89
+ mock_infos [config_ ] = dict (num_samples = mock_info )
90
+ elif not isinstance (mock_info , dict ):
78
91
raise pytest .UsageError (
79
- f"Unable to handle the returned dictionary of the mock data function for dataset { self .name } . If "
80
- f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type."
92
+ f"The mock data function for dataset '{ self .name } ' returned a { type (mock_infos )} for `config` "
93
+ f"{ config_ } . The returned object should be a dictionary containing at least the number of "
94
+ f"samples for the key `'num_samples'`. If no additional information is required for specific "
95
+ f"tests, the number of samples can also be returned as an integer."
96
+ )
97
+ elif "num_samples" not in mock_info :
98
+ raise pytest .UsageError (
99
+ f"The dictionary returned by the mock data function for dataset '{ self .name } ' and config "
100
+ f"{ config_ } has to contain a `'num_samples'` entry indicating the number of samples."
81
101
)
82
102
83
- for config_ , mock_info in list (mock_infos .items ()):
84
- if config_ in self ._cache :
85
- raise pytest .UsageError (
86
- f"The mock info for config { config_ } of dataset { self .name } generated for config { config } "
87
- f"already exists in the cache."
88
- )
89
- if isinstance (mock_info , int ):
90
- mock_infos [config_ ] = dict (num_samples = mock_info )
91
- elif not isinstance (mock_info , dict ):
92
- raise pytest .UsageError (
93
- f"The mock data function for dataset '{ self .name } ' returned a { type (mock_infos )} for `config` "
94
- f"{ config_ } . The returned object should be a dictionary containing at least the number of "
95
- f"samples for the key `'num_samples'`. If no additional information is required for specific "
96
- f"tests, the number of samples can also be returned as an integer."
97
- )
98
- elif "num_samples" not in mock_info :
99
- raise pytest .UsageError (
100
- f"The dictionary returned by the mock data function for dataset '{ self .name } ' and config "
101
- f"{ config_ } has to contain a `'num_samples'` entry indicating the number of samples."
102
- )
103
-
104
- return mock_infos
105
-
106
- return wrapper
103
+ return mock_infos
107
104
108
- def _load_mock (self , config ):
105
+ def _prepare_resources (self , config ):
109
106
with contextlib .suppress (KeyError ):
110
107
return self ._cache [config ]
111
108
112
109
self .root .mkdir (exist_ok = True )
113
- for config_ , mock_info in self .mock_data_fn (self .info , self .root , config ).items ():
114
- mock_resources = [
115
- ResourceMock (dataset_name = self .name , dataset_config = config_ , file_name = resource .file_name )
116
- for resource in self .dataset .resources (config_ )
117
- ]
118
- self ._cache [config_ ] = (mock_resources , mock_info )
110
+ mock_infos = self ._parse_mock_data (config , self .mock_data_fn (self .info , self .root , config ))
111
+
112
+ available_file_names = {path .name for path in self .root .glob ("*" )}
113
+ for config_ , mock_info in mock_infos .items ():
114
+ required_file_names = {resource .file_name for resource in self .dataset .resources (config_ )}
115
+ missing_file_names = required_file_names - available_file_names
116
+ if missing_file_names :
117
+ raise pytest .UsageError (
118
+ f"Dataset '{ self .name } ' requires the files { sequence_to_str (sorted (missing_file_names ))} "
119
+ f"for { config_ } , but they were not created by the mock data function."
120
+ )
121
+
122
+ self ._cache [config_ ] = mock_info
119
123
120
124
return self ._cache [config ]
121
125
122
- def load (self , config , * , decoder = DEFAULT_DECODER ):
123
- try :
124
- self .info .check_dependencies ()
125
- except ModuleNotFoundError as error :
126
- pytest .skip (str (error ))
127
-
128
- mock_resources , mock_info = self ._load_mock (config )
129
- datapipe = self .dataset ._make_datapipe (
130
- [resource .load (self .root ) for resource in mock_resources ],
131
- config = config ,
132
- decoder = DEFAULT_DECODER_MAP .get (self .info .type ) if decoder is DEFAULT_DECODER else decoder ,
133
- )
134
- return datapipe , mock_info
126
+ @contextlib .contextmanager
127
+ def prepare (self , config ):
128
+ mock_info = self ._prepare_resources (config )
129
+ with unittest .mock .patch ("torchvision.prototype.datasets._api.home" , return_value = str (TEST_HOME )):
130
+ yield mock_info
135
131
136
132
137
133
def config_id (name , config ):
@@ -1000,7 +996,7 @@ def dtd(info, root, _):
1000
996
def fer2013 (info , root , config ):
1001
997
num_samples = 5 if config .split == "train" else 3
1002
998
1003
- path = root / f"{ config .split } .txt "
999
+ path = root / f"{ config .split } .csv "
1004
1000
with open (path , "w" , newline = "" ) as file :
1005
1001
field_names = ["emotion" ] if config .split == "train" else []
1006
1002
field_names .append ("pixels" )
@@ -1061,7 +1057,7 @@ def clevr(info, root, config):
1061
1057
file ,
1062
1058
)
1063
1059
1064
- make_zip (root , f"{ data_folder .name } .zip" )
1060
+ make_zip (root , f"{ data_folder .name } .zip" , data_folder )
1065
1061
1066
1062
return {config_ : num_samples_map [config_ .split ] for config_ in info ._configs }
1067
1063
@@ -1121,8 +1117,8 @@ def generate(self, root):
1121
1117
for path in segmentation_files :
1122
1118
path .with_name (f".{ path .name } " ).touch ()
1123
1119
1124
- make_tar (root , "images.tar" )
1125
- make_tar (root , anns_folder .with_suffix (".tar" ).name )
1120
+ make_tar (root , "images.tar.gz" , compression = "gz " )
1121
+ make_tar (root , anns_folder .with_suffix (".tar.gz " ).name , compression = "gz" )
1126
1122
1127
1123
return num_samples_map
1128
1124
@@ -1211,7 +1207,7 @@ def _make_segmentations(cls, root, image_files):
1211
1207
size = [1 , * make_tensor ((2 ,), low = 3 , dtype = torch .int ).tolist ()],
1212
1208
)
1213
1209
1214
- make_tar (root , segmentations_folder .with_suffix (".tgz" ).name )
1210
+ make_tar (root , segmentations_folder .with_suffix (".tgz" ).name , compression = "gz" )
1215
1211
1216
1212
@classmethod
1217
1213
def generate (cls , root ):
@@ -1298,3 +1294,23 @@ def generate(cls, root):
1298
1294
def cub200 (info , root , config ):
1299
1295
num_samples_map = (CUB2002011MockData if config .year == "2011" else CUB2002010MockData ).generate (root )
1300
1296
return {config_ : num_samples_map [config_ .split ] for config_ in info ._configs if config_ .year == config .year }
1297
+
1298
+
1299
+ @DATASET_MOCKS .set_from_named_callable
1300
+ def svhn (info , root , config ):
1301
+ import scipy .io as sio
1302
+
1303
+ num_samples = {
1304
+ "train" : 2 ,
1305
+ "test" : 3 ,
1306
+ "extra" : 4 ,
1307
+ }[config .split ]
1308
+
1309
+ sio .savemat (
1310
+ root / f"{ config .split } _32x32.mat" ,
1311
+ {
1312
+ "X" : np .random .randint (256 , size = (32 , 32 , 3 , num_samples ), dtype = np .uint8 ),
1313
+ "y" : np .random .randint (10 , size = (num_samples ,), dtype = np .uint8 ),
1314
+ },
1315
+ )
1316
+ return num_samples
0 commit comments