-
Notifications
You must be signed in to change notification settings - Fork 3.2k
/
Copy pathsampler.py
357 lines (311 loc) · 16.9 KB
/
sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# sampler.py
import math
from typing import Callable, Iterator, Optional
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data.dataset import Dataset
from torch.utils.data.sampler import Sampler
def _shard_wrapped_indices_across_workers(dataset_index_list, num_shards, num_samples_per_shard):
"""Yield successive num_shards-sized chunks from dataset_index_list."""
num_samples = max(1, num_samples_per_shard)
num_elements = num_samples * num_shards
current_lst = []
for i in range(num_elements):
current_lst.append(dataset_index_list[i % len(dataset_index_list)])
if len(current_lst) == num_shards:
yield current_lst
current_lst = []
def shard_wrapped_indices_for_worker(dataset_index_list, shard_id, num_shards):
"""Shard wrapped around dataset_index_list across num_shards and return the indices for this shard_id"""
num_samples_per_worker = (len(dataset_index_list) + num_shards - 1) // num_shards
sharded_indices = list(
_shard_wrapped_indices_across_workers(dataset_index_list, num_shards, num_samples_per_worker)
)
return [sharded_indices[i][shard_id] for i in range(len(sharded_indices))]
# Implementation is adapted from bagua/load_balancing_data_loader.py
# https://github.com/BaguaSys/bagua/blob/01874a7c3f90904c37c5612a9db866b5d4b8b5ed/bagua/torch_api/contrib/load_balancing_data_loader.py#L12
class LoadBalancingDistributedSampler:
r"""Sampler that balances the data load across workers based on the sample's complexity.
This sampler uses a :attr:`complexity_fn` to calculate each sample's computational
complexity and make each batch get similar computational complexity.
This is useful in scenarios like speech and NLP, where each batch has variable
length and distributed training suffers from straggler problem. In such scenarios,
the complexity function could be defined to return the length of the input sample sequence.
The usage is similar to `torch.utils.data.DistributedSampler`, where each process loads a
subset of the original dataset that is exclusive to it.
The sampler sorts the dataset in increasing order of complexity. If the :attr:`group_size` is
provided, the sorting happens within dataset groups of size :attr:`group_size` before the
group order is shuffled followed by sharding of data across workers. If :attr:`group_size`
is not provided, the data is distributed across workers before the data indices for each worker
is shuffled deterministically.
.. note::
Dataset is assumed to be of constant size (map-style dataset).
Args:
dataset: Dataset (map-style) used for sampling.
complexity_fn(Callable): A function whose input is a sample and output is an integer as a
measure of the computational complexity of the sample.
world_size (int, optional): Number of processes participating in
distributed training. By default, :attr:`world_size` is retrieved from the
current distributed group.
rank (int, optional): Rank of the current process within :attr:`world_size`.
By default, :attr:`rank` is retrieved from the current distributed
group.
shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
indices within the dataset if :attr:`group_size` is None, else will
shuffle the groups if :attr:`group_size` is not None.
group_size (int, optional): If provided, the dataset will be broken down into
:attr:`group_size` sized groups. Indices will only be sorted within the groups
and not across the entire dataset. If :attr:`shuffle` is ```True``` and
:attr:`group_size` is not ```None```, the position of each group in the dataset
will be shuffled. Default: ```None```
seed (int, optional): random seed used to shuffle the sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. Default: 0.
drop_last (bool, optional): if ``True``, then the sampler will drop the
tail of the data to make it evenly divisible across the number of
shards. If ``False``, the sampler will add extra indices to make
the data evenly divisible across the shards. Default: ``False``.
random_level (float, optional): A float varies from 0 and 1 that controls the extent
of load balance. 0 means the best load balance, while 1 means the opposite.
.. warning::
In distributed mode, calling the :meth:`set_epoch` method at
the beginning of each epoch **before** creating the `torch.utils.data.DataLoader` iterator
is necessary to make shuffling work properly across multiple epochs. Otherwise,
the same ordering will be always used.
Example::
Define your :attr:`complexity_fn`, which accepts a dataset sample as its input and produces an integer
as the sample's computational complexity:
>>> dataset = MyVariableSequenceLengthDataset(dataset_samples)
>>> complexity_fn = lambda x: len(x)
Below is the usage of :class:`LoadBalancingDistributedSampler`
and `torch.utils.data.DataLoader`:
>>> sampler = onnxruntime.training.utils.data.LoadBalancingDistributedSampler(
... dataset,
... complexity_fn=complexity_fn)
>>> loader = torch.utils.data.DataLoader(dataset,
... sampler=sampler)
>>>
>>> for epoch in range(start_epoch, n_epochs):
... if is_distributed:
... sampler.set_epoch(epoch)
... train(loader)
"""
def __init__(
self,
dataset: Dataset,
complexity_fn: Callable[..., int],
world_size: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
group_size: Optional[int] = None,
seed: int = 0,
drop_last: bool = False,
random_level: float = 0,
) -> None:
if world_size is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
world_size = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
if rank >= world_size or rank < 0:
raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {world_size - 1}]")
self.dataset = dataset
self.world_size = world_size
self.rank = rank
self.epoch = 0
self.drop_last = drop_last
self.group_size = group_size
# If the dataset length is evenly divisible by number of shards, then there
# is no need to drop any data, since the dataset will be split equally.
dataset_len = len(self.dataset)
if self.drop_last and dataset_len % self.world_size != 0:
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = dataset_len // self.world_size
else:
self.num_samples = math.ceil(dataset_len / self.world_size)
self.total_size = self.num_samples * self.world_size
self.shuffle = shuffle
self.seed = seed
self.complexity_fn = complexity_fn
self.sample_complexities = None
self.ordered_sample_complexities = None
if random_level < 0.0 or random_level > 1.0:
raise ValueError(f"Invalid random level {random_level}, shoule be in the range [0.0, 1.0]")
self.random_level = random_level
self.random_number = None
def _sort_shard_and_shuffle_dataset(self):
# This method returns a list of dataset sample indices after
# the dataset has been sorted, sharded and shuffled.
# The sorting of the dataset happens based on the group_size and complexities
# of each sample.
# Sharding happens across the number of workers.
# Shuffling is done either before sharding on the group indices (if group_size is provided)
# or on the dataset sample indices if the group_size is not provided.
def sort_in_groups(sample_complexities, group_size):
"""Sort the dataset samples indices inside each group of size group_size."""
# If the group_size is None, the entire dataset is considered as a single group
if group_size is None:
group_size = len(sample_complexities)
# Sort the dataset samples inside each group of the dataset based on sample complexity.
for group_begin_index in range(0, len(sample_complexities), group_size):
group_end_index = min(group_begin_index + group_size, len(sample_complexities))
sorted_indices = group_begin_index + np.argsort(
sample_complexities[group_begin_index:group_end_index, 1]
)
sample_complexities[group_begin_index:group_end_index, :] = sample_complexities[sorted_indices]
return sample_complexities
# Get the samples and their complexities from the complexity_fn
if not self.sample_complexities:
self.sample_complexities = np.empty((len(self.dataset), 2), dtype=np.int64)
for sample_index in range(len(self.dataset)):
self.sample_complexities[sample_index][0] = sample_index
self.sample_complexities[sample_index][1] = self.complexity_fn(self.dataset[sample_index])
if self.random_number is None:
max_complexity = max(self.sample_complexities, key=lambda t: t[1])[1]
min_complexity = min(self.sample_complexities, key=lambda t: t[1])[1]
self.random_number = int((max_complexity - min_complexity) * self.random_level + 1)
sample_complexities = self.sample_complexities.copy()
# Control the degree of load balancing by modifying the complexities of
# all samples using the random_number.
g = torch.Generator()
g = g.manual_seed(self.seed + self.epoch)
if self.random_number > 1:
complexity_random_ints = torch.randint(
self.random_number, (len(sample_complexities),), generator=g
).tolist()
for index, random_int in enumerate(complexity_random_ints):
sample_complexities[index][1] += random_int
# Sort the data based on the computed complexities and group sizes.
# Sort only once if random_number <= 1 else sort everytime
if self.ordered_sample_complexities is None or self.random_number > 1:
self.ordered_sample_complexities = sort_in_groups(sample_complexities, self.group_size)
ordered_sample_complexities = self.ordered_sample_complexities
# If group_size is not None, shuffle the index of each group instead
# of shuffling the data indices.
if self.shuffle and self.group_size is not None:
num_groups = (len(self.sample_complexities) + self.group_size - 1) // self.group_size
group_order = torch.randperm(num_groups, generator=g).tolist()
end = 0
sample_complexities_copy = ordered_sample_complexities.copy()
for group_index in group_order:
original_list_begin_index = self.group_size * group_index
original_list_end_index = min(original_list_begin_index + self.group_size, len(sample_complexities))
begin = end
end = begin + (original_list_end_index - original_list_begin_index)
sample_complexities_copy[begin:end, :] = sample_complexities[
original_list_begin_index:original_list_end_index, :
]
ordered_sample_complexities = sample_complexities_copy
# Shard the data across the different workers.
index_chunks = list(
_shard_wrapped_indices_across_workers(
[index_complexity_tuple[0] for index_complexity_tuple in ordered_sample_complexities],
self.world_size,
self.num_samples,
)
)
# Shuffle the sharded data indices deterministically based on epoch and seed.
chunk_indices = list(range(len(index_chunks)))
if self.shuffle and self.group_size is None:
chunk_indices = torch.randperm(len(index_chunks), generator=g).tolist()
if not self.drop_last:
# Add extra samples to make it evenly divisible
padding_size = self.num_samples - len(chunk_indices)
if padding_size <= len(chunk_indices):
chunk_indices += chunk_indices[:padding_size]
else:
chunk_indices += (chunk_indices * math.ceil(padding_size / len(chunk_indices)))[:padding_size]
else:
# Remove tail of data to make it evenly divisible.
chunk_indices = chunk_indices[: self.num_samples]
assert len(chunk_indices) == self.num_samples
return index_chunks, chunk_indices
def __iter__(self) -> Iterator:
index_chunks, chunk_indices = self._sort_shard_and_shuffle_dataset()
# Extract indices based on current rank.
indices = [index_chunks[i][self.rank] for i in chunk_indices]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self) -> int:
return self.num_samples
def set_epoch(self, epoch: int) -> None:
r"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all shards use a different
random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
class LoadBalancingDistributedBatchSampler(Sampler):
r"""Wraps another load balance sampler to yield variable sized mini-batches.
Args:
sampler (LoadBalancingDistributedSampler): Load balance sampler.
batch_fn (Callable): Callable to yield mini-batch indices.
drop_last (bool): If ``True``, the sampler will drop the last few batches exceeding
the least number of batches among replicas, otherwise, the number of batches
on each replica will be padded to the same.
:attr:`batch_fn` will have the signature of::
def batch_fn(indices: List[int]) -> List[List[int]]
Example::
>>> from onnxruntime.training.utils.data import LoadBalancingDistributedSampler, \
... LoadBalancingDistributedBatchSampler
>>>
>>> sampler = LoadBalancingDistributedSampler(dataset, complexity_fn=complexity_fn)
>>> batch_sampler = LoadBalancingDistributedBatchSampler(sampler, batch_fn=batch_fn)
>>> loader = torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler)
>>>
>>> for epoch in range(start_epoch, n_epochs):
... batch_sampler.set_epoch(epoch)
... train(loader)
"""
def __init__(
self,
sampler: LoadBalancingDistributedSampler,
batch_fn,
drop_last: bool = False,
) -> None:
if not isinstance(sampler, LoadBalancingDistributedSampler):
raise ValueError("sampler should be of LoadBalancingDistributedSampler type.")
if sampler.drop_last:
raise ValueError("drop_last of sampler should be False")
self.sampler = sampler
self.batch_fn = batch_fn
self.drop_last = drop_last
self.world_size = self.sampler.world_size
self.rank = self.sampler.rank
self.generate_batches()
def generate_batches(self):
index_chunks, chunk_indices = self.sampler._sort_shard_and_shuffle_dataset()
batches = []
for rank in range(self.world_size):
sub_indices = [index_chunks[i][rank] for i in chunk_indices]
batches.append(self.batch_fn(sub_indices))
self.total_batch = max([len(b) for b in batches]) if not self.drop_last else min([len(b) for b in batches])
# here {len(batches[self.rank]) - self.total_batch} batches dropped for
# rank {self.rank}
if self.total_batch < len(batches[self.rank]):
pass
self.padded_batches = [batch + batch[: self.total_batch - len(batch)] for batch in batches]
def __iter__(self):
return iter(self.padded_batches[self.rank])
def __len__(self):
return self.total_batch
def set_epoch(self, epoch: int) -> None:
r"""
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
use a different random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.sampler.set_epoch(epoch)
self.generate_batches()