Skip to content

Commit e4e3182

Browse files
tutmanntensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 519933844
1 parent 9457c2a commit e4e3182

File tree

1 file changed

+10
-22
lines changed

1 file changed

+10
-22
lines changed

official/core/input_reader.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
"""A common dataset reader."""
16+
import dataclasses
1617
import random
1718
from typing import Any, Callable, Dict, List, Optional, Sequence, Text, Union
1819

@@ -159,20 +160,20 @@ def _read_tfds(tfds_name: Text,
159160
cycle_length: Optional[int] = None,
160161
block_length: Optional[int] = None) -> tf.data.Dataset:
161162
"""Reads a dataset from tfds."""
162-
repeat_filenames = is_training and not cache
163+
read_config = tfds.ReadConfig(
164+
interleave_cycle_length=cycle_length,
165+
interleave_block_length=block_length,
166+
input_context=input_context,
167+
shuffle_seed=seed,
168+
repeat_filenames=is_training and not cache,
169+
skip_prefetch=True)
170+
163171
decoders = {}
164172
if tfds_skip_decoding_feature:
165173
for skip_feature in tfds_skip_decoding_feature.split(','):
166174
decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
167175

168176
if tfds_name.startswith('mldataset.'):
169-
read_config = tfds.ReadConfig(
170-
interleave_cycle_length=cycle_length,
171-
interleave_block_length=block_length,
172-
input_context=input_context,
173-
shuffle_seed=seed,
174-
repeat_filenames=repeat_filenames,
175-
skip_prefetch=True)
176177
dataset = tfds.load(name=tfds_name,
177178
split=tfds_split,
178179
as_supervised=tfds_as_supervised,
@@ -196,25 +197,12 @@ def _read_tfds(tfds_name: Text,
196197
# The number of files in the dataset split is smaller than the number of
197198
# input pipelines. We read the entire dataset first and then shard in the
198199
# host memory.
199-
read_config = tfds.ReadConfig(
200-
interleave_cycle_length=cycle_length,
201-
interleave_block_length=block_length,
202-
input_context=None,
203-
shuffle_seed=seed,
204-
repeat_filenames=repeat_filenames,
205-
skip_prefetch=True)
200+
read_config = dataclasses.replace(read_config, input_context=None)
206201
load_kwargs.update({'read_config': read_config})
207202
dataset = tfds.load(**load_kwargs)
208203
dataset = dataset.shard(input_context.num_input_pipelines,
209204
input_context.input_pipeline_id)
210205
else:
211-
read_config = tfds.ReadConfig(
212-
interleave_cycle_length=cycle_length,
213-
interleave_block_length=block_length,
214-
input_context=input_context,
215-
shuffle_seed=seed,
216-
repeat_filenames=repeat_filenames,
217-
skip_prefetch=True)
218206
load_kwargs.update({'read_config': read_config})
219207
dataset = tfds.load(**load_kwargs)
220208
return dataset

0 commit comments

Comments
 (0)