1
1
import argparse
2
+ import copy
2
3
import logging
3
4
import math
4
5
import os
11
12
import torch .nn .functional as F
12
13
import torch .utils .checkpoint
13
14
15
+ import datasets
16
+ import diffusers
17
+ import transformers
14
18
from accelerate import Accelerator
15
19
from accelerate .logging import get_logger
16
20
from accelerate .utils import set_seed
28
32
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
29
33
check_min_version ("0.10.0.dev0" )
30
34
31
- logger = get_logger (__name__ )
35
+ logger = get_logger (__name__ , log_level = "INFO" )
32
36
33
37
34
38
def parse_args ():
@@ -171,7 +175,25 @@ def parse_args():
171
175
parser .add_argument (
172
176
"--use_8bit_adam" , action = "store_true" , help = "Whether or not to use 8-bit Adam from bitsandbytes."
173
177
)
178
+ parser .add_argument (
179
+ "--allow_tf32" ,
180
+ action = "store_true" ,
181
+ help = (
182
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
183
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
184
+ ),
185
+ )
174
186
parser .add_argument ("--use_ema" , action = "store_true" , help = "Whether to use EMA model." )
187
+ parser .add_argument (
188
+ "--non_ema_revision" ,
189
+ type = str ,
190
+ default = None ,
191
+ required = False ,
192
+ help = (
193
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
194
+ " remote repository specified with --pretrained_model_name_or_path."
195
+ ),
196
+ )
175
197
parser .add_argument ("--adam_beta1" , type = float , default = 0.9 , help = "The beta1 parameter for the Adam optimizer." )
176
198
parser .add_argument ("--adam_beta2" , type = float , default = 0.999 , help = "The beta2 parameter for the Adam optimizer." )
177
199
parser .add_argument ("--adam_weight_decay" , type = float , default = 1e-2 , help = "Weight decay to use." )
@@ -247,6 +269,10 @@ def parse_args():
247
269
if args .dataset_name is None and args .train_data_dir is None :
248
270
raise ValueError ("Need either a dataset name or a training folder." )
249
271
272
+ # default to using the same revision for the non-ema model if not specified
273
+ if args .non_ema_revision is None :
274
+ args .non_ema_revision = args .revision
275
+
250
276
return args
251
277
252
278
@@ -275,6 +301,8 @@ def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
275
301
parameters = list (parameters )
276
302
self .shadow_params = [p .clone ().detach () for p in parameters ]
277
303
304
+ self .collected_params = None
305
+
278
306
self .decay = decay
279
307
self .optimization_step = 0
280
308
@@ -322,6 +350,55 @@ def to(self, device=None, dtype=None) -> None:
322
350
for p in self .shadow_params
323
351
]
324
352
353
+ def state_dict (self ) -> dict :
354
+ r"""
355
+ Returns the state of the ExponentialMovingAverage as a dict.
356
+ This method is used by accelerate during checkpointing to save the ema state dict.
357
+ """
358
+ # Following PyTorch conventions, references to tensors are returned:
359
+ # "returns a reference to the state and not its copy!" -
360
+ # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
361
+ return {
362
+ "decay" : self .decay ,
363
+ "optimization_step" : self .optimization_step ,
364
+ "shadow_params" : self .shadow_params ,
365
+ "collected_params" : self .collected_params ,
366
+ }
367
+
368
+ def load_state_dict (self , state_dict : dict ) -> None :
369
+ r"""
370
+ Loads the ExponentialMovingAverage state.
371
+ This method is used by accelerate during checkpointing to save the ema state dict.
372
+ Args:
373
+ state_dict (dict): EMA state. Should be an object returned
374
+ from a call to :meth:`state_dict`.
375
+ """
376
+ # deepcopy, to be consistent with module API
377
+ state_dict = copy .deepcopy (state_dict )
378
+
379
+ self .decay = state_dict ["decay" ]
380
+ if self .decay < 0.0 or self .decay > 1.0 :
381
+ raise ValueError ("Decay must be between 0 and 1" )
382
+
383
+ self .optimization_step = state_dict ["optimization_step" ]
384
+ if not isinstance (self .optimization_step , int ):
385
+ raise ValueError ("Invalid optimization_step" )
386
+
387
+ self .shadow_params = state_dict ["shadow_params" ]
388
+ if not isinstance (self .shadow_params , list ):
389
+ raise ValueError ("shadow_params must be a list" )
390
+ if not all (isinstance (p , torch .Tensor ) for p in self .shadow_params ):
391
+ raise ValueError ("shadow_params must all be Tensors" )
392
+
393
+ self .collected_params = state_dict ["collected_params" ]
394
+ if self .collected_params is not None :
395
+ if not isinstance (self .collected_params , list ):
396
+ raise ValueError ("collected_params must be a list" )
397
+ if not all (isinstance (p , torch .Tensor ) for p in self .collected_params ):
398
+ raise ValueError ("collected_params must all be Tensors" )
399
+ if len (self .collected_params ) != len (self .shadow_params ):
400
+ raise ValueError ("collected_params and shadow_params must have the same length" )
401
+
325
402
326
403
def main ():
327
404
args = parse_args ()
@@ -339,6 +416,15 @@ def main():
339
416
datefmt = "%m/%d/%Y %H:%M:%S" ,
340
417
level = logging .INFO ,
341
418
)
419
+ logger .info (accelerator .state , main_process_only = False )
420
+ if accelerator .is_local_main_process :
421
+ datasets .utils .logging .set_verbosity_warning ()
422
+ transformers .utils .logging .set_verbosity_info ()
423
+ diffusers .utils .logging .set_verbosity_info ()
424
+ else :
425
+ datasets .utils .logging .set_verbosity_error ()
426
+ transformers .utils .logging .set_verbosity_error ()
427
+ diffusers .utils .logging .set_verbosity_error ()
342
428
343
429
# If passed along, set the training seed now.
344
430
if args .seed is not None :
@@ -361,39 +447,44 @@ def main():
361
447
elif args .output_dir is not None :
362
448
os .makedirs (args .output_dir , exist_ok = True )
363
449
364
- # Load models and create wrapper for stable diffusion
450
+ # Load scheduler, tokenizer and models.
451
+ noise_scheduler = DDPMScheduler .from_pretrained (args .pretrained_model_name_or_path , subfolder = "scheduler" )
365
452
tokenizer = CLIPTokenizer .from_pretrained (
366
453
args .pretrained_model_name_or_path , subfolder = "tokenizer" , revision = args .revision
367
454
)
368
455
text_encoder = CLIPTextModel .from_pretrained (
369
- args .pretrained_model_name_or_path ,
370
- subfolder = "text_encoder" ,
371
- revision = args .revision ,
372
- )
373
- vae = AutoencoderKL .from_pretrained (
374
- args .pretrained_model_name_or_path ,
375
- subfolder = "vae" ,
376
- revision = args .revision ,
456
+ args .pretrained_model_name_or_path , subfolder = "text_encoder" , revision = args .revision
377
457
)
458
+ vae = AutoencoderKL .from_pretrained (args .pretrained_model_name_or_path , subfolder = "vae" , revision = args .revision )
378
459
unet = UNet2DConditionModel .from_pretrained (
379
- args .pretrained_model_name_or_path ,
380
- subfolder = "unet" ,
381
- revision = args .revision ,
460
+ args .pretrained_model_name_or_path , subfolder = "unet" , revision = args .non_ema_revision
382
461
)
383
462
463
+ # Freeze vae and text_encoder
464
+ vae .requires_grad_ (False )
465
+ text_encoder .requires_grad_ (False )
466
+
467
+ # Create EMA for the unet.
468
+ if args .use_ema :
469
+ ema_unet = UNet2DConditionModel .from_pretrained (
470
+ args .pretrained_model_name_or_path , subfolder = "unet" , revision = args .revision
471
+ )
472
+ ema_unet = EMAModel (ema_unet .parameters ())
473
+
384
474
if args .enable_xformers_memory_efficient_attention :
385
475
if is_xformers_available ():
386
476
unet .enable_xformers_memory_efficient_attention ()
387
477
else :
388
478
raise ValueError ("xformers is not available. Make sure it is installed correctly" )
389
479
390
- # Freeze vae and text_encoder
391
- vae .requires_grad_ (False )
392
- text_encoder .requires_grad_ (False )
393
-
394
480
if args .gradient_checkpointing :
395
481
unet .enable_gradient_checkpointing ()
396
482
483
+ # Enable TF32 for faster training on Ampere GPUs,
484
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
485
+ if args .allow_tf32 :
486
+ torch .backends .cuda .matmul .allow_tf32 = True
487
+
397
488
if args .scale_lr :
398
489
args .learning_rate = (
399
490
args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * accelerator .num_processes
@@ -419,7 +510,6 @@ def main():
419
510
weight_decay = args .adam_weight_decay ,
420
511
eps = args .adam_epsilon ,
421
512
)
422
- noise_scheduler = DDPMScheduler .from_pretrained (args .pretrained_model_name_or_path , subfolder = "scheduler" )
423
513
424
514
# Get the datasets: you can either provide your own training and evaluation files (see below)
425
515
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
@@ -482,9 +572,10 @@ def tokenize_captions(examples, is_train=True):
482
572
raise ValueError (
483
573
f"Caption column `{ caption_column } ` should contain either strings or lists of strings."
484
574
)
485
- inputs = tokenizer (captions , max_length = tokenizer .model_max_length , padding = "do_not_pad" , truncation = True )
486
- input_ids = inputs .input_ids
487
- return input_ids
575
+ inputs = tokenizer (
576
+ captions , max_length = tokenizer .model_max_length , padding = "max_length" , truncation = True , return_tensors = "pt"
577
+ )
578
+ return inputs .input_ids
488
579
489
580
train_transforms = transforms .Compose (
490
581
[
@@ -500,7 +591,6 @@ def preprocess_train(examples):
500
591
images = [image .convert ("RGB" ) for image in examples [image_column ]]
501
592
examples ["pixel_values" ] = [train_transforms (image ) for image in images ]
502
593
examples ["input_ids" ] = tokenize_captions (examples )
503
-
504
594
return examples
505
595
506
596
with accelerator .main_process_first ():
@@ -512,13 +602,8 @@ def preprocess_train(examples):
512
602
def collate_fn (examples ):
513
603
pixel_values = torch .stack ([example ["pixel_values" ] for example in examples ])
514
604
pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).float ()
515
- input_ids = [example ["input_ids" ] for example in examples ]
516
- padded_tokens = tokenizer .pad ({"input_ids" : input_ids }, padding = True , return_tensors = "pt" )
517
- return {
518
- "pixel_values" : pixel_values ,
519
- "input_ids" : padded_tokens .input_ids ,
520
- "attention_mask" : padded_tokens .attention_mask ,
521
- }
605
+ input_ids = torch .stack ([example ["input_ids" ] for example in examples ])
606
+ return {"pixel_values" : pixel_values , "input_ids" : input_ids }
522
607
523
608
train_dataloader = torch .utils .data .DataLoader (
524
609
train_dataset , shuffle = True , collate_fn = collate_fn , batch_size = args .train_batch_size
@@ -541,23 +626,22 @@ def collate_fn(examples):
541
626
unet , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
542
627
unet , optimizer , train_dataloader , lr_scheduler
543
628
)
544
- accelerator .register_for_checkpointing (lr_scheduler )
629
+ if args .use_ema :
630
+ accelerator .register_for_checkpointing (ema_unet )
545
631
632
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
633
+ # as these models are only used for inference, keeping weights in full precision is not required.
546
634
weight_dtype = torch .float32
547
635
if accelerator .mixed_precision == "fp16" :
548
636
weight_dtype = torch .float16
549
637
elif accelerator .mixed_precision == "bf16" :
550
638
weight_dtype = torch .bfloat16
551
639
552
- # Move text_encode and vae to gpu.
553
- # For mixed precision training we cast the text_encoder and vae weights to half-precision
554
- # as these models are only used for inference, keeping weights in full precision is not required.
640
+ # Move text_encode and vae to gpu and cast to weight_dtype
555
641
text_encoder .to (accelerator .device , dtype = weight_dtype )
556
642
vae .to (accelerator .device , dtype = weight_dtype )
557
-
558
- # Create EMA for the unet.
559
643
if args .use_ema :
560
- ema_unet = EMAModel ( unet . parameters () )
644
+ ema_unet . to ( accelerator . device )
561
645
562
646
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
563
647
num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
0 commit comments