|
| 1 | +import args_parse |
| 2 | + |
| 3 | +SUPPORTED_MODELS = [ |
| 4 | + 'alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201', |
| 5 | + 'inception_v3', 'resnet101', 'resnet152', 'resnet18', 'resnet34', |
| 6 | + 'resnet50', 'squeezenet1_0', 'squeezenet1_1', 'vgg11', 'vgg11_bn', 'vgg13', |
| 7 | + 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn' |
| 8 | +] |
| 9 | + |
| 10 | +MODEL_OPTS = { |
| 11 | + '--model': { |
| 12 | + 'choices': SUPPORTED_MODELS, |
| 13 | + 'default': 'resnet50', |
| 14 | + }, |
| 15 | + '--test_set_batch_size': { |
| 16 | + 'type': int, |
| 17 | + }, |
| 18 | + '--lr_scheduler_type': { |
| 19 | + 'type': str, |
| 20 | + }, |
| 21 | + '--lr_scheduler_divide_every_n_epochs': { |
| 22 | + 'type': int, |
| 23 | + }, |
| 24 | + '--lr_scheduler_divisor': { |
| 25 | + 'type': int, |
| 26 | + }, |
| 27 | + '--test_only_at_end': { |
| 28 | + 'action': 'store_true', |
| 29 | + }, |
| 30 | + '--num_warmup_epochs': { |
| 31 | + 'type': float, |
| 32 | + 'default': 0.9, |
| 33 | + }, |
| 34 | + '--eval_interval': { |
| 35 | + 'type': int, |
| 36 | + 'default': 1, |
| 37 | + }, |
| 38 | + '--flatten_parameters': { |
| 39 | + 'action': 'store_true', |
| 40 | + }, |
| 41 | + '--use_nested_fsdp': { |
| 42 | + 'action': 'store_true', |
| 43 | + }, |
| 44 | + '--use_gradient_checkpointing': { |
| 45 | + 'action': 'store_true', |
| 46 | + }, |
| 47 | +} |
| 48 | + |
| 49 | +FLAGS = args_parse.parse_common_options( |
| 50 | + datadir='/tmp/imagenet', |
| 51 | + batch_size=None, |
| 52 | + num_epochs=None, |
| 53 | + momentum=None, |
| 54 | + lr=None, |
| 55 | + target_accuracy=None, |
| 56 | + profiler_port=9012, |
| 57 | + opts=MODEL_OPTS.items(), |
| 58 | +) |
| 59 | + |
| 60 | +import os |
| 61 | +import sys |
| 62 | +import schedulers |
| 63 | +import numpy as np |
| 64 | +import torch |
| 65 | +import torch.nn as nn |
| 66 | +import torch.nn.functional as F |
| 67 | +import torch.optim as optim |
| 68 | +import torchvision |
| 69 | +import torchvision.transforms as transforms |
| 70 | +import torch_xla |
| 71 | +import torch_xla.debug.metrics as met |
| 72 | +import torch_xla.distributed.parallel_loader as pl |
| 73 | +import torch_xla.utils.utils as xu |
| 74 | +import torch_xla.core.xla_model as xm |
| 75 | +import torch_xla.distributed.xla_multiprocessing as xmp |
| 76 | +import torch_xla.test.test_utils as test_utils |
| 77 | + |
| 78 | +from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP, checkpoint_module |
| 79 | + |
| 80 | +DEFAULT_KWARGS = dict( |
| 81 | + batch_size=128, |
| 82 | + test_set_batch_size=64, |
| 83 | + num_epochs=18, |
| 84 | + momentum=0.9, |
| 85 | + lr=0.1, |
| 86 | + target_accuracy=0.0, |
| 87 | +) |
| 88 | +MODEL_SPECIFIC_DEFAULTS = { |
| 89 | + # Override some of the args in DEFAULT_KWARGS, or add them to the dict |
| 90 | + # if they don't exist. |
| 91 | + 'resnet50': |
| 92 | + dict( |
| 93 | + DEFAULT_KWARGS, **{ |
| 94 | + 'lr': 0.5, |
| 95 | + 'lr_scheduler_divide_every_n_epochs': 20, |
| 96 | + 'lr_scheduler_divisor': 5, |
| 97 | + }) |
| 98 | +} |
| 99 | + |
| 100 | +# Set any args that were not explicitly given by the user. |
| 101 | +default_value_dict = MODEL_SPECIFIC_DEFAULTS.get(FLAGS.model, DEFAULT_KWARGS) |
| 102 | +for arg, value in default_value_dict.items(): |
| 103 | + if getattr(FLAGS, arg) is None: |
| 104 | + setattr(FLAGS, arg, value) |
| 105 | + |
| 106 | + |
| 107 | +def get_model_property(key): |
| 108 | + default_model_property = { |
| 109 | + 'img_dim': 224, |
| 110 | + 'model_fn': getattr(torchvision.models, FLAGS.model) |
| 111 | + } |
| 112 | + model_properties = { |
| 113 | + 'inception_v3': { |
| 114 | + 'img_dim': 299, |
| 115 | + 'model_fn': lambda: torchvision.models.inception_v3(aux_logits=False) |
| 116 | + }, |
| 117 | + } |
| 118 | + model_fn = model_properties.get(FLAGS.model, default_model_property)[key] |
| 119 | + return model_fn |
| 120 | + |
| 121 | + |
| 122 | +def _train_update(device, step, loss, tracker, epoch, writer): |
| 123 | + test_utils.print_training_update( |
| 124 | + device, |
| 125 | + step, |
| 126 | + loss.item(), |
| 127 | + tracker.rate(), |
| 128 | + tracker.global_rate(), |
| 129 | + epoch, |
| 130 | + summary_writer=writer) |
| 131 | + |
| 132 | + |
| 133 | +def train_imagenet(): |
| 134 | + print('==> Preparing data..') |
| 135 | + img_dim = get_model_property('img_dim') |
| 136 | + if FLAGS.fake_data: |
| 137 | + train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. |
| 138 | + train_loader = xu.SampleGenerator( |
| 139 | + data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), |
| 140 | + torch.zeros(FLAGS.batch_size, dtype=torch.int64)), |
| 141 | + sample_count=train_dataset_len // FLAGS.batch_size // |
| 142 | + xm.xrt_world_size()) |
| 143 | + test_loader = xu.SampleGenerator( |
| 144 | + data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), |
| 145 | + torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)), |
| 146 | + sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) |
| 147 | + else: |
| 148 | + normalize = transforms.Normalize( |
| 149 | + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| 150 | + train_dataset = torchvision.datasets.ImageFolder( |
| 151 | + os.path.join(FLAGS.datadir, 'train'), |
| 152 | + transforms.Compose([ |
| 153 | + transforms.RandomResizedCrop(img_dim), |
| 154 | + transforms.RandomHorizontalFlip(), |
| 155 | + transforms.ToTensor(), |
| 156 | + normalize, |
| 157 | + ])) |
| 158 | + train_dataset_len = len(train_dataset.imgs) |
| 159 | + resize_dim = max(img_dim, 256) |
| 160 | + test_dataset = torchvision.datasets.ImageFolder( |
| 161 | + os.path.join(FLAGS.datadir, 'val'), |
| 162 | + # Matches Torchvision's eval transforms except Torchvision uses size |
| 163 | + # 256 resize for all models both here and in the train loader. Their |
| 164 | + # version crashes during training on 299x299 images, e.g. inception. |
| 165 | + transforms.Compose([ |
| 166 | + transforms.Resize(resize_dim), |
| 167 | + transforms.CenterCrop(img_dim), |
| 168 | + transforms.ToTensor(), |
| 169 | + normalize, |
| 170 | + ])) |
| 171 | + |
| 172 | + train_sampler, test_sampler = None, None |
| 173 | + if xm.xrt_world_size() > 1: |
| 174 | + train_sampler = torch.utils.data.distributed.DistributedSampler( |
| 175 | + train_dataset, |
| 176 | + num_replicas=xm.xrt_world_size(), |
| 177 | + rank=xm.get_ordinal(), |
| 178 | + shuffle=True) |
| 179 | + test_sampler = torch.utils.data.distributed.DistributedSampler( |
| 180 | + test_dataset, |
| 181 | + num_replicas=xm.xrt_world_size(), |
| 182 | + rank=xm.get_ordinal(), |
| 183 | + shuffle=False) |
| 184 | + train_loader = torch.utils.data.DataLoader( |
| 185 | + train_dataset, |
| 186 | + batch_size=FLAGS.batch_size, |
| 187 | + sampler=train_sampler, |
| 188 | + drop_last=FLAGS.drop_last, |
| 189 | + shuffle=False if train_sampler else True, |
| 190 | + persistent_workers=True, |
| 191 | + num_workers=FLAGS.num_workers) |
| 192 | + test_loader = torch.utils.data.DataLoader( |
| 193 | + test_dataset, |
| 194 | + batch_size=FLAGS.test_set_batch_size, |
| 195 | + sampler=test_sampler, |
| 196 | + drop_last=FLAGS.drop_last, |
| 197 | + shuffle=False, |
| 198 | + persistent_workers=True, |
| 199 | + num_workers=FLAGS.num_workers) |
| 200 | + |
| 201 | + torch.manual_seed(42) |
| 202 | + |
| 203 | + device = xm.xla_device() |
| 204 | + model = get_model_property('model_fn')() |
| 205 | + # Wrap the model with FSDP |
| 206 | + # You may wrap all, a subset, or none of the sub-modules with inner FSDPs |
| 207 | + # - to implement ZeRO-2, wrap none of the sub-modules |
| 208 | + # - to implement ZeRO-3, wrap all of the sub-modules (nested FSDP) |
| 209 | + # - you may wrap sub-modules at different granularity (e.g. at each resnet |
| 210 | + # stage or each residual block or each conv layer). |
| 211 | + fsdp_wrap = lambda m: FSDP( |
| 212 | + m.to(device), flatten_parameters=FLAGS.flatten_parameters) |
| 213 | + # Apply gradient checkpointing to sub-modules if specified |
| 214 | + grad_ckpt_wrap = checkpoint_module if FLAGS.use_gradient_checkpointing else ( |
| 215 | + lambda x: x) |
| 216 | + if FLAGS.use_nested_fsdp: |
| 217 | + # Here we apply inner FSDP at the level of child modules for ZeRO-3, which |
| 218 | + # corresponds to different stages in resnet (i.e. Stage 1 to 5). |
| 219 | + for submodule_name, submodule in model.named_children(): |
| 220 | + if sum(p.numel() for p in submodule.parameters()) == 0: |
| 221 | + # Skip those submodules without parameters (i.e. no need to shard them) |
| 222 | + continue |
| 223 | + # Note: wrap with `checkpoint_module` first BEFORE wrapping with FSDP |
| 224 | + m_fsdp = fsdp_wrap(grad_ckpt_wrap(getattr(model, submodule_name))) |
| 225 | + setattr(model, submodule_name, m_fsdp) |
| 226 | + # Always wrap the base model with an outer FSDP |
| 227 | + model = fsdp_wrap(model) |
| 228 | + |
| 229 | + writer = None |
| 230 | + if xm.is_master_ordinal(): |
| 231 | + writer = test_utils.get_summary_writer(FLAGS.logdir) |
| 232 | + optimizer = optim.SGD( |
| 233 | + model.parameters(), |
| 234 | + lr=FLAGS.lr, |
| 235 | + momentum=FLAGS.momentum, |
| 236 | + weight_decay=1e-4) |
| 237 | + num_training_steps_per_epoch = train_dataset_len // ( |
| 238 | + FLAGS.batch_size * xm.xrt_world_size()) |
| 239 | + lr_scheduler = schedulers.WarmupAndExponentialDecayScheduler( |
| 240 | + optimizer, |
| 241 | + num_steps_per_epoch=num_training_steps_per_epoch, |
| 242 | + divide_every_n_epochs=FLAGS.lr_scheduler_divide_every_n_epochs, |
| 243 | + divisor=FLAGS.lr_scheduler_divisor, |
| 244 | + num_warmup_epochs=FLAGS.num_warmup_epochs, |
| 245 | + summary_writer=writer) |
| 246 | + loss_fn = nn.CrossEntropyLoss() |
| 247 | + |
| 248 | + def train_loop_fn(loader, epoch): |
| 249 | + tracker = xm.RateTracker() |
| 250 | + model.train() |
| 251 | + for step, (data, target) in enumerate(loader): |
| 252 | + optimizer.zero_grad() |
| 253 | + output = model(data) |
| 254 | + loss = loss_fn(output, target) |
| 255 | + loss.backward() |
| 256 | + optimizer.step() # do not reduce gradients on sharded params |
| 257 | + tracker.add(FLAGS.batch_size) |
| 258 | + if lr_scheduler: |
| 259 | + lr_scheduler.step() |
| 260 | + if step % FLAGS.log_steps == 0: |
| 261 | + xm.add_step_closure( |
| 262 | + _train_update, args=(device, step, loss, tracker, epoch, writer)) |
| 263 | + |
| 264 | + def test_loop_fn(loader, epoch): |
| 265 | + total_samples, correct = 0, 0 |
| 266 | + model.eval() |
| 267 | + for step, (data, target) in enumerate(loader): |
| 268 | + output = model(data) |
| 269 | + pred = output.max(1, keepdim=True)[1] |
| 270 | + correct += pred.eq(target.view_as(pred)).sum() |
| 271 | + total_samples += data.size()[0] |
| 272 | + if step % FLAGS.log_steps == 0: |
| 273 | + xm.add_step_closure( |
| 274 | + test_utils.print_test_update, args=(device, None, epoch, step)) |
| 275 | + accuracy = 100.0 * correct.item() / total_samples |
| 276 | + accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) |
| 277 | + return accuracy |
| 278 | + |
| 279 | + train_device_loader = pl.MpDeviceLoader(train_loader, device) |
| 280 | + test_device_loader = pl.MpDeviceLoader(test_loader, device) |
| 281 | + accuracy, max_accuracy = 0.0, 0.0 |
| 282 | + for epoch in range(1, FLAGS.num_epochs + 1): |
| 283 | + xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) |
| 284 | + train_loop_fn(train_device_loader, epoch) |
| 285 | + xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) |
| 286 | + run_eval = ((not FLAGS.test_only_at_end and |
| 287 | + epoch % FLAGS.eval_interval == 0) or epoch == FLAGS.num_epochs) |
| 288 | + if run_eval: |
| 289 | + accuracy = test_loop_fn(test_device_loader, epoch) |
| 290 | + xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( |
| 291 | + epoch, test_utils.now(), accuracy)) |
| 292 | + max_accuracy = max(accuracy, max_accuracy) |
| 293 | + test_utils.write_to_summary( |
| 294 | + writer, |
| 295 | + epoch, |
| 296 | + dict_to_write={'Accuracy/test': accuracy}, |
| 297 | + write_xla_metrics=True) |
| 298 | + if FLAGS.metrics_debug: |
| 299 | + xm.master_print(met.metrics_report()) |
| 300 | + |
| 301 | + test_utils.close_summary_writer(writer) |
| 302 | + xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) |
| 303 | + return max_accuracy |
| 304 | + |
| 305 | + |
| 306 | +def _mp_fn(index, flags): |
| 307 | + global FLAGS |
| 308 | + FLAGS = flags |
| 309 | + torch.set_default_tensor_type('torch.FloatTensor') |
| 310 | + accuracy = train_imagenet() |
| 311 | + if accuracy < FLAGS.target_accuracy: |
| 312 | + print('Accuracy {} is below target {}'.format(accuracy, |
| 313 | + FLAGS.target_accuracy)) |
| 314 | + sys.exit(21) |
| 315 | + |
| 316 | + |
| 317 | +if __name__ == '__main__': |
| 318 | + xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores) |
0 commit comments