Skip to content

Commit 62608a9

Browse files
patil-surajpcuenca
andauthored
[train_text_to_image] allow using non-ema weights for training (#1834)
* allow using non-ema weights for training * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * address more review comment * reorganise a few lines * always pad text to max_length to match original training * ifx collate_fn * remove unused code * don't prepare ema_unet, don't register lr scheduler * style * assert => ValueError * add allow_tf32 * set log level * fix comment Co-authored-by: Pedro Cuenca <[email protected]>
1 parent e4fe941 commit 62608a9

File tree

1 file changed

+120
-36
lines changed

1 file changed

+120
-36
lines changed

Diff for: examples/text_to_image/train_text_to_image.py

+120-36
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import copy
23
import logging
34
import math
45
import os
@@ -11,6 +12,9 @@
1112
import torch.nn.functional as F
1213
import torch.utils.checkpoint
1314

15+
import datasets
16+
import diffusers
17+
import transformers
1418
from accelerate import Accelerator
1519
from accelerate.logging import get_logger
1620
from accelerate.utils import set_seed
@@ -28,7 +32,7 @@
2832
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
2933
check_min_version("0.10.0.dev0")
3034

31-
logger = get_logger(__name__)
35+
logger = get_logger(__name__, log_level="INFO")
3236

3337

3438
def parse_args():
@@ -171,7 +175,25 @@ def parse_args():
171175
parser.add_argument(
172176
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
173177
)
178+
parser.add_argument(
179+
"--allow_tf32",
180+
action="store_true",
181+
help=(
182+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
183+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
184+
),
185+
)
174186
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
187+
parser.add_argument(
188+
"--non_ema_revision",
189+
type=str,
190+
default=None,
191+
required=False,
192+
help=(
193+
"Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
194+
" remote repository specified with --pretrained_model_name_or_path."
195+
),
196+
)
175197
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
176198
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
177199
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
@@ -247,6 +269,10 @@ def parse_args():
247269
if args.dataset_name is None and args.train_data_dir is None:
248270
raise ValueError("Need either a dataset name or a training folder.")
249271

272+
# default to using the same revision for the non-ema model if not specified
273+
if args.non_ema_revision is None:
274+
args.non_ema_revision = args.revision
275+
250276
return args
251277

252278

@@ -275,6 +301,8 @@ def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
275301
parameters = list(parameters)
276302
self.shadow_params = [p.clone().detach() for p in parameters]
277303

304+
self.collected_params = None
305+
278306
self.decay = decay
279307
self.optimization_step = 0
280308

@@ -322,6 +350,55 @@ def to(self, device=None, dtype=None) -> None:
322350
for p in self.shadow_params
323351
]
324352

353+
def state_dict(self) -> dict:
354+
r"""
355+
Returns the state of the ExponentialMovingAverage as a dict.
356+
This method is used by accelerate during checkpointing to save the ema state dict.
357+
"""
358+
# Following PyTorch conventions, references to tensors are returned:
359+
# "returns a reference to the state and not its copy!" -
360+
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
361+
return {
362+
"decay": self.decay,
363+
"optimization_step": self.optimization_step,
364+
"shadow_params": self.shadow_params,
365+
"collected_params": self.collected_params,
366+
}
367+
368+
def load_state_dict(self, state_dict: dict) -> None:
369+
r"""
370+
Loads the ExponentialMovingAverage state.
371+
This method is used by accelerate during checkpointing to save the ema state dict.
372+
Args:
373+
state_dict (dict): EMA state. Should be an object returned
374+
from a call to :meth:`state_dict`.
375+
"""
376+
# deepcopy, to be consistent with module API
377+
state_dict = copy.deepcopy(state_dict)
378+
379+
self.decay = state_dict["decay"]
380+
if self.decay < 0.0 or self.decay > 1.0:
381+
raise ValueError("Decay must be between 0 and 1")
382+
383+
self.optimization_step = state_dict["optimization_step"]
384+
if not isinstance(self.optimization_step, int):
385+
raise ValueError("Invalid optimization_step")
386+
387+
self.shadow_params = state_dict["shadow_params"]
388+
if not isinstance(self.shadow_params, list):
389+
raise ValueError("shadow_params must be a list")
390+
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
391+
raise ValueError("shadow_params must all be Tensors")
392+
393+
self.collected_params = state_dict["collected_params"]
394+
if self.collected_params is not None:
395+
if not isinstance(self.collected_params, list):
396+
raise ValueError("collected_params must be a list")
397+
if not all(isinstance(p, torch.Tensor) for p in self.collected_params):
398+
raise ValueError("collected_params must all be Tensors")
399+
if len(self.collected_params) != len(self.shadow_params):
400+
raise ValueError("collected_params and shadow_params must have the same length")
401+
325402

326403
def main():
327404
args = parse_args()
@@ -339,6 +416,15 @@ def main():
339416
datefmt="%m/%d/%Y %H:%M:%S",
340417
level=logging.INFO,
341418
)
419+
logger.info(accelerator.state, main_process_only=False)
420+
if accelerator.is_local_main_process:
421+
datasets.utils.logging.set_verbosity_warning()
422+
transformers.utils.logging.set_verbosity_info()
423+
diffusers.utils.logging.set_verbosity_info()
424+
else:
425+
datasets.utils.logging.set_verbosity_error()
426+
transformers.utils.logging.set_verbosity_error()
427+
diffusers.utils.logging.set_verbosity_error()
342428

