@@ -159,6 +159,7 @@ def _read_tfds(tfds_name: Text,
159
159
cycle_length : Optional [int ] = None ,
160
160
block_length : Optional [int ] = None ) -> tf .data .Dataset :
161
161
"""Reads a dataset from tfds."""
162
+ repeat_filenames = is_training and not cache
162
163
decoders = {}
163
164
if tfds_skip_decoding_feature :
164
165
for skip_feature in tfds_skip_decoding_feature .split (',' ):
@@ -170,6 +171,7 @@ def _read_tfds(tfds_name: Text,
170
171
interleave_block_length = block_length ,
171
172
input_context = input_context ,
172
173
shuffle_seed = seed ,
174
+ repeat_filenames = repeat_filenames ,
173
175
skip_prefetch = True )
174
176
dataset = tfds .load (name = tfds_name ,
175
177
split = tfds_split ,
@@ -199,6 +201,7 @@ def _read_tfds(tfds_name: Text,
199
201
interleave_block_length = block_length ,
200
202
input_context = None ,
201
203
shuffle_seed = seed ,
204
+ repeat_filenames = repeat_filenames ,
202
205
skip_prefetch = True )
203
206
load_kwargs .update ({'read_config' : read_config })
204
207
dataset = tfds .load (** load_kwargs )
@@ -210,12 +213,10 @@ def _read_tfds(tfds_name: Text,
210
213
interleave_block_length = block_length ,
211
214
input_context = input_context ,
212
215
shuffle_seed = seed ,
216
+ repeat_filenames = repeat_filenames ,
213
217
skip_prefetch = True )
214
218
load_kwargs .update ({'read_config' : read_config })
215
219
dataset = tfds .load (** load_kwargs )
216
-
217
- if is_training and not cache :
218
- dataset = dataset .repeat ()
219
220
return dataset
220
221
221
222
0 commit comments