Skip to content

Commit 9457c2a

Browse files
tutmanntensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 519933430
1 parent 270ed2a commit 9457c2a

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

official/core/input_reader.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def _read_tfds(tfds_name: Text,
159159
cycle_length: Optional[int] = None,
160160
block_length: Optional[int] = None) -> tf.data.Dataset:
161161
"""Reads a dataset from tfds."""
162+
repeat_filenames = is_training and not cache
162163
decoders = {}
163164
if tfds_skip_decoding_feature:
164165
for skip_feature in tfds_skip_decoding_feature.split(','):
@@ -170,6 +171,7 @@ def _read_tfds(tfds_name: Text,
170171
interleave_block_length=block_length,
171172
input_context=input_context,
172173
shuffle_seed=seed,
174+
repeat_filenames=repeat_filenames,
173175
skip_prefetch=True)
174176
dataset = tfds.load(name=tfds_name,
175177
split=tfds_split,
@@ -199,6 +201,7 @@ def _read_tfds(tfds_name: Text,
199201
interleave_block_length=block_length,
200202
input_context=None,
201203
shuffle_seed=seed,
204+
repeat_filenames=repeat_filenames,
202205
skip_prefetch=True)
203206
load_kwargs.update({'read_config': read_config})
204207
dataset = tfds.load(**load_kwargs)
@@ -210,12 +213,10 @@ def _read_tfds(tfds_name: Text,
210213
interleave_block_length=block_length,
211214
input_context=input_context,
212215
shuffle_seed=seed,
216+
repeat_filenames=repeat_filenames,
213217
skip_prefetch=True)
214218
load_kwargs.update({'read_config': read_config})
215219
dataset = tfds.load(**load_kwargs)
216-
217-
if is_training and not cache:
218-
dataset = dataset.repeat()
219220
return dataset
220221

221222

0 commit comments

Comments
 (0)