Skip to content

Commit 3c83269

Browse files
authored
Implement Fully Sharded Data Parallel (FSDP) in PyTorch XLA (#3431)
* Implement Fully Sharded Data Parallel (FSDP) in PyTorch XLA * move the FSDP module to `torch_xla.distributed` * adding `mark_step_on_freeing` as a temp workaround to #3455 * check in __init__ whether the module is already FSDP; fix exception types * add `optimization_barrier_` (#3493) to avoid fusion of full parameter reconstruction with subsequent freeing * also apply `xm.optimization_barrier_` to FSDP output's gradients * deprecate `mark_step_on_freeing` (since we have optimization barrier now) * add option to run a dummy forward pass in FSDP * add `_shard_size_multiple` to make sharded parameters a multiple of 128 for efficient all-gather (see #3510 (comment)) * refactor optimization_barrier_ to separately apply to forward and backward pass `_rebuild_full_params` and `_free_full_params` * seal off more relevant ops w/ optimization_barrier_ to avoid undesired fusion * remove obsolete `mark_step_on_freeing` and `use_all_gather_via_all_reduce` configs; unpin layout for all_reduce; add a wrapper for gradient checkpointing on modules; remove redundant `param_names` * handle keyword arguments in `checkpoint_module` * add gradient checkpointing option to MNIST and ImageNet FSDP examples * refactor `optimization_barrier` and only apply it in forward or backward when specified * refactor command line tool to consolidate sharded checkpoints * address reviewers' comments from GitHub * add more user instructions for checkpoint consolidation * change `flatten_parameters` default to False since it didn't bring an actual speed up in tests and breaks optimizer groups * documentation refinement
1 parent 6e3992a commit 3c83269

8 files changed

+2810
-0
lines changed

test/test_train_mp_imagenet_fsdp.py

Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
import args_parse
2+
3+
SUPPORTED_MODELS = [
4+
'alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201',
5+
'inception_v3', 'resnet101', 'resnet152', 'resnet18', 'resnet34',
6+
'resnet50', 'squeezenet1_0', 'squeezenet1_1', 'vgg11', 'vgg11_bn', 'vgg13',
7+
'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn'
8+
]
9+
10+
MODEL_OPTS = {
11+
'--model': {
12+
'choices': SUPPORTED_MODELS,
13+
'default': 'resnet50',
14+
},
15+
'--test_set_batch_size': {
16+
'type': int,
17+
},
18+
'--lr_scheduler_type': {
19+
'type': str,
20+
},
21+
'--lr_scheduler_divide_every_n_epochs': {
22+
'type': int,
23+
},
24+
'--lr_scheduler_divisor': {
25+
'type': int,
26+
},
27+
'--test_only_at_end': {
28+
'action': 'store_true',
29+
},
30+
'--num_warmup_epochs': {
31+
'type': float,
32+
'default': 0.9,
33+
},
34+
'--eval_interval': {
35+
'type': int,
36+
'default': 1,
37+
},
38+
'--flatten_parameters': {
39+
'action': 'store_true',
40+
},
41+
'--use_nested_fsdp': {
42+
'action': 'store_true',
43+
},
44+
'--use_gradient_checkpointing': {
45+
'action': 'store_true',
46+
},
47+
}
48+
49+
FLAGS = args_parse.parse_common_options(
50+
datadir='/tmp/imagenet',
51+
batch_size=None,
52+
num_epochs=None,
53+
momentum=None,
54+
lr=None,
55+
target_accuracy=None,
56+
profiler_port=9012,
57+
opts=MODEL_OPTS.items(),
58+
)
59+
60+
import os
61+
import sys
62+
import schedulers
63+
import numpy as np
64+
import torch
65+
import torch.nn as nn
66+
import torch.nn.functional as F
67+
import torch.optim as optim
68+
import torchvision
69+
import torchvision.transforms as transforms
70+
import torch_xla
71+
import torch_xla.debug.metrics as met
72+
import torch_xla.distributed.parallel_loader as pl
73+
import torch_xla.utils.utils as xu
74+
import torch_xla.core.xla_model as xm
75+
import torch_xla.distributed.xla_multiprocessing as xmp
76+
import torch_xla.test.test_utils as test_utils
77+
78+
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP, checkpoint_module
79+
80+
DEFAULT_KWARGS = dict(
81+
batch_size=128,
82+
test_set_batch_size=64,
83+
num_epochs=18,
84+
momentum=0.9,
85+
lr=0.1,
86+
target_accuracy=0.0,
87+
)
88+
MODEL_SPECIFIC_DEFAULTS = {
89+
# Override some of the args in DEFAULT_KWARGS, or add them to the dict
90+
# if they don't exist.
91+
'resnet50':
92+
dict(
93+
DEFAULT_KWARGS, **{
94+
'lr': 0.5,
95+
'lr_scheduler_divide_every_n_epochs': 20,
96+
'lr_scheduler_divisor': 5,
97+
})
98+
}
99+
100+
# Set any args that were not explicitly given by the user.
101+
default_value_dict = MODEL_SPECIFIC_DEFAULTS.get(FLAGS.model, DEFAULT_KWARGS)
102+
for arg, value in default_value_dict.items():
103+
if getattr(FLAGS, arg) is None:
104+
setattr(FLAGS, arg, value)
105+
106+
107+
def get_model_property(key):
108+
default_model_property = {
109+
'img_dim': 224,
110+
'model_fn': getattr(torchvision.models, FLAGS.model)
111+
}
112+
model_properties = {
113+
'inception_v3': {
114+
'img_dim': 299,
115+
'model_fn': lambda: torchvision.models.inception_v3(aux_logits=False)
116+
},
117+
}
118+
model_fn = model_properties.get(FLAGS.model, default_model_property)[key]
119+
return model_fn
120+
121+
122+
def _train_update(device, step, loss, tracker, epoch, writer):
123+
test_utils.print_training_update(
124+
device,
125+
step,
126+
loss.item(),
127+
tracker.rate(),
128+
tracker.global_rate(),
129+
epoch,
130+
summary_writer=writer)
131+
132+
133+
def train_imagenet():
134+
print('==> Preparing data..')
135+
img_dim = get_model_property('img_dim')
136+
if FLAGS.fake_data:
137+
train_dataset_len = 1200000 # Roughly the size of Imagenet dataset.
138+
train_loader = xu.SampleGenerator(
139+
data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
140+
torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
141+
sample_count=train_dataset_len // FLAGS.batch_size //
142+
xm.xrt_world_size())
143+
test_loader = xu.SampleGenerator(
144+
data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim),
145+
torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)),
146+
sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size())
147+
else:
148+
normalize = transforms.Normalize(
149+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
150+
train_dataset = torchvision.datasets.ImageFolder(
151+
os.path.join(FLAGS.datadir, 'train'),
152+
transforms.Compose([
153+
transforms.RandomResizedCrop(img_dim),
154+
transforms.RandomHorizontalFlip(),
155+
transforms.ToTensor(),
156+
normalize,
157+
]))
158+
train_dataset_len = len(train_dataset.imgs)
159+
resize_dim = max(img_dim, 256)
160+
test_dataset = torchvision.datasets.ImageFolder(
161+
os.path.join(FLAGS.datadir, 'val'),
162+
# Matches Torchvision's eval transforms except Torchvision uses size
163+
# 256 resize for all models both here and in the train loader. Their
164+
# version crashes during training on 299x299 images, e.g. inception.
165+
transforms.Compose([
166+
transforms.Resize(resize_dim),
167+
transforms.CenterCrop(img_dim),
168+
transforms.ToTensor(),
169+
normalize,
170+
]))
171+
172+
train_sampler, test_sampler = None, None
173+
if xm.xrt_world_size() > 1:
174+
train_sampler = torch.utils.data.distributed.DistributedSampler(
175+
train_dataset,
176+
num_replicas=xm.xrt_world_size(),
177+
rank=xm.get_ordinal(),
178+
shuffle=True)
179+
test_sampler = torch.utils.data.distributed.DistributedSampler(
180+
test_dataset,
181+
num_replicas=xm.xrt_world_size(),
182+
rank=xm.get_ordinal(),
183+
shuffle=False)
184+
train_loader = torch.utils.data.DataLoader(
185+
train_dataset,
186+
batch_size=FLAGS.batch_size,
187+
sampler=train_sampler,
188+
drop_last=FLAGS.drop_last,
189+
shuffle=False if train_sampler else True,
190+
persistent_workers=True,
191+
num_workers=FLAGS.num_workers)
192+
test_loader = torch.utils.data.DataLoader(
193+
test_dataset,
194+
batch_size=FLAGS.test_set_batch_size,
195+
sampler=test_sampler,
196+
drop_last=FLAGS.drop_last,
197+
shuffle=False,
198+
persistent_workers=True,
199+
num_workers=FLAGS.num_workers)
200+
201+
torch.manual_seed(42)
202+
203+
device = xm.xla_device()
204+
model = get_model_property('model_fn')()
205+
# Wrap the model with FSDP
206+
# You may wrap all, a subset, or none of the sub-modules with inner FSDPs
207+
# - to implement ZeRO-2, wrap none of the sub-modules
208+
# - to implement ZeRO-3, wrap all of the sub-modules (nested FSDP)
209+
# - you may wrap sub-modules at different granularity (e.g. at each resnet
210+
# stage or each residual block or each conv layer).
211+
fsdp_wrap = lambda m: FSDP(
212+
m.to(device), flatten_parameters=FLAGS.flatten_parameters)
213+
# Apply gradient checkpointing to sub-modules if specified
214+
grad_ckpt_wrap = checkpoint_module if FLAGS.use_gradient_checkpointing else (
215+
lambda x: x)
216+
if FLAGS.use_nested_fsdp:
217+
# Here we apply inner FSDP at the level of child modules for ZeRO-3, which
218+
# corresponds to different stages in resnet (i.e. Stage 1 to 5).
219+
for submodule_name, submodule in model.named_children():
220+
if sum(p.numel() for p in submodule.parameters()) == 0:
221+
# Skip those submodules without parameters (i.e. no need to shard them)
222+
continue
223+
# Note: wrap with `checkpoint_module` first BEFORE wrapping with FSDP
224+
m_fsdp = fsdp_wrap(grad_ckpt_wrap(getattr(model, submodule_name)))
225+
setattr(model, submodule_name, m_fsdp)
226+
# Always wrap the base model with an outer FSDP
227+
model = fsdp_wrap(model)
228+
229+
writer = None
230+
if xm.is_master_ordinal():
231+
writer = test_utils.get_summary_writer(FLAGS.logdir)
232+
optimizer = optim.SGD(
233+
model.parameters(),
234+
lr=FLAGS.lr,
235+
momentum=FLAGS.momentum,
236+
weight_decay=1e-4)
237+
num_training_steps_per_epoch = train_dataset_len // (
238+
FLAGS.batch_size * xm.xrt_world_size())
239+
lr_scheduler = schedulers.WarmupAndExponentialDecayScheduler(
240+
optimizer,
241+
num_steps_per_epoch=num_training_steps_per_epoch,
242+
divide_every_n_epochs=FLAGS.lr_scheduler_divide_every_n_epochs,
243+
divisor=FLAGS.lr_scheduler_divisor,
244+
num_warmup_epochs=FLAGS.num_warmup_epochs,
245+
summary_writer=writer)
246+
loss_fn = nn.CrossEntropyLoss()
247+
248+
def train_loop_fn(loader, epoch):
249+
tracker = xm.RateTracker()
250+
model.train()
251+
for step, (data, target) in enumerate(loader):
252+
optimizer.zero_grad()
253+
output = model(data)
254+
loss = loss_fn(output, target)
255+
loss.backward()
256+
optimizer.step() # do not reduce gradients on sharded params
257+
tracker.add(FLAGS.batch_size)
258+
if lr_scheduler:
259+
lr_scheduler.step()
260+
if step % FLAGS.log_steps == 0:
261+
xm.add_step_closure(
262+
_train_update, args=(device, step, loss, tracker, epoch, writer))
263+
264+
def test_loop_fn(loader, epoch):
265+
total_samples, correct = 0, 0
266+
model.eval()
267+
for step, (data, target) in enumerate(loader):
268+
output = model(data)
269+
pred = output.max(1, keepdim=True)[1]
270+
correct += pred.eq(target.view_as(pred)).sum()
271+
total_samples += data.size()[0]
272+
if step % FLAGS.log_steps == 0:
273+
xm.add_step_closure(
274+
test_utils.print_test_update, args=(device, None, epoch, step))
275+
accuracy = 100.0 * correct.item() / total_samples
276+
accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
277+
return accuracy
278+
279+
train_device_loader = pl.MpDeviceLoader(train_loader, device)
280+
test_device_loader = pl.MpDeviceLoader(test_loader, device)
281+
accuracy, max_accuracy = 0.0, 0.0
282+
for epoch in range(1, FLAGS.num_epochs + 1):
283+
xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
284+
train_loop_fn(train_device_loader, epoch)
285+
xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))
286+
run_eval = ((not FLAGS.test_only_at_end and
287+
epoch % FLAGS.eval_interval == 0) or epoch == FLAGS.num_epochs)
288+
if run_eval:
289+
accuracy = test_loop_fn(test_device_loader, epoch)
290+
xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
291+
epoch, test_utils.now(), accuracy))
292+
max_accuracy = max(accuracy, max_accuracy)
293+
test_utils.write_to_summary(
294+
writer,
295+
epoch,
296+
dict_to_write={'Accuracy/test': accuracy},
297+
write_xla_metrics=True)
298+
if FLAGS.metrics_debug:
299+
xm.master_print(met.metrics_report())
300+
301+
test_utils.close_summary_writer(writer)
302+
xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
303+
return max_accuracy
304+
305+
306+
def _mp_fn(index, flags):
307+
global FLAGS
308+
FLAGS = flags
309+
torch.set_default_tensor_type('torch.FloatTensor')
310+
accuracy = train_imagenet()
311+
if accuracy < FLAGS.target_accuracy:
312+
print('Accuracy {} is below target {}'.format(accuracy,
313+
FLAGS.target_accuracy))
314+
sys.exit(21)
315+
316+
317+
if __name__ == '__main__':
318+
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores)

0 commit comments

Comments
 (0)