Skip to content

Commit 1e4bf10

Browse files
marcenacpThe TensorFlow Datasets Authors
authored and
The TensorFlow Datasets Authors
committed
Make tfds.data_source pickable.
PiperOrigin-RevId: 636824581
1 parent 6bbba45 commit 1e4bf10

File tree

7 files changed

+110
-44
lines changed

7 files changed

+110
-44
lines changed

Diff for: tensorflow_datasets/core/data_sources/array_record.py

+2-16
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,8 @@
2020
"""
2121

2222
import dataclasses
23-
from typing import Any, Optional
2423

25-
from tensorflow_datasets.core import dataset_info as dataset_info_lib
26-
from tensorflow_datasets.core import decode
27-
from tensorflow_datasets.core import splits as splits_lib
2824
from tensorflow_datasets.core.data_sources import base
29-
from tensorflow_datasets.core.utils import type_utils
3025
from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_data_source
3126

3227

@@ -42,18 +37,9 @@ class ArrayRecordDataSource(base.BaseDataSource):
4237
source.
4338
"""
4439

45-
dataset_info: dataset_info_lib.DatasetInfo
46-
split: splits_lib.Split = None
47-
decoders: Optional[type_utils.TreeDict[decode.partial_decode.DecoderArg]] = (
48-
None
49-
)
50-
# In order to lazy load array_record, we don't load
51-
# `array_record_data_source.ArrayRecordDataSource` here.
52-
data_source: Any = dataclasses.field(init=False)
53-
length: int = dataclasses.field(init=False)
54-
5540
def __post_init__(self):
56-
file_instructions = base.file_instructions(self.dataset_info, self.split)
41+
dataset_info = self.dataset_builder.info
42+
file_instructions = base.file_instructions(dataset_info, self.split)
5743
self.data_source = array_record_data_source.ArrayRecordDataSource(
5844
file_instructions
5945
)

Diff for: tensorflow_datasets/core/data_sources/base.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717

1818
from collections.abc import MappingView, Sequence
1919
import dataclasses
20+
import functools
2021
import typing
2122
from typing import Any, Generic, Iterable, Protocol, SupportsIndex, TypeVar
2223

