diff --git a/examples/research_projects/multi_subject_dreambooth/README.md b/examples/research_projects/multi_subject_dreambooth/README.md index cf7dd31d0797..d1a7705cfebb 100644 --- a/examples/research_projects/multi_subject_dreambooth/README.md +++ b/examples/research_projects/multi_subject_dreambooth/README.md @@ -86,6 +86,53 @@ This example shows training for 2 subjects, but please note that the model can b Note also that in this script, `sks` and `t@y` were used as tokens to learn the new subjects ([this thread](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/issues/71) inspired the use of `t@y` as our second identifier). However, there may be better rare tokens to experiment with, and results also seemed to be good when more intuitive words are used. +**Important**: New parameters are added to the script, making possible to validate the progress of the training by +generating images at specified steps. Taking also into account that a comma separated list in a text field for a prompt +it's never a good idea (simply because it is very common in prompts to have them as part of a regular text) we +introduce the `concept_list` parameter: allowing to specify a json-like file where you can define the different +configuration for each subject that you want to train. + +An example of how to generate the file: +```python +import json + +# here we are using parameters for prior-preservation and validation as well. +concepts_list = [ + { + "instance_prompt": "drawing of a t@y meme", + "class_prompt": "drawing of a meme", + "instance_data_dir": "/some_folder/meme_toy", + "class_data_dir": "/data/meme", + "validation_prompt": "drawing of a t@y meme about football in Uruguay", + "validation_negative_prompt": "black and white" + }, + { + "instance_prompt": "drawing of a sks sir", + "class_prompt": "drawing of a sir", + "instance_data_dir": "/some_other_folder/sir_sks", + "class_data_dir": "/data/sir", + "validation_prompt": "drawing of a sks sir with the Uruguayan sun in his chest", + "validation_negative_prompt": "an old man", + "validation_guidance_scale": 20, + "validation_number_images": 3, + "validation_inference_steps": 10 + } +] + +with open("concepts_list.json", "w") as f: + json.dump(concepts_list, f, indent=4) +``` +And then just point to the file when executing the script: + +```bash +# exports... +accelerate launch train_multi_subject_dreambooth.py \ +# more parameters... +--concepts_list="concepts_list.json" +``` + +You can use the helper from the script to get a better sense of each parameter. + ### Inference Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt. diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index f24c6057fd8c..c75a0a9acc64 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -1,13 +1,18 @@ import argparse import hashlib import itertools +import json import logging import math -import os +import uuid import warnings +from os import environ, listdir, makedirs +from os.path import basename, join from pathlib import Path +from typing import List import datasets +import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -17,24 +22,140 @@ from accelerate.utils import ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from PIL import Image +from torch import dtype +from torch.nn import Module from torch.utils.data import Dataset from torchvision import transforms from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig import diffusers -from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + UNet2DConditionModel, +) from diffusers.optimization import get_scheduler -from diffusers.utils import check_min_version +from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +if is_wandb_available(): + import wandb + # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.13.0.dev0") logger = get_logger(__name__) +def log_validation_images_to_tracker( + images: List[np.array], label: str, validation_prompt: str, accelerator: Accelerator, epoch: int +): + logger.info(f"Logging images to tracker for validation prompt: {validation_prompt}.") + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{label}_{epoch}_{i}: {validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + +# TODO: Add `prompt_embeds` and `negative_prompt_embeds` parameters to the function when `pre_compute_text_embeddings` +# argument is implemented. +def generate_validation_images( + text_encoder: Module, + tokenizer: Module, + unet: Module, + vae: Module, + arguments: argparse.Namespace, + accelerator: Accelerator, + weight_dtype: dtype, +): + logger.info("Running validation images.") + + pipeline_args = {} + + if text_encoder is not None: + pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder) + + if vae is not None: + pipeline_args["vae"] = vae + + # create pipeline (note: unet and vae are loaded again in float32) + pipeline = DiffusionPipeline.from_pretrained( + arguments.pretrained_model_name_or_path, + tokenizer=tokenizer, + unet=accelerator.unwrap_model(unet), + revision=arguments.revision, + torch_dtype=weight_dtype, + **pipeline_args, + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the + # scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + generator = ( + None if arguments.seed is None else torch.Generator(device=accelerator.device).manual_seed(arguments.seed) + ) + + images_sets = [] + for vp, nvi, vnp, vis, vgs in zip( + arguments.validation_prompt, + arguments.validation_number_images, + arguments.validation_negative_prompt, + arguments.validation_inference_steps, + arguments.validation_guidance_scale, + ): + images = [] + if vp is not None: + logger.info( + f"Generating {nvi} images with prompt: '{vp}', negative prompt: '{vnp}', inference steps: {vis}, " + f"guidance scale: {vgs}." + ) + + pipeline_args = {"prompt": vp, "negative_prompt": vnp, "num_inference_steps": vis, "guidance_scale": vgs} + + # run inference + # TODO: it would be good to measure whether it's faster to run inference on all images at once, one at a + # time or in small batches + for _ in range(nvi): + with torch.autocast("cuda"): + image = pipeline(**pipeline_args, num_images_per_prompt=1, generator=generator).images[0] + images.append(image) + + images_sets.append(images) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return images_sets + + def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, @@ -81,7 +202,7 @@ def parse_args(input_args=None): "--instance_data_dir", type=str, default=None, - required=True, + required=False, help="A folder containing the training data of instance images.", ) parser.add_argument( @@ -95,7 +216,7 @@ def parse_args(input_args=None): "--instance_prompt", type=str, default=None, - required=True, + required=False, help="The prompt with identifier specifying the instance", ) parser.add_argument( @@ -272,6 +393,52 @@ def parse_args(input_args=None): ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) + parser.add_argument( + "--validation_steps", + type=int, + default=None, + help=( + "Run validation every X steps. Validation consists of running the prompt(s) `validation_prompt` " + "multiple times (`validation_number_images`) and logging the images." + ), + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning. You can use commas to " + "define multiple negative prompts. This parameter can be defined also within the file given by " + "`concepts_list` parameter in the respective subject.", + ) + parser.add_argument( + "--validation_number_images", + type=int, + default=4, + help="Number of images that should be generated during validation with the validation parameters given. This " + "can be defined within the file given by `concepts_list` parameter in the respective subject.", + ) + parser.add_argument( + "--validation_negative_prompt", + type=str, + default=None, + help="A negative prompt that is used during validation to verify that the model is learning. You can use commas" + " to define multiple negative prompts, each one corresponding to a validation prompt. This parameter can " + "be defined also within the file given by `concepts_list` parameter in the respective subject.", + ) + parser.add_argument( + "--validation_inference_steps", + type=int, + default=25, + help="Number of inference steps (denoising steps) to run during validation. This can be defined within the " + "file given by `concepts_list` parameter in the respective subject.", + ) + parser.add_argument( + "--validation_guidance_scale", + type=float, + default=7.5, + help="To control how much the image generation process follows the text prompt. This can be defined within the " + "file given by `concepts_list` parameter in the respective subject.", + ) parser.add_argument( "--mixed_precision", type=str, @@ -297,27 +464,80 @@ def parse_args(input_args=None): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + parser.add_argument( + "--concepts_list", + type=str, + default=None, + help="Path to json file containing a list of multiple concepts, will overwrite parameters like instance_prompt," + " class_prompt, etc.", + ) - if input_args is not None: + if input_args: args = parser.parse_args(input_args) else: args = parser.parse_args() - env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if not args.concepts_list and (not args.instance_data_dir or not args.instance_prompt): + raise ValueError( + "You must specify either instance parameters (data directory, prompt, etc.) or use " + "the `concept_list` parameter and specify them within the file." + ) + + if args.concepts_list: + if args.instance_prompt: + raise ValueError("If you are using `concepts_list` parameter, define the instance prompt within the file.") + if args.instance_data_dir: + raise ValueError( + "If you are using `concepts_list` parameter, define the instance data directory within the file." + ) + if args.validation_steps and (args.validation_prompt or args.validation_negative_prompt): + raise ValueError( + "If you are using `concepts_list` parameter, define validation parameters for " + "each subject within the file:\n - `validation_prompt`." + "\n - `validation_negative_prompt`.\n - `validation_guidance_scale`." + "\n - `validation_number_images`.\n - `validation_prompt`." + "\n - `validation_inference_steps`.\nThe `validation_steps` parameter is the only one " + "that needs to be defined outside the file." + ) + + env_local_rank = int(environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank if args.with_prior_preservation: - if args.class_data_dir is None: - raise ValueError("You must specify a data directory for class images.") - if args.class_prompt is None: - raise ValueError("You must specify prompt for class images.") + if not args.concepts_list: + if not args.class_data_dir: + raise ValueError("You must specify a data directory for class images.") + if not args.class_prompt: + raise ValueError("You must specify prompt for class images.") + else: + if args.class_data_dir: + raise ValueError( + "If you are using `concepts_list` parameter, define the class data directory within the file." + ) + if args.class_prompt: + raise ValueError( + "If you are using `concepts_list` parameter, define the class prompt within the file." + ) else: # logger is not available yet - if args.class_data_dir is not None: - warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") - if args.class_prompt is not None: - warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + if not args.class_data_dir: + warnings.warn( + "Ignoring `class_data_dir` parameter, you need to use it together with `with_prior_preservation`." + ) + if not args.class_prompt: + warnings.warn( + "Ignoring `class_prompt` parameter, you need to use it together with `with_prior_preservation`." + ) return args @@ -325,7 +545,7 @@ def parse_args(input_args=None): class DreamBoothDataset(Dataset): """ A dataset to prepare the instance and class images with the prompts for fine-tuning the model. - It pre-processes the images and the tokenizes prompts. + It pre-processes the images and then tokenizes prompts. """ def __init__( @@ -346,7 +566,7 @@ def __init__( self.instance_images_path = [] self.num_instance_images = [] self.instance_prompt = [] - self.class_data_root = [] + self.class_data_root = [] if class_data_root is not None else None self.class_images_path = [] self.num_class_images = [] self.class_prompt = [] @@ -371,8 +591,6 @@ def __init__( self._length -= self.num_instance_images[i] self._length += self.num_class_images[i] self.class_prompt.append(class_prompt[i]) - else: - self.class_data_root = None self.image_transforms = transforms.Compose( [ @@ -446,7 +664,7 @@ def collate_fn(num_instances, examples, with_prior_preservation=False): class PromptDataset(Dataset): - "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + """A simple dataset to prepare the prompts to generate class images on multiple GPUs.""" def __init__(self, prompt, num_samples): self.prompt = prompt @@ -474,6 +692,10 @@ def main(args): project_config=accelerator_project_config, ) + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. @@ -483,23 +705,84 @@ def main(args): "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." ) - # Parse instance and class inputs, and double check that lengths match - instance_data_dir = args.instance_data_dir.split(",") - instance_prompt = args.instance_prompt.split(",") - assert all( - x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)] - ), "Instance data dir and prompt inputs are not of the same length." + instance_data_dir = [] + instance_prompt = [] + class_data_dir = [] if args.with_prior_preservation else None + class_prompt = [] if args.with_prior_preservation else None + if args.concepts_list: + with open(args.concepts_list, "r") as f: + concepts_list = json.load(f) + + if args.validation_steps: + args.validation_prompt = [] + args.validation_number_images = [] + args.validation_negative_prompt = [] + args.validation_inference_steps = [] + args.validation_guidance_scale = [] + + for concept in concepts_list: + instance_data_dir.append(concept["instance_data_dir"]) + instance_prompt.append(concept["instance_prompt"]) + + if args.with_prior_preservation: + try: + class_data_dir.append(concept["class_data_dir"]) + class_prompt.append(concept["class_prompt"]) + except KeyError: + raise KeyError( + "`class_data_dir` or `class_prompt` not found in concepts_list while using " + "`with_prior_preservation`." + ) + else: + if "class_data_dir" in concept: + warnings.warn( + "Ignoring `class_data_dir` key, to use it you need to enable `with_prior_preservation`." + ) + if "class_prompt" in concept: + warnings.warn( + "Ignoring `class_prompt` key, to use it you need to enable `with_prior_preservation`." + ) - if args.with_prior_preservation: - class_data_dir = args.class_data_dir.split(",") - class_prompt = args.class_prompt.split(",") - assert all( - x == len(instance_data_dir) - for x in [len(instance_data_dir), len(instance_prompt), len(class_data_dir), len(class_prompt)] - ), "Instance & class data dir or prompt inputs are not of the same length." + if args.validation_steps: + args.validation_prompt.append(concept.get("validation_prompt", None)) + args.validation_number_images.append(concept.get("validation_number_images", 4)) + args.validation_negative_prompt.append(concept.get("validation_negative_prompt", None)) + args.validation_inference_steps.append(concept.get("validation_inference_steps", 25)) + args.validation_guidance_scale.append(concept.get("validation_guidance_scale", 7.5)) else: - class_data_dir = args.class_data_dir - class_prompt = args.class_prompt + # Parse instance and class inputs, and double check that lengths match + instance_data_dir = args.instance_data_dir.split(",") + instance_prompt = args.instance_prompt.split(",") + assert all( + x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)] + ), "Instance data dir and prompt inputs are not of the same length." + + if args.with_prior_preservation: + class_data_dir = args.class_data_dir.split(",") + class_prompt = args.class_prompt.split(",") + assert all( + x == len(instance_data_dir) + for x in [len(instance_data_dir), len(instance_prompt), len(class_data_dir), len(class_prompt)] + ), "Instance & class data dir or prompt inputs are not of the same length." + + if args.validation_steps: + validation_prompts = args.validation_prompt.split(",") + num_of_validation_prompts = len(validation_prompts) + args.validation_prompt = validation_prompts + args.validation_number_images = [args.validation_number_images] * num_of_validation_prompts + + negative_validation_prompts = [None] * num_of_validation_prompts + if args.validation_negative_prompt: + negative_validation_prompts = args.validation_negative_prompt.split(",") + while len(negative_validation_prompts) < num_of_validation_prompts: + negative_validation_prompts.append(None) + args.validation_negative_prompt = negative_validation_prompts + + assert num_of_validation_prompts == len( + negative_validation_prompts + ), "The length of negative prompts for validation is greater than the number of validation prompts." + args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts + args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -559,21 +842,24 @@ def main(args): ): images = pipeline(example["prompt"]).images - for i, image in enumerate(images): + for ii, image in enumerate(images): hash_image = hashlib.sha1(image.tobytes()).hexdigest() image_filename = ( - class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + class_images_dir / f"{example['index'][ii] + cur_class_images}-{hash_image}.jpg" ) image.save(image_filename) + # Clean up the memory deleting one-time-use variables. del pipeline + del sample_dataloader + del sample_dataset if torch.cuda.is_available(): torch.cuda.empty_cache() # Handle the repository creation if accelerator.is_main_process: if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) + makedirs(args.output_dir, exist_ok=True) if args.push_to_hub: repo_id = create_repo( @@ -581,6 +867,7 @@ def main(args): ).repo_id # Load the tokenizer + tokenizer = None if args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) elif args.pretrained_model_name_or_path: @@ -658,7 +945,7 @@ def main(args): train_dataset = DreamBoothDataset( instance_data_root=instance_data_dir, instance_prompt=instance_prompt, - class_data_root=class_data_dir if args.with_prior_preservation else None, + class_data_root=class_data_dir, class_prompt=class_prompt, tokenizer=tokenizer, size=args.resolution, @@ -720,7 +1007,7 @@ def main(args): args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. + # The trackers initialize automatically on the main process. if accelerator.is_main_process: accelerator.init_trackers("dreambooth", config=vars(args)) @@ -741,10 +1028,10 @@ def main(args): # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": - path = os.path.basename(args.resume_from_checkpoint) + path = basename(args.resume_from_checkpoint) else: # Get the mos recent checkpoint - dirs = os.listdir(args.output_dir) + dirs = listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None @@ -756,7 +1043,7 @@ def main(args): args.resume_from_checkpoint = None else: accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) + accelerator.load_state(join(args.output_dir, path)) global_step = int(path.split("-")[1]) resume_global_step = global_step * args.gradient_accumulation_steps @@ -787,24 +1074,26 @@ def main(args): noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) - timesteps = timesteps.long() + time_steps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ) + time_steps = time_steps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(latents, noise, time_steps) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(noisy_latents, time_steps, encoder_hidden_states).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) + target = noise_scheduler.get_velocity(latents, noise, time_steps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") @@ -834,19 +1123,34 @@ def main(args): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() - optimizer.zero_grad() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 - if global_step % args.checkpointing_steps == 0: - if accelerator.is_main_process: - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + save_path = join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") + if ( + args.validation_steps + and any(args.validation_prompt) + and global_step % args.validation_steps == 0 + ): + images_set = generate_validation_images( + text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype + ) + for images, validation_prompt in zip(images_set, args.validation_prompt): + if len(images) > 0: + label = str(uuid.uuid1())[:8] # generate an id for different set of images + log_validation_images_to_tracker( + images, label, validation_prompt, accelerator, global_step + ) + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) @@ -854,7 +1158,7 @@ def main(args): if global_step >= args.max_train_steps: break - # Create the pipeline using using the trained modules and save it. + # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: pipeline = DiffusionPipeline.from_pretrained(