-
Notifications
You must be signed in to change notification settings - Fork 4.3k
/
Copy pathengine.py
1388 lines (1155 loc) · 56.8 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2019 The Microsoft DeepSpeed Team
import time
import logging
import copy
import os
from types import MethodType
from numpy import prod
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from deepspeed.utils.logging import logger
from deepspeed.utils.timer import SynchronizedWallClockTimer, ThroughputTimer
from deepspeed.inference.engine import InferenceEngine
from ..engine import DeepSpeedEngine, MEMORY_OPT_ALLREDUCE_SIZE
from ..utils import PartitionedTensor, ensure_directory_exists
from ..dataloader import RepeatingLoader
from .module import PipelineModule, PipelineError, TiedLayerSpec
from . import p2p
from . import schedule
TARGET_ID = -2
LOG_STAGE = -2
DATA_PARALLEL_ID = -2
def is_even(number):
return number % 2 == 0
mem_alloced = 0
mem_cached = 0
def _tensor_bytes(tensor):
return tensor.numel() * tensor.element_size()
class PipelineEngine(DeepSpeedEngine):
""" A training engine hybrid pipeline, data, and model parallel training.
This engine is created by ``deepspeed.initialize()`` when a :class:`PipelineModule`
is provided.
"""
ID_TO_DTYPE = [
torch.float32,
torch.float64,
torch.complex64,
torch.complex128,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.bool
]
DTYPE_TO_ID = {dtype: id_ for id_, dtype in enumerate(ID_TO_DTYPE)}
def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):
super().__init__(*super_args, **super_kwargs)
assert isinstance(self.module, PipelineModule), "model must base PipelineModule"
assert self.zero_optimization_stage() < 2, "ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism"
# We schedule the all-reduces, so disable it in super().backward()
self.enable_backward_allreduce = False
self.has_bool_tensors = has_bool_tensors
self.eval_return_logits = False
self.outputs = None
# used to disable the pipeline all-reduce when used with 1-bit Adam/1-bit LAMB
self.pipeline_enable_backward_allreduce = True
assert not self.elasticity_enabled(), "Elasticity is not currently supported" \
" with pipeline parallelism."
# pipeline step for logging
self.log_batch_step_id = -1
self.micro_batch_size = self.train_micro_batch_size_per_gpu()
self.micro_batches = self.gradient_accumulation_steps()
# Set Grid and Communication Groups
self.grid = self.module._grid
if self.grid.get_global_rank() == 0:
logger.info(f'CONFIG: micro_batches={self.micro_batches} '
f'micro_batch_size={self.micro_batch_size}')
self.global_rank = self.grid.get_global_rank()
assert self.dp_world_size == self.grid.data_parallel_size
assert self.train_batch_size() == \
self.micro_batch_size * self.micro_batches * self.grid.data_parallel_size
# Set Stage Inf
self.num_stages = self.grid.pipe_parallel_size
self.stage_id = self.grid.get_stage_id()
self.prev_stage = self.stage_id - 1
self.next_stage = self.stage_id + 1
self.data_iterator = None
self.batch_fn = None
self._force_grad_boundary = False
self.batch_timer = ThroughputTimer(batch_size=self.micro_batch_size *
self.micro_batches,
num_workers=self.dp_world_size,
logging_fn=self.tput_log,
monitor_memory=False,
steps_per_output=self.steps_per_print())
# PipelineEngine needs to handle data loading specially due to only the first
# and last stages loading inputs/labels. We construct a sampler that uses
if self.training_data:
self._build_data_iter(self.training_data)
self.is_pipe_parallel = self.grid.pipe_parallel_size > 1
self.is_data_parallel = self.grid.data_parallel_size > 1
self.is_model_parallel = self.grid.model_parallel_size > 1
# Partition input/output buffers
# XXX temporarily disable while I revert some partition hacks.
self.is_pipe_partitioned = self.is_model_parallel
self.is_grad_partitioned = self.is_model_parallel
model_parameters = filter(lambda p: p.requires_grad, self.module.parameters())
num_params = sum([p.numel() for p in model_parameters])
unique_params = num_params
# Subtract tied parameters if we don't own them
if self.module.tied_comms:
tied_params = 0
for key, d in self.module.tied_comms.items():
if self.global_rank != min(d['ranks']):
tied_params += sum(p.numel() for p in d['module'].parameters())
unique_params -= tied_params
params_tensor = torch.LongTensor(data=[num_params,
unique_params]).to(self.device)
dist.all_reduce(params_tensor, group=self.grid.get_model_parallel_group())
params_tensor = params_tensor.tolist()
total_params = params_tensor[0]
unique_params = params_tensor[1]
if self.grid.data_parallel_id == 0:
logger.info(f'RANK={self.global_rank} '
f'STAGE={self.stage_id} '
f'LAYERS={self.module._local_stop - self.module._local_start} '
f'[{self.module._local_start}, {self.module._local_stop}) '
f'STAGE_PARAMS={num_params} ({num_params/1e6:0.3f}M) '
f'TOTAL_PARAMS={total_params} ({total_params/1e6:0.3f}M) '
f'UNIQUE_PARAMS={unique_params} ({unique_params/1e6:0.3f}M)')
#initialize peer-2-peer communication and allreduce groups
if self.is_pipe_parallel:
p2p.init_process_groups(self.grid)
# Pipeline buffers
self.num_pipe_buffers = 0
self.pipe_buffers = {
'inputs' : [], # batch input and received activations
'labels' : [], # labels from batch input
'outputs' : [], # activations
'output_tensors' : [], # tensor object to preserve backward graph
}
self.pipe_recv_buf = None
self.grad_layer = None
self.meta_buffer = None
self.first_output_send = True
self.first_gradient_send = True
#stores the loss for the current micro batch being processed
self.loss = torch.tensor(0.0).to(self.device)
#stores the loss for the entire batch
self.total_loss = None
self.agg_loss = torch.tensor(0.0, requires_grad=False).to(self.device)
self.dp_group_loss = torch.tensor(0.0, requires_grad=False).to(self.device)
if self._config.pipeline['activation_checkpoint_interval'] > 0:
self.module.activation_checkpoint_interval = self._config.pipeline[
'activation_checkpoint_interval']
if self.is_last_stage():
self.loss_model = self.module.loss_fn
self.has_attention_mask = self.module.__class__.__name__ == 'GPT2ModelPipe'
# Initialize pipeline communicators. Just send a 0.
if is_even(self.stage_id):
if not self.is_last_stage():
p2p.send(self.loss, self.next_stage)
if not self.is_first_stage():
p2p.recv(self.loss, self.prev_stage)
else:
if not self.is_first_stage():
p2p.recv(self.loss, self.prev_stage)
if not self.is_last_stage():
p2p.send(self.loss, self.next_stage)
# XXX look into timer reporting timing
# Initialize some timers because of early weirdness.
if self.wall_clock_breakdown():
self.timers('forward_microstep').start()
self.timers('forward_microstep').stop()
self.timers('backward_microstep').start()
self.timers('backward_microstep').stop()
self.timers('backward_inner_microstep').start()
self.timers('backward_inner_microstep').stop()
self.timers('backward_allreduce_microstep').start()
self.timers('backward_allreduce_microstep').stop()
self.timers('backward_allreduce').start()
self.timers('backward_allreduce').stop()
self.timers('step_microstep').start()
self.timers('step_microstep').stop()
def set_has_attention_mask(self, value):
assert isinstance(value, bool)
self.has_attention_mask = value
def _build_data_iter(self, dataset):
sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=self.dp_world_size,
rank=self.mpu.get_data_parallel_rank(),
shuffle=False)
# Build a loader and make it repeating.
pipe_dataloader = self.deepspeed_io(dataset, data_sampler=sampler)
pipe_dataloader = RepeatingLoader(pipe_dataloader)
self.set_dataloader(pipe_dataloader)
def _exec_reduce_tied_grads(self):
# We need to run this first to write to self.averaged_gradients;
# since this class turns `enable_backward_allreduce` off,
# `self.overlapping_partition_gradients_reduce_epilogue()` defined in the DeepSpeedEngine
# never actually runs. I suspect this is because of efficiency problems; get_flat_partition in
# stage2.py might do something expensive; someone will have to look into that later. But
# in the meantime, this fixes ZeRO2 + Pipelining enough to run a demo. Further profiling
# needed to decide if it actually breaks everything.
# (see https://github.com/EleutherAI/gpt-neox/issues/62#issuecomment-761471944)
if self.zero_optimization_partition_gradients():
self.optimizer.overlapping_partition_gradients_reduce_epilogue()
weight_group_list = self.module.get_tied_weights_and_groups()
for weight, group in weight_group_list:
grad = weight._hp_grad if self.bfloat16_enabled() else weight.grad
dist.all_reduce(grad, group=group)
def _exec_reduce_grads(self):
self._force_grad_boundary = True
if self.pipeline_enable_backward_allreduce:
if self.bfloat16_enabled():
if self.zero_optimization_stage() == 0:
self._bf16_reduce_grads()
else:
assert self.zero_optimization_stage() == 1, "only bf16 + z1 are supported"
raise NotImplementedError()
else:
self.allreduce_gradients(bucket_size=MEMORY_OPT_ALLREDUCE_SIZE)
self._force_grad_boundary = False
def _bf16_reduce_grads(self):
# Make our own list of gradients from the optimizer's FP32 grads
grads = []
self.buffered_allreduce_fallback(grads=self.optimizer.get_grads_for_reduction(),
elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE)
def _reserve_pipe_buffers(self, num_buffers):
"""Ensure that each pipeline buffer has at least ``num_buffers`` slots.
This method only reserves slots and does not allocate tensors.
Args:
num_buffers (int): The number of buffers to reserve.
"""
if self.num_pipe_buffers >= num_buffers:
return
num_added = num_buffers - self.num_pipe_buffers
for key in self.pipe_buffers:
self.pipe_buffers[key].extend([None] * num_added)
self.num_pipe_buffers = num_buffers
def reset_activation_shape(self):
"""Reset the buffers when the shape of activation and gradient change.
For example, for curriculum learning that changes the seqlen of each
sample, we need to call this whenever the seqlen is going to change.
"""
self.first_output_send = True
self.pipe_recv_buf = None
self.grad_layer = None
self.meta_buffer = None
def train_batch(self, data_iter=None):
"""Progress the pipeline to train the next batch of data. The engine will ingest
``self.train_batch_size()`` total samples collectively across all workers.
An iterator that over training data should be provided as an argument
unless ``deepspeed.initialize()`` was provided a training set. In that event,
the training data will automatically be read.
.. warning::
A total of ``self.gradient_accumulation_steps()`` entries will be pulled
from ``data_iter`` by each pipeline. There must be sufficient
data left in ``data_iter`` or else a ``StopIteration`` will halt training.
DeepSpeed provides a convenience class :class:`deepspeed.utils.RepeatingLoader`
that wraps data loaders to automatically restart upon a ``StopIteration``.
Args:
data_iter (Iterator, optional): Iterator of training data.
Returns:
The arithmetic mean of the losses computed this batch.
"""
if not torch._C.is_grad_enabled():
raise RuntimeError(
f'train_batch() requires gradients enabled. Use eval_batch() instead.')
# Curriculum learning could change activation shape
if self.curriculum_enabled():
new_difficulty = self.curriculum_scheduler.update_difficulty( \
self.global_steps + 1)
if self.global_steps == 0 or self.curriculum_scheduler.first_step:
self.reset_activation_shape()
self.curriculum_scheduler.first_step = False
elif new_difficulty != self.curriculum_scheduler.get_difficulty( \
self.global_steps):
self.reset_activation_shape()
if data_iter:
self.set_dataiterator(data_iter)
self.module.train()
self.total_loss = None
self._compute_loss = True
# Do the work
self.timers('train_batch').start()
sched = schedule.TrainSchedule(micro_batches=self.micro_batches,
stages=self.num_stages,
stage_id=self.stage_id)
self._exec_schedule(sched)
self.agg_train_loss = self._aggregate_total_loss()
self.timers('train_batch').stop()
if self.global_steps % self.steps_per_print() == 0:
if self.global_rank == 0:
elapsed = self.timers('train_batch').elapsed(reset=True) / 1000.0
iter_time = elapsed / self.steps_per_print()
tput = self.train_batch_size() / iter_time
print(f'steps: {self.global_steps} '
f'loss: {self.agg_train_loss:0.4f} '
f'iter time (s): {iter_time:0.3f} '
f'samples/sec: {tput:0.3f}')
# Tensorboard
if self.tensorboard_enabled():
if self.global_rank == 0:
self.summary_events = [(f'Train/Samples/train_loss',
self.agg_train_loss.mean().item(),
self.global_samples)]
for event in self.summary_events: # write_summary_events
self.summary_writer.add_scalar(event[0], event[1], event[2])
if self.global_steps % self.steps_per_print() == 0:
self.summary_writer.flush()
if self.wall_clock_breakdown(
) and self.global_steps % self.steps_per_print() == 0:
self.timers.log([
'pipe_send_output',
'pipe_send_grad',
'pipe_recv_input',
'pipe_recv_grad'
])
# TODO: should return precisely what loss returned and allow others to be queried?
return self.agg_train_loss
def eval_batch(self,
data_iter,
return_logits=False,
compute_loss=True,
reduce_output='avg'):
"""Evaluate the pipeline on a batch of data from ``data_iter``. The
engine will evaluate ``self.train_batch_size()`` total samples
collectively across all workers.
This method is equivalent to:
.. code-block:: python
module.eval()
with torch.no_grad():
output = module(batch)
.. warning::
A total of ``self.gradient_accumulation_steps()`` entries will be pulled
from ``data_iter`` by each pipeline. There must be sufficient
data left in ``data_iter`` or else a ``StopIteration`` will halt training.
DeepSpeed provides a convenience class :class:`deepspeed.utils.RepeatingLoader`
that wraps data loaders to automatically restart upon a ``StopIteration``.
Args:
data_iter (Iterator): Iterator of data to evaluate.
Returns:
The arithmetic mean of the losses computed this batch.
"""
self.eval_return_logits = return_logits
self.module.eval()
# Curriculum learning could change activation shape
if self.curriculum_enabled():
new_difficulty = self.curriculum_scheduler.update_difficulty( \
self.global_steps + 1)
if self.global_steps == 0 or self.curriculum_scheduler.first_step:
self.reset_activation_shape()
self.curriculum_scheduler.first_step = False
elif new_difficulty != self.curriculum_scheduler.get_difficulty( \
self.global_steps):
self.reset_activation_shape()
eval_output = None
self._compute_loss = compute_loss
# Use the provided data iterator
train_iterator = self.data_iterator
self.set_dataiterator(data_iter)
# Do the work
sched = schedule.InferenceSchedule(micro_batches=self.micro_batches,
stages=self.num_stages,
stage_id=self.stage_id)
with torch.no_grad():
self._exec_schedule(sched)
if self.is_last_stage():
eval_output = self._reduce_outputs(self.fwd_outputs, reduce=reduce_output)
if compute_loss:
eval_output = self._bcast_pipe_scalar(eval_output)
if self.tensorboard_enabled():
if self.global_rank == 0:
self.summary_events = [(f'Train/Samples/eval_loss',
eval_output.mean().item(),
self.global_samples)]
for event in self.summary_events: # write_summary_events
self.summary_writer.add_scalar(event[0], event[1], event[2])
self.summary_writer.flush()
# Restore the training iterator
self.set_dataiterator(train_iterator)
# Reset any buffers that may have been populated during the forward passes.
#ds_checkpointing.reset()
self.eval_return_logits = False
if return_logits:
outputs = self.outputs
self.outputs = None
return eval_output, outputs
return eval_output
def set_train_batch_size(self, train_batch_size):
"""Adjust the global batch size by increasing or decreasing the number of
micro-batches (i.e., gradient accumulation steps). The size of each micro-batch
(i.e., ``train_micro_batch_size_per_gpu``) is not changed.
Args:
train_batch_size (int): The new global batch size for training.
Raises:
ValueError: if ``train_batch_size`` is not divisible by the
configured micro-batch size and data parallelism.
"""
super().set_train_batch_size(train_batch_size)
self.micro_batches = self.gradient_accumulation_steps()
def is_first_stage(self):
"""True if this process is in the first stage in the pipeline."""
return self.stage_id == 0
def is_last_stage(self):
"""True if this process is in the last stage in the pipeline."""
return self.stage_id == self.num_stages - 1
def _reduce_outputs(self, outputs, reduce='avg', reduce_dp=True):
if reduce is None:
return outputs
if reduce.lower() == 'avg':
# first sum over all microbatches
if torch.is_tensor(outputs[0]):
reduced = sum(outputs)
else:
assert isinstance(outputs, (list, tuple))
reduced = [torch.zeros_like(o) for o in outputs[0]]
for idx, out in outputs:
reduced[idx] += out
# Average over the microbatches
reduced = self._scale_loss_by_gas(reduced)
# Average over DP groups
if reduce_dp and self.is_data_parallel:
if torch.is_tensor(reduced):
dist.all_reduce(reduced, group=self.mpu.get_data_parallel_group())
reduced /= self.dp_world_size
else:
for idx in range(len(reduced)):
dist.all_reduce(reduced[idx],
group=self.mpu.get_data_parallel_group())
reduced[idx] /= self.dp_world_size
return reduced
else:
raise NotImplementedError(f'reduction type {reduce} not supported.')
def _bcast_pipe_scalar(self, data, src_rank=None, dtype=torch.float32):
# Default to last stage (e.g., for broadcasting loss)
if src_rank is None:
src_rank = self.grid.stage_to_global(self.num_stages - 1)
assert src_rank in self.grid.pp_group
if self.global_rank == src_rank:
result = data.clone().detach()
else:
result = torch.Tensor([0.]).type(dtype).to(self.device)
dist.broadcast(tensor=result,
src=src_rank,
group=self.mpu.get_pipe_parallel_group())
return result
def _aggregate_total_loss(self):
# Scale loss, average among DP ranks, and bcast loss to the rest of my DP group
if self.is_last_stage():
loss = self._scale_loss_by_gas(self.total_loss)
self.dp_group_loss = loss.clone().detach()
## Average loss across all data-parallel groups
agg_loss = self.dp_group_loss.clone().detach()
#print(f'RANK={self.global_rank} bcast SENDER src={self.global_rank} group={self.grid.pp_group}', flush=True)
if self.is_data_parallel:
dist.all_reduce(agg_loss, group=self.mpu.get_data_parallel_group())
agg_loss /= self.dp_world_size
assert self.global_rank in self.grid.pp_group
losses = torch.Tensor([self.dp_group_loss, agg_loss]).to(self.device)
dist.broadcast(tensor=losses,
src=self.global_rank,
group=self.mpu.get_pipe_parallel_group())
else:
# Get loss from last stage
src_rank = self.grid.stage_to_global(self.num_stages - 1)
assert src_rank in self.grid.pp_group
losses = torch.Tensor([0., 0.]).to(self.device)
dist.broadcast(tensor=losses,
src=src_rank,
group=self.grid.get_pipe_parallel_group())
self.dp_group_loss = losses[0].clone().detach()
agg_loss = losses[1].clone().detach()
return agg_loss
def set_dataloader(self, loader):
""""""
if self.is_first_stage() or self.is_last_stage():
self.training_dataloader = loader
self.data_iterator = iter(self.training_dataloader)
def set_dataiterator(self, iterator):
""" Store an iterator to sample for training data. """
if self.is_first_stage() or self.is_last_stage():
self.training_dataloader = None
self.data_iterator = iterator
def set_batch_fn(self, fn):
self.batch_fn = fn
def is_gradient_accumulation_boundary(self):
"""True if the engine is executing a gradient reduction or optimizer step instruction.
This is overridden from :class:`DeepSpeedEngine` to force reductions
and steps when the pipeline engine is instructed to do so.
Returns:
bool: whether reductions and optimizer steps should occur.
"""
return self._force_grad_boundary
def log_for_device(self, *msg):
if LOG_STAGE == self.stage_id or LOG_STAGE == -1:
if DATA_PARALLEL_ID == self.grid.data_parallel_id or DATA_PARALLEL_ID == -1:
print(
f'RANK={dist.get_rank()} '
f'PIPE-ID={self.stage_id} '
f'DATA-ID={self.grid.data_parallel_id} '
f'MBATCH-ID={self.microbatch_id} '
f'STEP-ID={self.log_batch_step_id} '
'::',
*msg,
flush=True)
def tput_log(self, *msg):
if self.global_rank == 0 and self.global_steps % self.steps_per_print() == 0:
print(*msg)
def _next_batch(self):
# If using 3D parallelism, only some first-stage ranks may do IO
batch = None
if self.data_iterator is not None:
batch = next(self.data_iterator)
# Any post-processing, like broadcasting across a slice-parallel group.
if self.batch_fn:
batch = self.batch_fn(batch)
return batch
def _exec_forward_pass(self, buffer_id):
self.tput_timer.start()
self.mem_status('BEFORE FWD', reset_max=True)
if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple):
inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id])
else:
inputs = self.pipe_buffers['inputs'][buffer_id].clone()
# collect the partitioned input from the previous stage
if self.is_pipe_partitioned and not self.is_first_stage():
part_input = PartitionedTensor.from_meta(
meta=inputs[0],
local_part=inputs[1],
group=self.grid.get_slice_parallel_group())
inputs = (part_input.full(), *inputs[2:])
inputs[0].requires_grad = True
# skip mask
#inputs[1].requires_grad = True
part_input = None
inputs = inputs[0] if len(inputs) == 1 else inputs
self.pipe_buffers['inputs'][buffer_id] = inputs
# Zero out the gradients each time we use the tensor because only the data in
# tensor changes across batches
self._zero_grads(inputs)
outputs = super().forward(inputs)
# Partition the outputs if we are not the last stage
if self.is_pipe_partitioned and not self.is_last_stage():
if isinstance(outputs, tuple):
first_output = outputs[0]
# TODO: Improve pipe partitioning to pass multiple tensors that require grads
assert all([
torch.is_tensor(elt) and elt.requires_grad is False
for elt in outputs[1:]
])
outputs_tail = outputs[1:]
elif torch.is_tensor(outputs):
first_output = outputs
outputs_tail = []
else:
raise ValueError("expecting a tensor or a tuple of tensors")
part = PartitionedTensor(tensor=first_output,
group=self.grid.get_slice_parallel_group())
# Clear the large output data, but save the computation graph
first_output.data = torch.zeros(1)
self.pipe_buffers['output_tensors'][buffer_id] = first_output
# Inject the partitioned tensor into the output before sending
outputs = (part.to_meta(), part.data(), *outputs_tail)
part = None
self.pipe_buffers['outputs'][buffer_id] = outputs
# Optionally compute loss on the last device
if self.is_last_stage():
if self._compute_loss and self.loss_model is not None:
labels = self.pipe_buffers['labels'][buffer_id]
self.loss = self.loss_model(outputs, labels)
else:
# Some models just return loss from forward()
self.loss = outputs
if self.eval_return_logits:
self.outputs = outputs
if isinstance(self.loss, torch.Tensor):
self.fwd_outputs.append(self.loss.detach())
if self.total_loss is None:
self.total_loss = torch.zeros_like(self.loss)
self.total_loss += self.loss.detach()
else:
self.fwd_outputs.append([l.detach() for l in self.loss])
if self.total_loss is None:
self.total_loss = [torch.zeros_like(l) for l in self.loss]
for idx, l in enumerate(self.loss):
self.total_loss[idx] += l.detach()
def _exec_backward_pass(self, buffer_id):
assert self.optimizer is not None, "must provide optimizer during " \
"init in order to use backward"
self.mem_status('BEFORE BWD', reset_max=True)
# The last stage just runs backward on the loss using DeepSpeed's typical
# mechanisms.
if self.is_last_stage():
super().backward(self.loss)
self.mem_status('AFTER BWD')
return
outputs = self.pipe_buffers['outputs'][buffer_id]
if self.wall_clock_breakdown():
self.timers('backward_microstep').start()
self.timers('backward').start()
self.timers('backward_inner_microstep').start()
self.timers('backward_inner').start()
# Reconstruct if we previously partitioned the output. We must be
# careful to also restore the computational graph of the tensors we partitioned.
if self.is_pipe_partitioned:
if self.is_grad_partitioned:
part_output = PartitionedTensor.from_meta(
meta=outputs[0],
local_part=outputs[1],
group=self.grid.get_slice_parallel_group())
self.pipe_buffers['output_tensors'][buffer_id].data = part_output.full()
outputs = (self.pipe_buffers['output_tensors'][buffer_id], *outputs[2:])
else:
# Already restored from partition
self.pipe_buffers['output_tensors'][buffer_id].data = outputs[0]
outputs = (self.pipe_buffers['output_tensors'][buffer_id], *outputs[1:])
grad_tensors = self.grad_layer
if self.is_grad_partitioned:
#print(f'RANK={self.global_rank} BEFORE-BWD restoring grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')
part_grad = PartitionedTensor.from_meta(
meta=self.grad_layer[0],
local_part=self.grad_layer[1],
group=self.grid.get_slice_parallel_group())
grad_tensors = (part_grad.full(), *grad_tensors[2:])
part_grad = None
#print(f'RANK={self.global_rank} BEFORE-BWD restored grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')
if self.bfloat16_enabled() and not self.is_last_stage():
# manually call because we don't call optimizer.backward()
self.optimizer.clear_lp_grads()
# This handles either a single tensor or tuple of tensors.
if isinstance(outputs, tuple):
out_tensors = [t for t in outputs if t.is_floating_point()]
assert len(out_tensors) == len(grad_tensors)
torch.autograd.backward(tensors=out_tensors, grad_tensors=grad_tensors)
else:
torch.autograd.backward(tensors=(outputs, ), grad_tensors=(grad_tensors, ))
if self.bfloat16_enabled() and not self.is_last_stage():
# manually call because we don't call optimizer.backward()
self.optimizer.update_hp_grads(clear_lp_grads=False)
# Free up the memory from the output of forward()
self.pipe_buffers['output_tensors'][buffer_id] = None
self.pipe_buffers['outputs'][buffer_id] = None
grad_tensors = None
if self.wall_clock_breakdown():
self.timers('backward_inner').stop()
self.timers('backward_inner_microstep').stop()
self.timers('backward').stop()
self.timers('backward_microstep').stop()
self.mem_status('AFTER BWD')
def _exec_load_micro_batch(self, buffer_id):
if self.wall_clock_breakdown():
self.timers('batch_input').start()
batch = self._next_batch()
if self.is_first_stage():
loaded = None
if torch.is_tensor(batch[0]):
loaded = batch[0].clone().to(self.device).detach()
loaded.requires_grad = loaded.is_floating_point()
else:
assert isinstance(batch[0], tuple)
# Assume list or tuple
loaded = []
for x in batch[0]:
assert torch.is_tensor(x)
mine = x.clone().detach().to(self.device)
mine.requires_grad = mine.is_floating_point()
loaded.append(mine)
loaded = tuple(loaded)
self.pipe_buffers['inputs'][buffer_id] = loaded
if self.is_last_stage():
loaded = batch[1]
if torch.is_tensor(batch[1]):
loaded = batch[1].to(self.device)
elif isinstance(batch[1], tuple):
loaded = []
for x in batch[1]:
assert torch.is_tensor(x)
x = x.to(self.device).detach()
loaded.append(x)
loaded = tuple(loaded)
self.pipe_buffers['labels'][buffer_id] = loaded
if self.wall_clock_breakdown():
self.timers('batch_input').stop()
def _send_tensor_meta(self, buffer, recv_stage):
""" Communicate metadata about upcoming p2p transfers.
Metadata is communicated in this order:
* type (0: tensor, 1: list)
* num_tensors if type=list
foreach tensor in buffer:
* ndims
* shape
"""
send_bytes = 0
if isinstance(buffer, torch.Tensor):
type_tensor = torch.LongTensor(data=[0]).to(self.device)
p2p.send(type_tensor, recv_stage)
send_shape = torch.LongTensor(data=buffer.size()).to(self.device)
send_ndims = torch.LongTensor(data=[len(buffer.size())]).to(self.device)
p2p.send(send_ndims, recv_stage)
p2p.send(send_shape, recv_stage)
send_bytes += _tensor_bytes(buffer)
elif isinstance(buffer, list):
assert (False)
type_tensor = torch.LongTensor(data=[1]).to(self.device)
p2p.send(type_tensor, recv_stage)
count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)
p2p.send(count_tensor, recv_stage)
for tensor in buffer:
assert isinstance(tensor, torch.Tensor)
send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
p2p.send(send_ndims, recv_stage)
p2p.send(send_shape, recv_stage)
send_bytes += _tensor_bytes(tensor)
elif isinstance(buffer, tuple):
type_tensor = torch.LongTensor(data=[2]).to(self.device)
p2p.send(type_tensor, recv_stage)
count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)
p2p.send(count_tensor, recv_stage)
for idx, tensor in enumerate(buffer):
assert isinstance(tensor, torch.Tensor)
send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
send_dtype = torch.LongTensor(data=[self.DTYPE_TO_ID[tensor.dtype]]).to(
self.device)
p2p.send(send_dtype, recv_stage)
p2p.send(send_ndims, recv_stage)
p2p.send(send_shape, recv_stage)
# Useful for performance debugging.
'''
new_bytes = _tensor_bytes(tensor)
send_bytes += _tensor_bytes(tensor)
# Useful for performance debugging.
if self.grid.data_parallel_id == 0:
print(
f'STAGE={self.stage_id} pipe-send-volume[{idx}]: shape={send_shape} {new_bytes/1024**2:0.2f}MB'
)
'''
else:
raise NotImplementedError(f'Could not send meta type {type(buffer)}')
# Useful for performance debugging.
'''
if self.grid.data_parallel_id == 0:
print(f'STAGE={self.stage_id} pipe-send-volume: {send_bytes/1024**2:0.2f}MB')
'''
def _recv_tensor_meta(self, send_stage):
"""Receive metadata about upcoming p2p transfers and return allocated buffers.
Metadata is communicated in this order:
* type (0: tensor, 1: list)
* num_tensors if type=list
foreach tensor in buffer:
* ndims
* shape
Returns:
Allocated buffer for receiving from send_stage.
"""
type_tensor = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(type_tensor, send_stage)
recv_type = type_tensor.item()
# A single tensor will be sent.
if recv_type == 0:
recv_ndims = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(recv_ndims, send_stage)
recv_ndims = recv_ndims.item()
recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
p2p.recv(recv_shape, send_stage)
recv_shape = recv_shape.tolist()
return self._allocate_buffer(recv_shape, num_buffers=1)[0]
# List or tuple of tensors
elif recv_type == 1 or recv_type == 2:
count_tensor = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(count_tensor, send_stage)
num_tensors = count_tensor.item()
recv_shapes_and_dtypes = []
for idx in range(num_tensors):
recv_dtype = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(recv_dtype, send_stage)
recv_dtype = self.ID_TO_DTYPE[recv_dtype.item()]
recv_ndims = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(recv_ndims, send_stage)
recv_ndims = recv_ndims.item()
recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
p2p.recv(recv_shape, send_stage)
recv_shapes_and_dtypes.append((recv_shape.tolist(), recv_dtype))
buffers = self._allocate_buffers(recv_shapes_and_dtypes, num_buffers=1)[0]
# Convert to tuples if requested.
if recv_type == 2:
buffers = tuple(buffers)
return buffers
else:
raise NotImplementedError(f'Could not receive type {type(recv_type)}')
def _exec_send_activations(self, buffer_id):
if self.wall_clock_breakdown():
self.timers('pipe_send_output').start()
outputs = self.pipe_buffers['outputs'][buffer_id]
# NCCL does not like to send torch.BoolTensor types, so cast the mask to half().
# We could do char, but with half() we can eventually flatten with other fp16
# messages (TODO)
if self.has_attention_mask or self.has_bool_tensors:
outputs = list(outputs)
outputs[-1] = outputs[-1].half()
outputs = tuple(outputs)
if self.first_output_send:
self.first_output_send = False
self._send_tensor_meta(outputs, self.next_stage)
if isinstance(outputs, torch.Tensor):
p2p.send(outputs, self.next_stage)
elif isinstance(outputs, tuple):
for idx, buffer in enumerate(outputs):
p2p.send(buffer, self.next_stage)
else:
raise NotImplementedError('Could not send output of type '
f'{type(outputs)}')
# Restore the boolean tensor
if self.has_attention_mask or self.has_bool_tensors:
outputs = list(outputs)
outputs[-1] = outputs[-1].bool()
outputs = tuple(outputs)
if self.wall_clock_breakdown():
self.timers('pipe_send_output').stop()
def _exec_send_grads(self, buffer_id):
if self.wall_clock_breakdown():
self.timers('pipe_send_grad').start()
inputs = self.pipe_buffers['inputs'][buffer_id]
# Partition the gradient
if self.is_grad_partitioned:
if isinstance(inputs, tuple):
first_input = inputs[0]
assert all([torch.is_tensor(elt) for elt in inputs[1:]])
inputs_grad_tail = [
elt.grad for elt in inputs[1:] if elt.grad is not None
]