Skip to content

Commit 52f4d72

Browse files
marcenacpThe TensorFlow Datasets Authors
authored and
The TensorFlow Datasets Authors
committed
Make tfds.data_source pickable.
PiperOrigin-RevId: 636824581
1 parent 6bbba45 commit 52f4d72

File tree

7 files changed

+94
-40
lines changed

7 files changed

+94
-40
lines changed

tensorflow_datasets/core/data_sources/array_record.py

+2-16
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,8 @@
2020
"""
2121

2222
import dataclasses
23-
from typing import Any, Optional
2423

25-
from tensorflow_datasets.core import dataset_info as dataset_info_lib
26-
from tensorflow_datasets.core import decode
27-
from tensorflow_datasets.core import splits as splits_lib
2824
from tensorflow_datasets.core.data_sources import base
29-
from tensorflow_datasets.core.utils import type_utils
3025
from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_data_source
3126

3227

@@ -42,18 +37,9 @@ class ArrayRecordDataSource(base.BaseDataSource):
4237
source.
4338
"""
4439

45-
dataset_info: dataset_info_lib.DatasetInfo
46-
split: splits_lib.Split = None
47-
decoders: Optional[type_utils.TreeDict[decode.partial_decode.DecoderArg]] = (
48-
None
49-
)
50-
# In order to lazy load array_record, we don't load
51-
# `array_record_data_source.ArrayRecordDataSource` here.
52-
data_source: Any = dataclasses.field(init=False)
53-
length: int = dataclasses.field(init=False)
54-
5540
def __post_init__(self):
56-
file_instructions = base.file_instructions(self.dataset_info, self.split)
41+
dataset_info = self.dataset_builder.info
42+
file_instructions = base.file_instructions(dataset_info, self.split)
5743
self.data_source = array_record_data_source.ArrayRecordDataSource(
5844
file_instructions
5945
)

tensorflow_datasets/core/data_sources/base.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717

1818
from collections.abc import MappingView, Sequence
1919
import dataclasses
20+
import functools
2021
import typing
2122
from typing import Any, Generic, Iterable, Protocol, SupportsIndex, TypeVar
2223