2324
from tensorflow_datasets.core import dataset_info as dataset_info_lib
2425
from tensorflow_datasets.core import decode
2526
from tensorflow_datasets.core import splits as splits_lib
27+
from tensorflow_datasets.core.features import top_level_feature
2628
from tensorflow_datasets.core.utils import shard_utils
2729
from tensorflow_datasets.core.utils import type_utils
2830
from tensorflow_datasets.core.utils.lazy_imports_utils import tree
@@ -54,6 +56,14 @@ def file_instructions(
5456
return split_dict[split].file_instructions
5557

5658

59+
class _DatasetBuilder(Protocol):
60+
"""Protocol for the DatasetBuilder to avoid cyclic imports."""
61+
62+
@property
63+
def info(self) -> dataset_info_lib.DatasetInfo:
64+
...
65+
66+
5767
@dataclasses.dataclass
5868
class BaseDataSource(MappingView, Sequence):
5969
"""Base DataSource to override all dunder methods with the deserialization.
@@ -64,22 +74,28 @@ class BaseDataSource(MappingView, Sequence):
6474
deserialization/decoding.
6575
6676
Attributes:
67-
dataset_info: The DatasetInfo of the
77+
dataset_builder: The dataset builder.
6878
split: The split to load in the data source.
6979
decoders: Optional decoders for decoding.
7080
data_source: The underlying data source to initialize in the __post_init__.
7181
"""
7282

73-
dataset_info: dataset_info_lib.DatasetInfo
83+
dataset_builder: _DatasetBuilder
7484
split: splits_lib.Split | None = None
7585
decoders: type_utils.TreeDict[decode.partial_decode.DecoderArg] | None = None
7686
data_source: DataSource[Any] = dataclasses.field(init=False)
7787

88+
@functools.cached_property
89+
def _features(self) -> top_level_feature.TopLevelFeature:
90+
"""Caches features because we log the use of dataset_builder.info."""
91+
features = self.dataset_builder.info.features
92+
if not features:
93+
raise ValueError('No feature defined in the dataset builder.')
94+
return features
95+
7896
def __getitem__(self, key: SupportsIndex) -> Any:
7997
record = self.data_source[key.__index__()]
80-
return self.dataset_info.features.deserialize_example_np(
81-
record, decoders=self.decoders
82-
)
98+
return self._features.deserialize_example_np(record, decoders=self.decoders)
8399

84100
def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]:
85101
"""Retrieves items by batch.
@@ -98,24 +114,24 @@ def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]:
98114
if not keys:
99115
return []
100116
records = self.data_source.__getitems__(keys)
101-
features = self.dataset_info.features
102117
if len(keys) != len(records):
103118
raise IndexError(
104119
f'Requested {len(keys)} records but got'
105120
f' {len(records)} records.'
106121
f'{keys=}, {records=}'
107122
)
108123
return [
109-
features.deserialize_example_np(record, decoders=self.decoders)
124+
self._features.deserialize_example_np(record, decoders=self.decoders)
110125
for record in records
111126
]
112127

113128
def __repr__(self) -> str:
114129
decoders_repr = (
115130
tree.map_structure(type, self.decoders) if self.decoders else None
116131
)
132+
name = self.dataset_builder.info.name
117133
return (
118-
f'{self.__class__.__name__}(name={self.dataset_info.name}, '
134+
f'{self.__class__.__name__}(name={name}, '
119135
f'split={self.split!r}, '
120136
f'decoders={decoders_repr})'
121137
)

Diff for: tensorflow_datasets/core/data_sources/base_test.py

+39-10
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515

1616
"""Tests for all data sources."""
1717

18+
import pickle
1819
from unittest import mock
1920

21+
import cloudpickle
2022
from etils import epath
2123
import pytest
2224
import tensorflow_datasets as tfds
2325
from tensorflow_datasets import testing
24-
from tensorflow_datasets.core import dataset_builder
26+
from tensorflow_datasets.core import dataset_builder as dataset_builder_lib
2527
from tensorflow_datasets.core import dataset_info as dataset_info_lib
2628
from tensorflow_datasets.core import decode
2729
from tensorflow_datasets.core import file_adapters
@@ -77,7 +79,7 @@ def mocked_parquet_dataset():
7779
)
7880
def test_read_write(
7981
tmp_path: epath.Path,
80-
builder_cls: dataset_builder.DatasetBuilder,
82+
builder_cls: dataset_builder_lib.DatasetBuilder,
8183
file_format: file_adapters.FileFormat,
8284
):
8385
builder = builder_cls(data_dir=tmp_path, file_format=file_format)
@@ -106,28 +108,36 @@ def test_read_write(
106108
]
107109

108110

109-
def create_dataset_info(file_format: file_adapters.FileFormat):
111+
def create_dataset_builder(
112+
file_format: file_adapters.FileFormat,
113+
) -> dataset_builder_lib.DatasetBuilder:
110114
with mock.patch.object(splits_lib, 'SplitInfo') as split_mock:
111115
split_mock.return_value.name = 'train'
112116
split_mock.return_value.file_instructions = _FILE_INSTRUCTIONS
113117
dataset_info = mock.create_autospec(dataset_info_lib.DatasetInfo)
114118
dataset_info.file_format = file_format
115119
dataset_info.splits = {'train': split_mock()}
116120
dataset_info.name = 'dataset_name'
117-
return dataset_info
121+
122+
dataset_builder = mock.create_autospec(dataset_builder_lib.DatasetBuilder)
123+
dataset_builder.info = dataset_info
124+
125+
return dataset_builder
118126

119127

120128
@pytest.mark.parametrize(
121129
'data_source_cls',
122130
_DATA_SOURCE_CLS,
123131
)
124132
def test_missing_split_raises_error(data_source_cls):
125-
dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD)
133+
dataset_builder = create_dataset_builder(
134+
file_adapters.FileFormat.ARRAY_RECORD
135+
)
126136
with pytest.raises(
127137
ValueError,
128138
match="Unknown split 'doesnotexist'.",
129139
):
130-
data_source_cls(dataset_info, split='doesnotexist')
140+
data_source_cls(dataset_builder, split='doesnotexist')
131141

132142

133143
@pytest.mark.usefixtures(*_FIXTURES)
@@ -136,8 +146,10 @@ def test_missing_split_raises_error(data_source_cls):
136146
_DATA_SOURCE_CLS,
137147
)
138148
def test_repr_returns_meaningful_string_without_decoders(data_source_cls):
139-
dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD)
140-
source = data_source_cls(dataset_info, split='train')
149+
dataset_builder = create_dataset_builder(
150+
file_adapters.FileFormat.ARRAY_RECORD
151+
)
152+
source = data_source_cls(dataset_builder, split='train')
141153
name = data_source_cls.__name__
142154
assert (
143155
repr(source) == f"{name}(name=dataset_name, split='train', decoders=None)"
@@ -150,9 +162,11 @@ def test_repr_returns_meaningful_string_without_decoders(data_source_cls):
150162
_DATA_SOURCE_CLS,
151163
)
152164
def test_repr_returns_meaningful_string_with_decoders(data_source_cls):
153-
dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD)
165+
dataset_builder = create_dataset_builder(
166+
file_adapters.FileFormat.ARRAY_RECORD
167+
)
154168
source = data_source_cls(
155-
dataset_info,
169+
dataset_builder,
156170
split='train',
157171
decoders={'my_feature': decode.SkipDecoding()},
158172
)
@@ -181,3 +195,18 @@ def test_data_source_is_sliceable():
181195
file_instructions = mock_array_record_data_source.call_args_list[1].args[0]
182196
assert file_instructions[0].skip == 0
183197
assert file_instructions[0].take == 30000
198+
199+
200+
# PyGrain requires that data sources are picklable.
201+
@pytest.mark.parametrize(
202+
'file_format',
203+
file_adapters.FileFormat.with_random_access(),
204+
)
205+
@pytest.mark.parametrize('pickle_module', [pickle, cloudpickle])
206+
def test_data_source_is_picklable_after_use(file_format, pickle_module):
207+
with tfds.testing.tmp_dir() as data_dir:
208+
builder = tfds.testing.DummyDataset(data_dir=data_dir)
209+
builder.download_and_prepare(file_format=file_format)
210+
data_source = builder.as_data_source(split='train')
211+
assert data_source[0] == {'id': 0}
212+
assert pickle_module.loads(pickle_module.dumps(data_source))[0] == {'id': 0}

Diff for: tensorflow_datasets/core/data_sources/parquet.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ class ParquetDataSource(base.BaseDataSource):
5757
"""ParquetDataSource to read from a ParquetDataset."""
5858

5959
def __post_init__(self):
60-
file_instructions = base.file_instructions(self.dataset_info, self.split)
60+
dataset_info = self.dataset_builder.info
61+
file_instructions = base.file_instructions(dataset_info, self.split)
6162
filenames = [
6263
file_instruction.filename for file_instruction in file_instructions
6364
]

Diff for: tensorflow_datasets/core/dataset_builder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -774,13 +774,13 @@ def build_single_data_source(
774774
file_format = self.info.file_format
775775
if file_format == file_adapters.FileFormat.ARRAY_RECORD:
776776
return array_record.ArrayRecordDataSource(
777-
self.info,
777+
self,
778778
split=split,
779779
decoders=decoders,
780780
)
781781
elif file_format == file_adapters.FileFormat.PARQUET:
782782
return parquet.ParquetDataSource(
783-
self.info,
783+
self,
784784
split=split,
785785
decoders=decoders,
786786
)

Diff for: tensorflow_datasets/testing/mocking.py

+32-7
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,7 @@ def _getitems(
120120
_getitem(self, record_key, generator, serialized=serialized)
121121
for record_key in record_keys
122122
]
123-
if serialized:
124-
return np.array(items)
125-
return items
123+
return np.asarray(items)
126124

127125

128126
def _deserialize_example_np(serialized_example, *, decoders=None):
@@ -173,6 +171,7 @@ def mock_data(
173171
as_data_source_fn: Optional[Callable[..., Sequence[Any]]] = None,
174172
data_dir: Optional[str] = None,
175173
mock_array_record_data_source: Optional[PickableDataSourceMock] = None,
174+
use_in_multiprocessing: bool = False,
176175
) -> Iterator[None]:
177176
"""Mock tfds to generate random data.
178177
@@ -262,6 +261,10 @@ def as_dataset(self, *args, **kwargs):
262261
mock_array_record_data_source: Overwrite a mock for the underlying
263262
ArrayRecord data source if it is used. Note: If used the same mock will be
264263
used for all data sources loaded within this context.
264+
use_in_multiprocessing: If True, the mock will use a multiprocessing-safe
265+
approach to generate the data. It's notably useful for PyGrain. The goal
266+
is to migrate the codebase to this mode by default. Find a more detailed
267+
explanation of this parameter in a comment in the code below.
265268
266269
Yields:
267270
None
@@ -361,9 +364,31 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):
361364
if split is None:
362365
split = {s: s for s in self.info.splits}
363366

