15
15
16
16
"""Tests for all data sources."""
17
17
18
+ import pickle
18
19
from unittest import mock
19
20
21
+ import cloudpickle
20
22
from etils import epath
21
23
import pytest
22
24
import tensorflow_datasets as tfds
23
25
from tensorflow_datasets import testing
24
- from tensorflow_datasets .core import dataset_builder
26
+ from tensorflow_datasets .core import dataset_builder as dataset_builder_lib
25
27
from tensorflow_datasets .core import dataset_info as dataset_info_lib
26
28
from tensorflow_datasets .core import decode
27
29
from tensorflow_datasets .core import file_adapters
@@ -77,7 +79,7 @@ def mocked_parquet_dataset():
77
79
)
78
80
def test_read_write (
79
81
tmp_path : epath .Path ,
80
- builder_cls : dataset_builder .DatasetBuilder ,
82
+ builder_cls : dataset_builder_lib .DatasetBuilder ,
81
83
file_format : file_adapters .FileFormat ,
82
84
):
83
85
builder = builder_cls (data_dir = tmp_path , file_format = file_format )
@@ -106,28 +108,36 @@ def test_read_write(
106
108
]
107
109
108
110
109
- def create_dataset_info (file_format : file_adapters .FileFormat ):
111
+ def create_dataset_builder (
112
+ file_format : file_adapters .FileFormat ,
113
+ ) -> dataset_builder_lib .DatasetBuilder :
110
114
with mock .patch .object (splits_lib , 'SplitInfo' ) as split_mock :
111
115
split_mock .return_value .name = 'train'
112
116
split_mock .return_value .file_instructions = _FILE_INSTRUCTIONS
113
117
dataset_info = mock .create_autospec (dataset_info_lib .DatasetInfo )
114
118
dataset_info .file_format = file_format
115
119
dataset_info .splits = {'train' : split_mock ()}
116
120
dataset_info .name = 'dataset_name'
117
- return dataset_info
121
+
122
+ dataset_builder = mock .create_autospec (dataset_builder_lib .DatasetBuilder )
123
+ dataset_builder_lib .info = dataset_info
124
+
125
+ return dataset_builder
118
126
119
127
120
128
@pytest .mark .parametrize (
121
129
'data_source_cls' ,
122
130
_DATA_SOURCE_CLS ,
123
131
)
124
132
def test_missing_split_raises_error (data_source_cls ):
125
- dataset_info = create_dataset_info (file_adapters .FileFormat .ARRAY_RECORD )
133
+ dataset_builder = create_dataset_builder (
134
+ file_adapters .FileFormat .ARRAY_RECORD
135
+ )
126
136
with pytest .raises (
127
137
ValueError ,
128
138
match = "Unknown split 'doesnotexist'." ,
129
139
):
130
- data_source_cls (dataset_info , split = 'doesnotexist' )
140
+ data_source_cls (dataset_builder , split = 'doesnotexist' )
131
141
132
142
133
143
@pytest .mark .usefixtures (* _FIXTURES )
@@ -136,8 +146,10 @@ def test_missing_split_raises_error(data_source_cls):
136
146
_DATA_SOURCE_CLS ,
137
147
)
138
148
def test_repr_returns_meaningful_string_without_decoders (data_source_cls ):
139
- dataset_info = create_dataset_info (file_adapters .FileFormat .ARRAY_RECORD )
140
- source = data_source_cls (dataset_info , split = 'train' )
149
+ dataset_builder = create_dataset_builder (
150
+ file_adapters .FileFormat .ARRAY_RECORD
151
+ )
152
+ source = data_source_cls (dataset_builder , split = 'train' )
141
153
name = data_source_cls .__name__
142
154
assert (
143
155
repr (source ) == f"{ name } (name=dataset_name, split='train', decoders=None)"
@@ -150,9 +162,11 @@ def test_repr_returns_meaningful_string_without_decoders(data_source_cls):
150
162
_DATA_SOURCE_CLS ,
151
163
)
152
164
def test_repr_returns_meaningful_string_with_decoders (data_source_cls ):
153
- dataset_info = create_dataset_info (file_adapters .FileFormat .ARRAY_RECORD )
165
+ dataset_builder = create_dataset_builder (
166
+ file_adapters .FileFormat .ARRAY_RECORD
167
+ )
154
168
source = data_source_cls (
155
- dataset_info ,
169
+ dataset_builder ,
156
170
split = 'train' ,
157
171
decoders = {'my_feature' : decode .SkipDecoding ()},
158
172
)
@@ -181,3 +195,18 @@ def test_data_source_is_sliceable():
181
195
file_instructions = mock_array_record_data_source .call_args_list [1 ].args [0 ]
182
196
assert file_instructions [0 ].skip == 0
183
197
assert file_instructions [0 ].take == 30000
198
+
199
+
200
+ # PyGrain requires that data sources are picklable.
201
+ @pytest .mark .parametrize (
202
+ 'file_format' ,
203
+ file_adapters .FileFormat .with_random_access (),
204
+ )
205
+ @pytest .mark .parametrize ('pickle_module' , [pickle , cloudpickle ])
206
+ def test_data_source_is_picklable_after_use (file_format , pickle_module ):
207
+ with tfds .testing .tmp_dir () as data_dir :
208
+ builder = tfds .testing .DummyDataset (data_dir = data_dir )
209
+ builder .download_and_prepare (file_format = file_format )
210
+ data_source = builder .as_data_source (split = 'train' )
211
+ assert data_source [0 ] == {'id' : 0 }
212
+ assert pickle_module .loads (pickle_module .dumps (data_source ))[0 ] == {'id' : 0 }
0 commit comments