343429
# If passed along, set the training seed now.
344430
if args.seed is not None:
@@ -361,39 +447,44 @@ def main():
361447
elif args.output_dir is not None:
362448
os.makedirs(args.output_dir, exist_ok=True)
363449

364-
# Load models and create wrapper for stable diffusion
450+
# Load scheduler, tokenizer and models.
451+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
365452
tokenizer = CLIPTokenizer.from_pretrained(
366453
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
367454
)
368455
text_encoder = CLIPTextModel.from_pretrained(
369-
args.pretrained_model_name_or_path,
370-
subfolder="text_encoder",
371-
revision=args.revision,
372-
)
373-
vae = AutoencoderKL.from_pretrained(
374-
args.pretrained_model_name_or_path,
375-
subfolder="vae",
376-
revision=args.revision,
456+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
377457
)
458+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
378459
unet = UNet2DConditionModel.from_pretrained(
379-
args.pretrained_model_name_or_path,
380-
subfolder="unet",
381-
revision=args.revision,
460+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
382461
)
383462

463+
# Freeze vae and text_encoder
464+
vae.requires_grad_(False)
465+
text_encoder.requires_grad_(False)
466+
467+
# Create EMA for the unet.
468+
if args.use_ema:
469+
ema_unet = UNet2DConditionModel.from_pretrained(
470+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
471+
)
472+
ema_unet = EMAModel(ema_unet.parameters())
473+
384474
if args.enable_xformers_memory_efficient_attention:
385475
if is_xformers_available():
386476
unet.enable_xformers_memory_efficient_attention()
387477
else:
388478
raise ValueError("xformers is not available. Make sure it is installed correctly")
389479

390-
# Freeze vae and text_encoder
391-
vae.requires_grad_(False)
392-
text_encoder.requires_grad_(False)
393-
394480
if args.gradient_checkpointing:
395481
unet.enable_gradient_checkpointing()
396482

483+
# Enable TF32 for faster training on Ampere GPUs,
484+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
485+
if args.allow_tf32:
486+
torch.backends.cuda.matmul.allow_tf32 = True
487+
397488
if args.scale_lr:
398489
args.learning_rate = (
399490
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
@@ -419,7 +510,6 @@ def main():
419510
weight_decay=args.adam_weight_decay,
420511
eps=args.adam_epsilon,
421512
)
422-
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
423513

424514
# Get the datasets: you can either provide your own training and evaluation files (see below)
425515
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
@@ -482,9 +572,10 @@ def tokenize_captions(examples, is_train=True):
482572
raise ValueError(
483573
f"Caption column `{caption_column}` should contain either strings or lists of strings."
484574
)
485-
inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True)
486-
input_ids = inputs.input_ids
487-
return input_ids
575+
inputs = tokenizer(
576+
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
577+
)
578+
return inputs.input_ids
488579

489580
train_transforms = transforms.Compose(
490581
[
@@ -500,7 +591,6 @@ def preprocess_train(examples):
500591
images = [image.convert("RGB") for image in examples[image_column]]
501592
examples["pixel_values"] = [train_transforms(image) for image in images]
502593
examples["input_ids"] = tokenize_captions(examples)
503-
504594
return examples
505595

506596
with accelerator.main_process_first():
@@ -512,13 +602,8 @@ def preprocess_train(examples):
512602
def collate_fn(examples):
513603
pixel_values = torch.stack([example["pixel_values"] for example in examples])
514604
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
515-
input_ids = [example["input_ids"] for example in examples]
516-
padded_tokens = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt")
517-
return {
518-
"pixel_values": pixel_values,
519-
"input_ids": padded_tokens.input_ids,
520-
"attention_mask": padded_tokens.attention_mask,
521-
}
605+
input_ids = torch.stack([example["input_ids"] for example in examples])
606+
return {"pixel_values": pixel_values, "input_ids": input_ids}
522607

523608
train_dataloader = torch.utils.data.DataLoader(
524609
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size
@@ -541,23 +626,22 @@ def collate_fn(examples):
541626
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
542627
unet, optimizer, train_dataloader, lr_scheduler
543628
)
544-
accelerator.register_for_checkpointing(lr_scheduler)
629+
if args.use_ema:
630+
accelerator.register_for_checkpointing(ema_unet)
545631

632+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
633+
# as these models are only used for inference, keeping weights in full precision is not required.
546634
weight_dtype = torch.float32
547635
if accelerator.mixed_precision == "fp16":
548636
weight_dtype = torch.float16
549637
elif accelerator.mixed_precision == "bf16":
550638
weight_dtype = torch.bfloat16
551639

552-
# Move text_encode and vae to gpu.
553-
# For mixed precision training we cast the text_encoder and vae weights to half-precision
554-
# as these models are only used for inference, keeping weights in full precision is not required.
640+
# Move text_encode and vae to gpu and cast to weight_dtype
555641
text_encoder.to(accelerator.device, dtype=weight_dtype)
556642
vae.to(accelerator.device, dtype=weight_dtype)
557-
558-
# Create EMA for the unet.
559643
if args.use_ema:
560-
ema_unet = EMAModel(unet.parameters())
644+
ema_unet.to(accelerator.device)
561645

562646
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
563647
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

0 commit comments

Comments
 (0)