13
13
# limitations under the License.
14
14
15
15
"""A common dataset reader."""
16
+ import dataclasses
16
17
import random
17
18
from typing import Any , Callable , Dict , List , Optional , Sequence , Text , Union
18
19
@@ -159,20 +160,20 @@ def _read_tfds(tfds_name: Text,
159
160
cycle_length : Optional [int ] = None ,
160
161
block_length : Optional [int ] = None ) -> tf .data .Dataset :
161
162
"""Reads a dataset from tfds."""
162
- repeat_filenames = is_training and not cache
163
+ read_config = tfds .ReadConfig (
164
+ interleave_cycle_length = cycle_length ,
165
+ interleave_block_length = block_length ,
166
+ input_context = input_context ,
167
+ shuffle_seed = seed ,
168
+ repeat_filenames = is_training and not cache ,
169
+ skip_prefetch = True )
170
+
163
171
decoders = {}
164
172
if tfds_skip_decoding_feature :
165
173
for skip_feature in tfds_skip_decoding_feature .split (',' ):
166
174
decoders [skip_feature .strip ()] = tfds .decode .SkipDecoding ()
167
175
168
176
if tfds_name .startswith ('mldataset.' ):
169
- read_config = tfds .ReadConfig (
170
- interleave_cycle_length = cycle_length ,
171
- interleave_block_length = block_length ,
172
- input_context = input_context ,
173
- shuffle_seed = seed ,
174
- repeat_filenames = repeat_filenames ,
175
- skip_prefetch = True )
176
177
dataset = tfds .load (name = tfds_name ,
177
178
split = tfds_split ,
178
179
as_supervised = tfds_as_supervised ,
@@ -196,25 +197,12 @@ def _read_tfds(tfds_name: Text,
196
197
# The number of files in the dataset split is smaller than the number of
197
198
# input pipelines. We read the entire dataset first and then shard in the
198
199
# host memory.
199
- read_config = tfds .ReadConfig (
200
- interleave_cycle_length = cycle_length ,
201
- interleave_block_length = block_length ,
202
- input_context = None ,
203
- shuffle_seed = seed ,
204
- repeat_filenames = repeat_filenames ,
205
- skip_prefetch = True )
200
+ read_config = dataclasses .replace (read_config , input_context = None )
206
201
load_kwargs .update ({'read_config' : read_config })
207
202
dataset = tfds .load (** load_kwargs )
208
203
dataset = dataset .shard (input_context .num_input_pipelines ,
209
204
input_context .input_pipeline_id )
210
205
else :
211
- read_config = tfds .ReadConfig (
212
- interleave_cycle_length = cycle_length ,
213
- interleave_block_length = block_length ,
214
- input_context = input_context ,
215
- shuffle_seed = seed ,
216
- repeat_filenames = repeat_filenames ,
217
- skip_prefetch = True )
218
206
load_kwargs .update ({'read_config' : read_config })
219
207
dataset = tfds .load (** load_kwargs )
220
208
return dataset
0 commit comments