Skip to content

Commit cb10f71

Browse files
committed
Adding state parser utility that can be used for modifying worker states
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 958eeb0 commit cb10f71

File tree

3 files changed

+154
-1
lines changed

3 files changed

+154
-1
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from torch.testing._internal.common_utils import TestCase
10+
11+
from torch.utils.data import Dataset, IterableDataset
12+
from torchdata.stateful_dataloader import Stateful, StatefulDataLoader, StateParserUtil
13+
14+
15+
class StatefulIterableDataset(IterableDataset, Stateful):
16+
def __init__(self):
17+
self.num_calls = 0
18+
19+
def __iter__(self):
20+
return self
21+
22+
def __next__(self):
23+
self.num_calls += 1
24+
return self.num_calls
25+
26+
def load_state_dict(self, state_dict):
27+
self.num_calls = state_dict["num_calls"]
28+
29+
def state_dict(self):
30+
return {"num_calls": self.num_calls}
31+
32+
33+
def identity(x):
34+
return x
35+
36+
37+
class TestIteratorDataset(TestCase):
38+
def test_increasing_worker(self):
39+
ds = StatefulIterableDataset()
40+
dl = StatefulDataLoader(ds, num_workers=2, collate_fn=identity)
41+
it = iter(dl)
42+
next(it)
43+
sd = dl.state_dict()
44+
print(sd)
45+
del dl
46+
47+
parser = StateParserUtil(sd)
48+
worker_states = parser.fetch_dataset_state()
49+
worker_states[2] = {"num_calls": 2}
50+
worker_states[3] = {"num_calls": 3}
51+
parser.set_dataset_state(worker_states)
52+
53+
# worker state doesn't equal num workers setting
54+
with self.assertRaises(AssertionError):
55+
parser.get_state_dict()
56+
parser.set_num_workers(4)
57+
58+
# last worker yielded id is greater than num workers
59+
parser.set_last_worker_yielded_id(10)
60+
with self.assertRaises(AssertionError):
61+
parser.get_state_dict()
62+
parser.set_last_worker_yielded_id(0)
63+
64+
# load the modified state
65+
new_sd = parser.get_state_dict()
66+
print(new_sd)
67+
dl = StatefulDataLoader(ds, num_workers=4, collate_fn=identity)
68+
dl.load_state_dict(new_sd)
69+
it = iter(dl)
70+
values = []
71+
for _ in range(4):
72+
values.extend(next(it))
73+
print(values)
74+
self.assertEqual(values, [1, 3, 4, 2])
75+
76+
77+
if __name__ == "__main__":
78+
unittest.main()

torchdata/stateful_dataloader/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from .state_parser import StateParserUtil
78
from .stateful import Stateful
89
from .stateful_dataloader import StatefulDataLoader
910

10-
__all__ = ["Stateful", "StatefulDataLoader"]
11+
__all__ = ["Stateful", "StatefulDataLoader", "StateParserUtil"]
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import logging
7+
from typing import Any, Dict, Union
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class StateParserUtil:
13+
"""
14+
Utility class that can be used to modify state returned by the dataloader
15+
"""
16+
17+
def __init__(self, state_dict: Dict[str, Any]):
18+
self._state_dict = state_dict
19+
self._is_multiprocess_state = "_snapshot" in self._state_dict
20+
21+
def fetch_dataset_state(self) -> Dict[int, Any]:
22+
# Handle both cases of single process and multiprocess
23+
if not self._is_multiprocess_state:
24+
return self._state_dict["dataset_state"]
25+
return {
26+
state["worker_id"]: state["dataset_state"]
27+
for _, state in self._state_dict["_snapshot"]["_worker_snapshots"].items()
28+
}
29+
30+
def set_last_worker_yielded_id(self, last_worker_yielded: int) -> None:
31+
# Ensure that this number is within the number of workers
32+
if not self._is_multiprocess_state:
33+
logger.warning("Cannot set last worker yielded id on a single process state dict")
34+
return
35+
self._state_dict["_snapshot"]["_last_yielded_worker_id"] = last_worker_yielded
36+
37+
def set_num_workers(self, num_workers: int) -> None:
38+
if not self._is_multiprocess_state:
39+
logger.warning("Cannot set num_workers on a single process state dict")
40+
return
41+
self._state_dict["_snapshot"]["_main_snapshot"]["_num_workers"] = num_workers
42+
43+
def set_dataset_state(self, dataset_state: Union[Dict[int, Any], Any]) -> None:
44+
if not self._is_multiprocess_state:
45+
self._state_dict["dataset_state"] = dataset_state
46+
return
47+
48+
for id, state in dataset_state.items():
49+
worker_states = self._state_dict["_snapshot"]["_worker_snapshots"]
50+
worker_key = f"worker_{id}"
51+
if worker_key in worker_states:
52+
worker_states[worker_key]["dataset_state"] = state
53+
else:
54+
worker_states[worker_key] = {"worker_id": id, "dataset_state": state, "fetcher_state": None}
55+
56+
def get_state_dict(self) -> Dict[str, Any]:
57+
# Perform validations
58+
# a) num_workers should match worker_snapshots
59+
# b) last yielded worker id should be within num_workers
60+
if not self._is_multiprocess_state:
61+
return self._state_dict
62+
63+
last_yielded_id = self._state_dict["_snapshot"]["_last_yielded_worker_id"]
64+
num_workers = self._state_dict["_snapshot"]["_main_snapshot"]["_num_workers"]
65+
worker_ids = self._state_dict["_snapshot"]["_worker_snapshots"].keys()
66+
67+
assert (
68+
len(worker_ids) == num_workers
69+
), f"Number of worker states {len(worker_ids)} should be equal to num_workers setting {num_workers}"
70+
assert (
71+
len(set(worker_ids)) == num_workers
72+
), f"Worker state for all from [0, {num_workers}) should be present. Instead found state for only {worker_ids} workers"
73+
assert last_yielded_id < num_workers, "Last yielded id should be strictly within the number of workers"
74+
return self._state_dict

0 commit comments

Comments
 (0)