364-
generator_cls, features, _, _ = _get_fake_data_components(
365-
decoders, self.info.features
366-
)
367+
features = self.info.features
368+
if use_in_multiprocessing:
369+
# In multiprocessing, we generate serialized data. The data is then
370+
# re-deserialized by the feature as it would normally happen in TFDS. In
371+
# this approach, we don't need to monkey-patch workers to propagate the
372+
# information that deserialize_example_np should be a no-op. Indeed, doing
373+
# so is difficult as PyGrain uses the `spawn` multiprocessing mode. Users
374+
# of tfds.testing.mock_data in the codebase started relying on the
375+
# function not serializing (for example, they don't have TensorFlow in
376+
# their dependency), so we cannot have use_in_multiprocessing by default.
377+
# ┌─────────────┐
378+
# │ Main process│
379+
# └─┬──────┬────┘
380+
# ┌───────▼─┐ ┌─▼───────┐
381+
# │ worker1 │ │ worker2 │ ...
382+
# └───────┬─┘ └─┬───────┘
383+
# serialized data by the generator
384+
# ┌───────▼─┐ ┌─▼───────┐
385+
# │ tfds 1 │ │ tfds 2 │ ...
386+
# └───────┬─┘ └─┬───────┘
387+
# deserialized data
388+
generator_cls = SerializedRandomFakeGenerator
389+
else:
390+
# We generate already deserialized data with the generator.
391+
generator_cls, _, _, _ = _get_fake_data_components(decoders, features)
367392
generator = generator_cls(features, num_examples)
368393

369394
if actual_policy == MockPolicy.USE_CODE:
@@ -399,7 +424,7 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):
399424

400425
def build_single_data_source(split):
401426
single_data_source = array_record.ArrayRecordDataSource(
402-
dataset_info=self.info, split=split, decoders=decoders
427+
dataset_builder=self, split=split, decoders=decoders
403428
)
404429
return single_data_source
405430

Diff for: tensorflow_datasets/testing/mocking_test.py

+9
Original file line numberDiff line numberDiff line change
@@ -392,3 +392,12 @@ def test_as_data_source_fn():
392392
assert imagenet[0] == 'foo'
393393
assert imagenet[1] == 'bar'
394394
assert imagenet[2] == 'baz'
395+
396+
397+
# PyGrain requires that data sources are picklable.
398+
def test_mocked_data_source_is_pickable():
399+
with tfds.testing.mock_data(num_examples=2):
400+
data_source = tfds.data_source('imagenet2012', split='train')
401+
pickled_and_unpickled_data_source = pickle.loads(pickle.dumps(data_source))
402+
assert len(pickled_and_unpickled_data_source) == 2
403+
assert isinstance(pickled_and_unpickled_data_source[0]['image'], np.ndarray)

0 commit comments

Comments
 (0)