2324
from tensorflow_datasets.core import dataset_info as dataset_info_lib
2425
from tensorflow_datasets.core import decode
2526
from tensorflow_datasets.core import splits as splits_lib
27+
from tensorflow_datasets.core.features import top_level_feature
2628
from tensorflow_datasets.core.utils import shard_utils
2729
from tensorflow_datasets.core.utils import type_utils
2830
from tensorflow_datasets.core.utils.lazy_imports_utils import tree
@@ -54,6 +56,14 @@ def file_instructions(
5456
return split_dict[split].file_instructions
5557

5658

59+
class _DatasetBuilder(Protocol):
60+
"""Protocol for the DatasetBuilder to avoid cyclic imports."""
61+
62+
@property
63+
def info(self) -> dataset_info_lib.DatasetInfo:
64+
...
65+
66+
5767
@dataclasses.dataclass
5868
class BaseDataSource(MappingView, Sequence):
5969
"""Base DataSource to override all dunder methods with the deserialization.
@@ -64,22 +74,28 @@ class BaseDataSource(MappingView, Sequence):
6474
deserialization/decoding.
6575
6676
Attributes:
67-
dataset_info: The DatasetInfo of the
77+
dataset_builder: The dataset builder.
6878
split: The split to load in the data source.
6979
decoders: Optional decoders for decoding.
7080
data_source: The underlying data source to initialize in the __post_init__.
7181
"""
7282

73-
dataset_info: dataset_info_lib.DatasetInfo
83+
dataset_builder: _DatasetBuilder
7484
split: splits_lib.Split | None = None
7585
decoders: type_utils.TreeDict[decode.partial_decode.DecoderArg] | None = None
7686
data_source: DataSource[Any] = dataclasses.field(init=False)
7787

88+
@functools.cached_property
89+
def _features(self) -> top_level_feature.TopLevelFeature:
90+
"""Caches features because we log the use of dataset_builder.info."""
91+
features = self.dataset_builder.info.features
92+
if not features:
93+
raise ValueError('No feature defined in the dataset builder.')
94+
return features
95+
7896
def __getitem__(self, key: SupportsIndex) -> Any:
7997
record = self.data_source[key.__index__()]
80-
return self.dataset_info.features.deserialize_example_np(
81-
record, decoders=self.decoders
82-
)
98+
return self._features.deserialize_example_np(record, decoders=self.decoders)
8399

84100
def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]:
85101
"""Retrieves items by batch.
@@ -98,24 +114,24 @@ def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]:
98114
if not keys:
99115
return []
100116
records = self.data_source.__getitems__(keys)
101-
features = self.dataset_info.features
102117
if len(keys) != len(records):
103118
raise IndexError(
104119
f'Requested {len(keys)} records but got'
105120
f' {len(records)} records.'
106121
f'{keys=}, {records=}'
107122
)
108123
return [
109-
features.deserialize_example_np(record, decoders=self.decoders)
124+
self._features.deserialize_example_np(record, decoders=self.decoders)
110125
for record in records
111126
]
112127

113128
def __repr__(self) -> str:
114129
decoders_repr = (
115130
tree.map_structure(type, self.decoders) if self.decoders else None
116131
)
132+
name = self.dataset_builder.info.name
117133
return (
118-
f'{self.__class__.__name__}(name={self.dataset_info.name}, '
134+
f'{self.__class__.__name__}(name={name}, '
119135
f'split={self.split!r}, '
120136
f'decoders={decoders_repr})'
121137
)

tensorflow_datasets/core/data_sources/base_test.py

+39-10
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515

1616
"""Tests for all data sources."""
1717

18+
import pickle
1819
from unittest import mock
1920

21+
import cloudpickle
2022
from etils import epath
2123
import pytest
2224
import tensorflow_datasets as tfds
2325
from tensorflow_datasets import testing
24-
from tensorflow_datasets.core import dataset_builder
26+
from tensorflow_datasets.core import dataset_builder as dataset_builder_lib
2527
from tensorflow_datasets.core import dataset_info as dataset_info_lib
2628
from tensorflow_datasets.core import decode
2729
from tensorflow_datasets.core import file_adapters
@@ -77,7 +79,7 @@ def mocked_parquet_dataset():
7779
)
7880
def test_read_write(
7981
tmp_path: epath.Path,
80-
builder_cls: dataset_builder.DatasetBuilder,
82+
builder_cls: dataset_builder_lib.DatasetBuilder,
8183
file_format: file_adapters.FileFormat,
8284
):
8385
builder = builder_cls(data_dir=tmp_path, file_format=file_format)
@@ -106,28 +108,36 @@ def test_read_write(
106108
]
107109

108110

109-
def create_dataset_info(file_format: file_adapters.FileFormat):
111+
def create_dataset_builder(
112+
file_format: file_adapters.FileFormat,
113+
) -> dataset_builder_lib.DatasetBuilder:
110114
with mock.patch.object(splits_lib, 'SplitInfo') as split_mock:
111115
split_mock.return_value.name = 'train'
112116
split_mock.return_value.file_instructions = _FILE_INSTRUCTIONS
113117
dataset_info = mock.create_autospec(dataset_info_lib.DatasetInfo)
114118
dataset_info.file_format = file_format
115119
dataset_info.splits = {'train': split_mock()}
116120
dataset_info.name = 'dataset_name'
117-
return dataset_info
121+
122+
dataset_builder = mock.create_autospec(dataset_builder_lib.DatasetBuilder)
123+
dataset_builder_lib.info = dataset_info
124+
125+
return dataset_builder
118126

119127

120128
@pytest.mark.parametrize(
121129
'data_source_cls',
122130
_DATA_SOURCE_CLS,
123131
)
124132
def test_missing_split_raises_error(data_source_cls):
125-
dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD)
133+
dataset_builder = create_dataset_builder(
134+
file_adapters.FileFormat.ARRAY_RECORD
135+
)
126136
with pytest.raises(
127137
ValueError,
128138
match="Unknown split 'doesnotexist'.",
129139
):
130-
data_source_cls(dataset_info, split='doesnotexist')
140+
data_source_cls(dataset_builder, split='doesnotexist')
131141

132142

133143
@pytest.mark.usefixtures(*_FIXTURES)
@@ -136,8 +146,10 @@ def test_missing_split_raises_error(data_source_cls):
136146
_DATA_SOURCE_CLS,
137147
)
138148
def test_repr_returns_meaningful_string_without_decoders(data_source_cls):
139-
dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD)
140-
source = data_source_cls(dataset_info, split='train')
149+
dataset_builder = create_dataset_builder(
150+
file_adapters.FileFormat.ARRAY_RECORD
151+
)
152+
source = data_source_cls(dataset_builder, split='train')
141153
name = data_source_cls.__name__
142154
assert (
143155
repr(source) == f"{name}(name=dataset_name, split='train', decoders=None)"
@@ -150,9 +162,11 @@ def test_repr_returns_meaningful_string_without_decoders(data_source_cls):
150162
_DATA_SOURCE_CLS,
151163
)
152164
def test_repr_returns_meaningful_string_with_decoders(data_source_cls):
153-
dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD)
165+
dataset_builder = create_dataset_builder(
166+
file_adapters.FileFormat.ARRAY_RECORD
167+
)
154168
source = data_source_cls(
155-
dataset_info,
169+
dataset_builder,
156170
split='train',
157171
decoders={'my_feature': decode.SkipDecoding()},
158172
)
@@ -181,3 +195,18 @@ def test_data_source_is_sliceable():
181195
file_instructions = mock_array_record_data_source.call_args_list[1].args[0]
182196
assert file_instructions[0].skip == 0
183197
assert file_instructions[0].take == 30000
198+
199+
200+
# PyGrain requires that data sources are picklable.
201+
@pytest.mark.parametrize(
202+
'file_format',
203+
file_adapters.FileFormat.with_random_access(),
204+
)
205+
@pytest.mark.parametrize('pickle_module', [pickle, cloudpickle])
206+
def test_data_source_is_picklable_after_use(file_format, pickle_module):
207+
with tfds.testing.tmp_dir() as data_dir:
208+
builder = tfds.testing.DummyDataset(data_dir=data_dir)
209+
builder.download_and_prepare(file_format=file_format)
210+
data_source = builder.as_data_source(split='train')
211+
assert data_source[0] == {'id': 0}
212+
assert pickle_module.loads(pickle_module.dumps(data_source))[0] == {'id': 0}

tensorflow_datasets/core/data_sources/parquet.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ class ParquetDataSource(base.BaseDataSource):
5757
"""ParquetDataSource to read from a ParquetDataset."""
5858

5959
def __post_init__(self):
60-
file_instructions = base.file_instructions(self.dataset_info, self.split)
60+
dataset_info = self.dataset_builder.info
61+
file_instructions = base.file_instructions(dataset_info, self.split)
6162
filenames = [
6263
file_instruction.filename for file_instruction in file_instructions
6364
]

tensorflow_datasets/core/dataset_builder.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -333,10 +333,23 @@ def code_path(cls) -> Optional[epath.Path]:
333333
return epath.Path(filepath)
334334

335335
def __getstate__(self):
336-
return self._original_state
336+
state = {"original_state": self._original_state}
337+
features = self.info.features
338+
if hasattr(features, "deserialize_example_np"):
339+
# See the comment in __setstate__ to understand why we do this.
340+
state["deserialize_example_np"] = features.deserialize_example_np
341+
return state
337342

338343
def __setstate__(self, state):
339-
self.__init__(**state)
344+
self.__init__(**state["original_state"])
345+
346+
# This is a hack. We explicitly set deserialize_example_np to propagate any
347+
# mock on this function to PyGrain workers in multiprocessing. Indeed,
348+
# mock.patch cannot be used in multiprocessing since the builder is created
349+
# in a totally different process.
350+
deserialize_example_np = state.get("deserialize_example_np")
351+
if deserialize_example_np:
352+
self.info.features.deserialize_example_np = deserialize_example_np
340353

341354
@functools.cached_property
342355
def canonical_version(self) -> utils.Version:
@@ -774,13 +787,13 @@ def build_single_data_source(
774787
file_format = self.info.file_format
775788
if file_format == file_adapters.FileFormat.ARRAY_RECORD:
776789
return array_record.ArrayRecordDataSource(
777-
self.info,
790+
self,
778791
split=split,
779792
decoders=decoders,
780793
)
781794
elif file_format == file_adapters.FileFormat.PARQUET:
782795
return parquet.ParquetDataSource(
783-
self.info,
796+
self,
784797
split=split,
785798
decoders=decoders,
786799
)

tensorflow_datasets/testing/mocking.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):
399399

400400
def build_single_data_source(split):
401401
single_data_source = array_record.ArrayRecordDataSource(
402-
dataset_info=self.info, split=split, decoders=decoders
402+
dataset_builder=self, split=split, decoders=decoders
403403
)
404404
return single_data_source
405405

tensorflow_datasets/testing/mocking_test.py

+9
Original file line numberDiff line numberDiff line change
@@ -392,3 +392,12 @@ def test_as_data_source_fn():
392392
assert imagenet[0] == 'foo'
393393
assert imagenet[1] == 'bar'
394394
assert imagenet[2] == 'baz'
395+
396+
397+
# PyGrain requires that data sources are picklable.
398+
def test_mocked_data_source_is_pickable():
399+
with tfds.testing.mock_data(num_examples=2):
400+
data_source = tfds.data_source('imagenet2012', split='train')
401+
pickled_and_unpickled_data_source = pickle.loads(pickle.dumps(data_source))
402+
assert len(pickled_and_unpickled_data_source) == 2
403+
assert isinstance(pickled_and_unpickled_data_source[0]['image'], np.ndarray)

0 commit comments

Comments
 (0)