From 2116de29663b31df7b2aef5326151ab29ec3a3a8 Mon Sep 17 00:00:00 2001 From: vhakobyan Date: Fri, 9 Feb 2024 09:32:08 +0000 Subject: [PATCH 01/14] wip: training script --- examples/inpainting/README.md | 326 ++++++ examples/inpainting/requirements.txt | 8 + examples/inpainting/train_inpainting.py | 1216 +++++++++++++++++++++++ 3 files changed, 1550 insertions(+) create mode 100644 examples/inpainting/README.md create mode 100644 examples/inpainting/requirements.txt create mode 100644 examples/inpainting/train_inpainting.py diff --git a/examples/inpainting/README.md b/examples/inpainting/README.md new file mode 100644 index 000000000000..a56cccbcf5d7 --- /dev/null +++ b/examples/inpainting/README.md @@ -0,0 +1,326 @@ +# Stable Diffusion text-to-image fine-tuning + +The `train_text_to_image.py` script shows how to fine-tune stable diffusion model on your own dataset. + +___Note___: + +___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___ + + +## Running locally with PyTorch +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install . +``` + +Then cd in the example folder and run +```bash +pip install -r requirements.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. + +### Pokemon example + +You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree. + +You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens). + +Run the following command to authenticate your token + +```bash +huggingface-cli login +``` + +If you have already cloned the repo, then you won't need to go through these steps. + +
+ +#### Hardware +With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory. + +**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export DATASET_NAME="lambdalabs/pokemon-blip-captions" + +accelerate launch --mixed_precision="fp16" train_text_to_image.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_NAME \ + --use_ema \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --output_dir="sd-pokemon-model" +``` + + + +To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata). +If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script. + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export TRAIN_DIR="path_to_your_dataset" + +accelerate launch --mixed_precision="fp16" train_text_to_image.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$TRAIN_DIR \ + --use_ema \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --output_dir="sd-pokemon-model" +``` + + +Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline` + +```python +import torch +from diffusers import StableDiffusionPipeline + +model_path = "path_to_saved_model" +pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) +pipe.to("cuda") + +image = pipe(prompt="yoda").images[0] +image.save("yoda-pokemon.png") +``` + +Checkpoints only save the unet, so to run inference from a checkpoint, just load the unet + +```python +import torch +from diffusers import StableDiffusionPipeline, UNet2DConditionModel + +model_path = "path_to_saved_model" +unet = UNet2DConditionModel.from_pretrained(model_path + "/checkpoint-/unet", torch_dtype=torch.float16) + +pipe = StableDiffusionPipeline.from_pretrained("", unet=unet, torch_dtype=torch.float16) +pipe.to("cuda") + +image = pipe(prompt="yoda").images[0] +image.save("yoda-pokemon.png") +``` + +#### Training with multiple GPUs + +`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch) +for running distributed training with `accelerate`. Here is an example command: + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export DATASET_NAME="lambdalabs/pokemon-blip-captions" + +accelerate launch --mixed_precision="fp16" --multi_gpu train_text_to_image.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_NAME \ + --use_ema \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --output_dir="sd-pokemon-model" +``` + + +#### Training with Min-SNR weighting + +We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence +by rebalancing the loss. In order to use it, one needs to set the `--snr_gamma` argument. The recommended +value when using it is 5.0. + +You can find [this project on Weights and Biases](https://wandb.ai/sayakpaul/text2image-finetune-minsnr) that compares the loss surfaces of the following setups: + +* Training without the Min-SNR weighting strategy +* Training with the Min-SNR weighting strategy (`snr_gamma` set to 5.0) +* Training with the Min-SNR weighting strategy (`snr_gamma` set to 1.0) + +For our small Pokemons dataset, the effects of Min-SNR weighting strategy might not appear to be pronounced, but for larger datasets, we believe the effects will be more pronounced. + +Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds. + +## Training with LoRA + +Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. + +In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages: + +- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114). +- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable. +- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter. + +[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository. + +With LoRA, it's possible to fine-tune Stable Diffusion on a custom image-caption pair dataset +on consumer GPUs like Tesla T4, Tesla V100. + +### Training + +First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Stable Diffusion v1-4](https://hf.co/CompVis/stable-diffusion-v1-4) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions). + +**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** + +**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [Weights and Biases](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___** + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export DATASET_NAME="lambdalabs/pokemon-blip-captions" +``` + +For this example we want to directly store the trained LoRA embeddings on the Hub, so +we need to be logged in and add the `--push_to_hub` flag. + +```bash +huggingface-cli login +``` + +Now we can start training! + +```bash +accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_NAME --caption_column="text" \ + --resolution=512 --random_flip \ + --train_batch_size=1 \ + --num_train_epochs=100 --checkpointing_steps=5000 \ + --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \ + --seed=42 \ + --output_dir="sd-pokemon-model-lora" \ + --validation_prompt="cute dragon creature" --report_to="wandb" +``` + +The above command will also run inference as fine-tuning progresses and log the results to Weights and Biases. + +**___Note: When using LoRA we can use a much higher learning rate compared to non-LoRA fine-tuning. Here we use *1e-4* instead of the usual *1e-5*. Also, by using LoRA, it's possible to run `train_text_to_image_lora.py` in consumer GPUs like T4 or V100.___** + +The final LoRA embedding weights have been uploaded to [sayakpaul/sd-model-finetuned-lora-t4](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4). **___Note: [The final weights](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/pytorch_lora_weights.bin) are only 3 MB in size, which is orders of magnitudes smaller than the original model.___** + +You can check some inference samples that were logged during the course of the fine-tuning process [here](https://wandb.ai/sayakpaul/text2image-fine-tune/runs/q4lc0xsw). + +### Inference + +Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline` after loading the trained LoRA weights. You +need to pass the `output_dir` for loading the LoRA weights which, in this case, is `sd-pokemon-model-lora`. + +```python +from diffusers import StableDiffusionPipeline +import torch + +model_path = "sayakpaul/sd-model-finetuned-lora-t4" +pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16) +pipe.unet.load_attn_procs(model_path) +pipe.to("cuda") + +prompt = "A pokemon with green eyes and red legs." +image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] +image.save("pokemon.png") +``` + +If you are loading the LoRA parameters from the Hub and if the Hub repository has +a `base_model` tag (such as [this](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/README.md?code=true#L4)), then +you can do: + +```py +from huggingface_hub.repocard import RepoCard + +lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4" +card = RepoCard.load(lora_model_id) +base_model_id = card.data.to_dict()["base_model"] + +pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16) +... +``` + +## Training with Flax/JAX + +For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script. + +**___Note: The flax example doesn't yet support features like gradient checkpoint, gradient accumulation etc, so to use flax for faster training we will need >30GB cards or TPU v3.___** + + +Before running the scripts, make sure to install the library's training dependencies: + +```bash +pip install -U -r requirements_flax.txt +``` + +```bash +export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" +export DATASET_NAME="lambdalabs/pokemon-blip-captions" + +python train_text_to_image_flax.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_NAME \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 \ + --mixed_precision="fp16" \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --output_dir="sd-pokemon-model" +``` + +To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata). +If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script. + +```bash +export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" +export TRAIN_DIR="path_to_your_dataset" + +python train_text_to_image_flax.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$TRAIN_DIR \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 \ + --mixed_precision="fp16" \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --output_dir="sd-pokemon-model" +``` + +### Training with xFormers: + +You can enable memory efficient attention by [installing xFormers](https://huggingface.co/docs/diffusers/main/en/optimization/xformers) and passing the `--enable_xformers_memory_efficient_attention` argument to the script. + +xFormers training is not available for Flax/JAX. + +**Note**: + +According to [this issue](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212), xFormers `v0.0.16` cannot be used for training in some GPUs. If you observe that problem, please install a development version as indicated in that comment. + +## Stable Diffusion XL + +* We support fine-tuning the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) via the `train_text_to_image_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md). +* We also support fine-tuning of the UNet and Text Encoder shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with LoRA via the `train_text_to_image_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md). diff --git a/examples/inpainting/requirements.txt b/examples/inpainting/requirements.txt new file mode 100644 index 000000000000..0dd164fc2035 --- /dev/null +++ b/examples/inpainting/requirements.txt @@ -0,0 +1,8 @@ +accelerate>=0.16.0 +torchvision +transformers>=4.25.1 +datasets +ftfy +tensorboard +Jinja2 +peft==0.7.0 \ No newline at end of file diff --git a/examples/inpainting/train_inpainting.py b/examples/inpainting/train_inpainting.py new file mode 100644 index 000000000000..0d52380a61f6 --- /dev/null +++ b/examples/inpainting/train_inpainting.py @@ -0,0 +1,1216 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import logging +import math +import os +import random +import shutil +from pathlib import Path + +import accelerate +import datasets +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer +from transformers.utils import ContextManagers +from functools import partial + +import diffusers +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel, compute_snr +from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module + + + +if is_wandb_available(): + import wandb + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.26.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + +DATASET_NAME_MAPPING = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), +} + + +def save_model_card( + args, + repo_id: str, + images=None, + prompts=None, + masks=None, + repo_folder=None, +): + img_str = "" + if len(images) > 0: + image_grid = make_image_grid(images+masks, 2, args.validation_size) + image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png")) + img_str += "![val_imgs_grid](./val_imgs_grid.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {args.pretrained_model_name_or_path} +datasets: +- {args.dataset_name} +tags: +- stable-diffusion +- stable-diffusion-diffusers +- text-to-image +- diffusers +inference: true +--- + """ + model_card = f""" +# Text-to-image finetuning - {repo_id} + +This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {prompts}: \n +{img_str} + +## Pipeline usage + +You can use the pipeline like so: + +```python +from diffusers import DiffusionPipeline +import torch + +pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16) +prompt = "{prompts[0]}" +image = pipeline(prompt).images[0] +image.save("my_image.png") +``` + +## Training info + +These are the key hyperparameters used during training: + +* Epochs: {args.num_train_epochs} +* Learning rate: {args.learning_rate} +* Batch size: {args.train_batch_size} +* Gradient accumulation steps: {args.gradient_accumulation_steps} +* Image resolution: {args.resolution} +* Mixed-precision: {args.mixed_precision} + +""" + wandb_info = "" + if is_wandb_available(): + wandb_run_url = None + if wandb.run is not None: + wandb_run_url = wandb.run.url + + if wandb_run_url is not None: + wandb_info = f""" +More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}). +""" + + model_card += wandb_info + + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def log_validation(validation_dataloader, vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch): + logger.info("Running validation... ") + + pipeline = StableDiffusionInpaintPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=accelerator.unwrap_model(vae), + text_encoder=accelerator.unwrap_model(text_encoder), + tokenizer=tokenizer, + unet=accelerator.unwrap_model(unet), + safety_checker=None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + images = [] + prompts = [] + masks = [] + image_transform = transforms.ToPILImage() + for _, batch in enumerate(validation_dataloader): + + mask = image_transform(batch['masks'][0]*255) + init_image= image_transform(batch['pixel_values'][0]) + prompt = batch['prompts'][0] + + with torch.autocast("cuda"): + #### UPDATE PIPELINE HERE + image = pipeline( + prompt, + image=init_image, + mask_image=mask, + num_inference_steps=20, + generator=generator + ).images[0] + + prompts.append(prompt) + images.append(image) + masks.append(mask) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + elif tracker.name == "wandb": + + tracker.log( + { + "validation": [ + wandb.Image(make_image_grid([image, masks[i]], 1, 2), caption=f"{i}: {prompts[i]}") + for i, image in enumerate(images) + ] + } + ) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + del pipeline + torch.cuda.empty_cache() + + return images + + +def create_custom_binary_mask_torch(total_size, max_square_size, probability_threshold, center_mask=False): + """ + Creates a binary mask tensor based on provided parameters using PyTorch. + The output tensor has values of 0 or 1, where 1 represents the masked area. + + :param total_size: The size of the sides of the square black background image. + :param max_square_size: The maximum size of the sides of the white square. + :param probability_threshold: The threshold for deciding between a completely white mask and a random mask. + :param center_mask: If True, generates a centered white square mask 3/4 the size of total_size. + :return: A PyTorch tensor representing the binary mask. + """ + # Create a new black image (mask) as a PyTorch tensor + mask = torch.zeros(total_size, total_size, dtype=torch.uint8) + + if center_mask: + # Create a centered white square mask 3/4 the size of total_size + center_square_size = int(total_size * 3 / 4) + start = (total_size - center_square_size) // 2 + end = start + center_square_size + mask[start:end, start:end] = 1 + else: + # Generate a random number between 0 and 1 + random_number = random.random() + + if random_number < probability_threshold: + # Create a completely white (binary 1) mask + mask.fill_(1) + else: + # Create a mask with a random white square + square_size = random.randint(max_square_size // 2, max_square_size) + max_val = total_size - square_size + top_left_x = random.randint(0, max_val) + top_left_y = random.randint(0, max_val) + mask[top_left_y:top_left_y + square_size, top_left_x:top_left_x + square_size] = 1 + + return mask + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_size", + type=int, + default=0, + help=("Number of images to be evaluated `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + +def main(): + args = parse_args() + + if args.non_ema_revision is not None: + deprecate( + "non_ema_revision!=None", + "0.15.0", + message=( + "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + " use `--variant=non_ema` instead." + ), + ) + logging_dir = os.path.join(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load scheduler, tokenizer and models. + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + ) + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. + # For this to work properly all models must be run through `accelerate.prepare`. But accelerate + # will try to assign the same optimizer with the same weights to all models during + # `deepspeed.initialize`, which of course doesn't work. + # + # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 + # frozen models from being partitioned during `zero.Init` which gets called during + # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding + # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant + ) + + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision + ) + + # Freeze vae and text_encoder and set unet to trainable + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.train() + + # Create EMA for the unet. + if args.use_ema: + ema_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + ) + ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + if args.use_ema: + ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + ema_unet.load_state_dict(load_model.state_dict()) + ema_unet.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + optimizer = optimizer_cls( + unet.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + + val_dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + data_dir=args.train_data_dir, + split=datasets.ReadInstruction('train', to=args.validation_size, unit='abs') + ) + train_dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + data_dir=args.train_data_dir, + split=datasets.ReadInstruction('train', from_=args.validation_size, unit='abs') + ) + else: + raise ValueError("Current Implementation supports only datasets from the hub.") + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = train_dataset.column_names + + # 6. Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Preprocessing the datasets. + # We need to tokenize input captions and transform the images. + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + return inputs.input_ids + + # Preprocessing the datasets. + train_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + val_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor() + ] + ) + + def preprocess_dataset_train(examples, is_train=True): + + images = [image.convert("RGB") for image in examples[image_column]] + if is_train: + examples["pixel_values"] = [train_transforms(image) for image in images] + else: + examples["pixel_values"] = [val_transforms(image) for image in images] + examples["input_ids"] = tokenize_captions(examples) + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = train_dataset.with_transform(partial(preprocess_dataset_train, is_train=True)) + val_dataset = val_dataset.with_transform(partial(preprocess_dataset_train, is_train=False)) + + def collate_fn(examples, center_mask): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + input_ids = torch.stack([example["input_ids"] for example in examples]) + + masks = torch.stack([ + create_custom_binary_mask_torch( + pixel_values.shape[2], + int(3/4*pixel_values.shape[2]), + 0.25, + center_mask + ) for ix in range(pixel_values.shape[0]) + ]).unsqueeze(1) + + prompts = [example[caption_column] for example in examples] + return { + "pixel_values": pixel_values, + "input_ids": input_ids, + "masks": masks, + "prompts": prompts + } + + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=partial(collate_fn, center_mask=False), + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + val_dataloader = torch.utils.data.DataLoader( + val_dataset, + shuffle=False, + collate_fn=partial(collate_fn, center_mask=True), + batch_size=1, + num_workers=0, + ) + + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + # Prepare everything with our `accelerator`. + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + if args.use_ema: + ema_unet.to(accelerator.device) + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + args.mixed_precision = accelerator.mixed_precision + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + args.mixed_precision = accelerator.mixed_precision + + # Move text_encode and vae to gpu and cast to weight_dtype + text_encoder.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + accelerator.init_trackers(args.tracker_project_name, tracker_config) + + # Function for unwrapping if model was compiled with `torch.compile`. + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + vae_downscaling_factor = 2 ** (len(vae.config.block_out_channels) - 1) + + for epoch in range(first_epoch, args.num_train_epochs): + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + # Convert images to latent space + pixel_values = batch["pixel_values"].to(weight_dtype) + + latents = vae.encode(pixel_values).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + mask_orig = batch["masks"].to(weight_dtype) + + masked_latents = vae.encode(pixel_values * (1.0 - mask_orig)).latent_dist.sample() + masked_latents = masked_latents * vae.config.scaling_factor + + mask = torch.nn.functional.interpolate( + mask_orig, + size=( + int(mask_orig.shape[3] // vae_downscaling_factor), + int(mask_orig.shape[2] // vae_downscaling_factor) + ) + ).to(dtype=weight_dtype, device=accelerator.device) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1), device=latents.device + ) + if args.input_perturbation: + new_noise = noise + args.input_perturbation * torch.randn_like(noise) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.input_perturbation: + noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) + else: + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0] + + # Get the target for loss depending on the prediction type + if args.prediction_type is not None: + # set prediction_type of scheduler if defined + noise_scheduler.register_to_config(prediction_type=args.prediction_type) + + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + # Predict the noise residual and compute loss + combined_latents = torch.cat([noisy_latents, mask, masked_latents], dim=1) + model_pred = unet(combined_latents, timesteps, encoder_hidden_states, return_dict=False)[0] + + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(noise_scheduler, timesteps) + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] + if noise_scheduler.config.prediction_type == "epsilon": + mse_loss_weights = mse_loss_weights / snr + elif noise_scheduler.config.prediction_type == "v_prediction": + mse_loss_weights = mse_loss_weights / (snr + 1) + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.use_ema: + ema_unet.step(unet.parameters()) + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + + if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_size > 0 and epoch % args.validation_epochs == 0: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + log_validation( + val_dataloader, + vae, + text_encoder, + tokenizer, + unet, + args, + accelerator, + weight_dtype, + global_step, + ) + if args.use_ema: + # Switch back to the original UNet parameters. + ema_unet.restore(unet.parameters()) + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = unwrap_model(unet) + if args.use_ema: + ema_unet.copy_to(unet.parameters()) + + pipeline = StableDiffusionInpaintPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=text_encoder, + vae=vae, + unet=unet, + revision=args.revision, + variant=args.variant, + ) + pipeline.save_pretrained(args.output_dir) + + # Run a final round of inference. + images = [] + if args.validation_size > 0: + logger.info("Running inference for collecting generated images...") + pipeline = pipeline.to(accelerator.device) + pipeline.torch_dtype = weight_dtype + pipeline.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + images = [] + prompts = [] + masks = [] + image_transform = transforms.ToPILImage() + for _, batch in enumerate(val_dataloader): + + mask = image_transform(batch['masks'][0]*255) + init_image= image_transform(batch['pixel_values'][0]) + prompt = batch['prompts'][0] + + with torch.autocast("cuda"): + #### UPDATE PIPELINE HERE + image = pipeline( + prompt, + image=init_image, + mask_image=mask, + num_inference_steps=20, + generator=generator + ).images[0] + + prompts.append(prompt) + images.append(image) + masks.append(mask) + + if args.push_to_hub: + save_model_card(args, repo_id, images, prompts, masks, repo_folder=args.output_dir) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + main() From 882cb67b68e8703019726c25a8704f9f73198c08 Mon Sep 17 00:00:00 2001 From: vhakobyan Date: Fri, 9 Feb 2024 11:31:41 +0000 Subject: [PATCH 02/14] wip: update documentation --- examples/inpainting/README.md | 243 ++++++++-------------------------- 1 file changed, 56 insertions(+), 187 deletions(-) diff --git a/examples/inpainting/README.md b/examples/inpainting/README.md index a56cccbcf5d7..6cd7148ed888 100644 --- a/examples/inpainting/README.md +++ b/examples/inpainting/README.md @@ -36,7 +36,7 @@ Note also that we use PEFT library as backend for LoRA training, make sure to ha ### Pokemon example -You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree. +You need to accept the model license before downloading or using the weights. In this example we'll use model version `sd-v1-5-inpainting` from runwayml, so you'll need to visit [its card](https://huggingface.co/runwayml/stable-diffusion-inpainting), read the license and tick the checkbox if you agree. You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens). @@ -53,25 +53,27 @@ If you have already cloned the repo, then you won't need to go through these ste #### Hardware With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory. -**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** +**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2-inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) 768x768 model.___** ```bash -export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export MODEL_NAME="runwayml/stable-diffusion-inpainting" export DATASET_NAME="lambdalabs/pokemon-blip-captions" -accelerate launch --mixed_precision="fp16" train_text_to_image.py \ + +accelerate launch --mixed_precision="fp16" train_inpainting.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --dataset_name=$DATASET_NAME \ --use_ema \ --resolution=512 --center_crop --random_flip \ - --train_batch_size=1 \ - --gradient_accumulation_steps=4 \ + --train_batch_size=4 \ + --gradient_accumulation_steps=2 \ --gradient_checkpointing \ --max_train_steps=15000 \ --learning_rate=1e-05 \ - --max_grad_norm=1 \ + --max_grad_norm=1 --seed=42 \ --lr_scheduler="constant" --lr_warmup_steps=0 \ - --output_dir="sd-pokemon-model" + --output_dir="sd-pokemon-model-inpaint" \ + --validation_size=3 ``` @@ -80,53 +82,73 @@ To run on your own training files prepare the dataset according to the format re If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script. ```bash -export MODEL_NAME="CompVis/stable-diffusion-v1-4" -export TRAIN_DIR="path_to_your_dataset" +export MODEL_NAME="runwayml/stable-diffusion-inpainting" +export DATASET_NAME="path_to_your_dataset" (NOT IMPLEMENTED) + -accelerate launch --mixed_precision="fp16" train_text_to_image.py \ +accelerate launch --mixed_precision="fp16" train_inpainting.py \ --pretrained_model_name_or_path=$MODEL_NAME \ - --train_data_dir=$TRAIN_DIR \ + --dataset_name=$DATASET_NAME \ --use_ema \ --resolution=512 --center_crop --random_flip \ - --train_batch_size=1 \ - --gradient_accumulation_steps=4 \ + --train_batch_size=4 \ + --gradient_accumulation_steps=2 \ --gradient_checkpointing \ --max_train_steps=15000 \ --learning_rate=1e-05 \ - --max_grad_norm=1 \ + --max_grad_norm=1 --seed=42 \ --lr_scheduler="constant" --lr_warmup_steps=0 \ - --output_dir="sd-pokemon-model" + --output_dir="sd-pokemon-model-inpaint" \ + --validation_size=3 ``` -Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline` +Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model-inpaint`. To load the fine-tuned model for inference just pass that path to `StableDiffusionInpaintPipeline` ```python import torch -from diffusers import StableDiffusionPipeline +from PIL import Image +from diffusers import StableDiffusionInpaintPipeline + +init_image = Image.open("path_to_image").resize((512, 512)) +mask_image = Image.open("path_to_mask").resize((512, 512)) model_path = "path_to_saved_model" -pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) +pipe = StableDiffusionInpaintPipeline.from_pretrained(model_path, torch_dtype=torch.float16) pipe.to("cuda") -image = pipe(prompt="yoda").images[0] -image.save("yoda-pokemon.png") +inpainted_image = pipe( + prompt = "yoda", + image = init_image, + mask_image = mask_image, +).images[0] + +inpainted_image.save("inpainted-yoda-pokemon.png") ``` Checkpoints only save the unet, so to run inference from a checkpoint, just load the unet ```python import torch -from diffusers import StableDiffusionPipeline, UNet2DConditionModel +from PIL import Image +from diffusers import StableDiffusionInpaintPipeline, UNet2DConditionModel + +init_image = Image.open("path_to_image").resize((512, 512)) +mask_image = Image.open("path_to_mask").resize((512, 512)) model_path = "path_to_saved_model" unet = UNet2DConditionModel.from_pretrained(model_path + "/checkpoint-/unet", torch_dtype=torch.float16) -pipe = StableDiffusionPipeline.from_pretrained("", unet=unet, torch_dtype=torch.float16) +pipe = StableDiffusionInpaintPipeline.from_pretrained("", unet=unet, torch_dtype=torch.float16) pipe.to("cuda") -image = pipe(prompt="yoda").images[0] -image.save("yoda-pokemon.png") +inpainted_image = pipe( + prompt = "yoda", + image = init_image, + mask_image = mask_image, +).images[0] + +inpainted_image.save("inpainted-yoda-pokemon.png") ``` #### Training with multiple GPUs @@ -135,22 +157,24 @@ image.save("yoda-pokemon.png") for running distributed training with `accelerate`. Here is an example command: ```bash -export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export MODEL_NAME="runwayml/stable-diffusion-inpainting" export DATASET_NAME="lambdalabs/pokemon-blip-captions" -accelerate launch --mixed_precision="fp16" --multi_gpu train_text_to_image.py \ + +accelerate launch --mixed_precision="fp16" --multi_gpu train_inpainting.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --dataset_name=$DATASET_NAME \ --use_ema \ --resolution=512 --center_crop --random_flip \ - --train_batch_size=1 \ - --gradient_accumulation_steps=4 \ + --train_batch_size=4 \ + --gradient_accumulation_steps=2 \ --gradient_checkpointing \ --max_train_steps=15000 \ --learning_rate=1e-05 \ - --max_grad_norm=1 \ + --max_grad_norm=1 --seed=42 \ --lr_scheduler="constant" --lr_warmup_steps=0 \ - --output_dir="sd-pokemon-model" + --output_dir="sd-pokemon-model-inpaint" \ + --validation_size=3 ``` @@ -168,159 +192,4 @@ You can find [this project on Weights and Biases](https://wandb.ai/sayakpaul/tex For our small Pokemons dataset, the effects of Min-SNR weighting strategy might not appear to be pronounced, but for larger datasets, we believe the effects will be more pronounced. -Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds. - -## Training with LoRA - -Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. - -In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages: - -- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114). -- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable. -- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter. - -[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository. - -With LoRA, it's possible to fine-tune Stable Diffusion on a custom image-caption pair dataset -on consumer GPUs like Tesla T4, Tesla V100. - -### Training - -First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Stable Diffusion v1-4](https://hf.co/CompVis/stable-diffusion-v1-4) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions). - -**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** - -**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [Weights and Biases](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___** - -```bash -export MODEL_NAME="CompVis/stable-diffusion-v1-4" -export DATASET_NAME="lambdalabs/pokemon-blip-captions" -``` - -For this example we want to directly store the trained LoRA embeddings on the Hub, so -we need to be logged in and add the `--push_to_hub` flag. - -```bash -huggingface-cli login -``` - -Now we can start training! - -```bash -accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \ - --pretrained_model_name_or_path=$MODEL_NAME \ - --dataset_name=$DATASET_NAME --caption_column="text" \ - --resolution=512 --random_flip \ - --train_batch_size=1 \ - --num_train_epochs=100 --checkpointing_steps=5000 \ - --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \ - --seed=42 \ - --output_dir="sd-pokemon-model-lora" \ - --validation_prompt="cute dragon creature" --report_to="wandb" -``` - -The above command will also run inference as fine-tuning progresses and log the results to Weights and Biases. - -**___Note: When using LoRA we can use a much higher learning rate compared to non-LoRA fine-tuning. Here we use *1e-4* instead of the usual *1e-5*. Also, by using LoRA, it's possible to run `train_text_to_image_lora.py` in consumer GPUs like T4 or V100.___** - -The final LoRA embedding weights have been uploaded to [sayakpaul/sd-model-finetuned-lora-t4](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4). **___Note: [The final weights](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/pytorch_lora_weights.bin) are only 3 MB in size, which is orders of magnitudes smaller than the original model.___** - -You can check some inference samples that were logged during the course of the fine-tuning process [here](https://wandb.ai/sayakpaul/text2image-fine-tune/runs/q4lc0xsw). - -### Inference - -Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline` after loading the trained LoRA weights. You -need to pass the `output_dir` for loading the LoRA weights which, in this case, is `sd-pokemon-model-lora`. - -```python -from diffusers import StableDiffusionPipeline -import torch - -model_path = "sayakpaul/sd-model-finetuned-lora-t4" -pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16) -pipe.unet.load_attn_procs(model_path) -pipe.to("cuda") - -prompt = "A pokemon with green eyes and red legs." -image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] -image.save("pokemon.png") -``` - -If you are loading the LoRA parameters from the Hub and if the Hub repository has -a `base_model` tag (such as [this](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/README.md?code=true#L4)), then -you can do: - -```py -from huggingface_hub.repocard import RepoCard - -lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4" -card = RepoCard.load(lora_model_id) -base_model_id = card.data.to_dict()["base_model"] - -pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16) -... -``` - -## Training with Flax/JAX - -For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script. - -**___Note: The flax example doesn't yet support features like gradient checkpoint, gradient accumulation etc, so to use flax for faster training we will need >30GB cards or TPU v3.___** - - -Before running the scripts, make sure to install the library's training dependencies: - -```bash -pip install -U -r requirements_flax.txt -``` - -```bash -export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" -export DATASET_NAME="lambdalabs/pokemon-blip-captions" - -python train_text_to_image_flax.py \ - --pretrained_model_name_or_path=$MODEL_NAME \ - --dataset_name=$DATASET_NAME \ - --resolution=512 --center_crop --random_flip \ - --train_batch_size=1 \ - --mixed_precision="fp16" \ - --max_train_steps=15000 \ - --learning_rate=1e-05 \ - --max_grad_norm=1 \ - --output_dir="sd-pokemon-model" -``` - -To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata). -If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script. - -```bash -export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" -export TRAIN_DIR="path_to_your_dataset" - -python train_text_to_image_flax.py \ - --pretrained_model_name_or_path=$MODEL_NAME \ - --train_data_dir=$TRAIN_DIR \ - --resolution=512 --center_crop --random_flip \ - --train_batch_size=1 \ - --mixed_precision="fp16" \ - --max_train_steps=15000 \ - --learning_rate=1e-05 \ - --max_grad_norm=1 \ - --output_dir="sd-pokemon-model" -``` - -### Training with xFormers: - -You can enable memory efficient attention by [installing xFormers](https://huggingface.co/docs/diffusers/main/en/optimization/xformers) and passing the `--enable_xformers_memory_efficient_attention` argument to the script. - -xFormers training is not available for Flax/JAX. - -**Note**: - -According to [this issue](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212), xFormers `v0.0.16` cannot be used for training in some GPUs. If you observe that problem, please install a development version as indicated in that comment. - -## Stable Diffusion XL - -* We support fine-tuning the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) via the `train_text_to_image_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md). -* We also support fine-tuning of the UNet and Text Encoder shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with LoRA via the `train_text_to_image_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md). +Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds. \ No newline at end of file From 89854ee90bc32fdc2d5283e15aefb37ec2b1d566 Mon Sep 17 00:00:00 2001 From: vhakobyan Date: Fri, 9 Feb 2024 11:32:32 +0000 Subject: [PATCH 03/14] fix: README --- examples/inpainting/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/inpainting/README.md b/examples/inpainting/README.md index 6cd7148ed888..685b0cfaa518 100644 --- a/examples/inpainting/README.md +++ b/examples/inpainting/README.md @@ -1,6 +1,6 @@ # Stable Diffusion text-to-image fine-tuning -The `train_text_to_image.py` script shows how to fine-tune stable diffusion model on your own dataset. +The `train_inpainting.py` script shows how to fine-tune stable diffusion model on your own dataset. ___Note___: From 969605f1896d434fc0ab5d75322366003b1c6eb4 Mon Sep 17 00:00:00 2001 From: vhakobyan Date: Fri, 9 Feb 2024 11:33:10 +0000 Subject: [PATCH 04/14] fix: README title --- examples/inpainting/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/inpainting/README.md b/examples/inpainting/README.md index 685b0cfaa518..2ecf0bffd94f 100644 --- a/examples/inpainting/README.md +++ b/examples/inpainting/README.md @@ -1,4 +1,4 @@ -# Stable Diffusion text-to-image fine-tuning +# Stable Diffusion Inpainting fine-tuning The `train_inpainting.py` script shows how to fine-tune stable diffusion model on your own dataset. From 69d4494fbebbf169bfbdb7c2a8b8ceaa8f9dcf2c Mon Sep 17 00:00:00 2001 From: vhakobyan Date: Mon, 26 Feb 2024 13:44:13 +0000 Subject: [PATCH 05/14] wip: integrating LAMA masking --- examples/inpainting/requirements.txt | 3 +- examples/inpainting/sd_train_inpaint.sh | 17 ++ examples/inpainting/train_inpainting.py | 350 +++++++++++++++--------- src/diffusers/loaders/single_file.py | 1 + 4 files changed, 243 insertions(+), 128 deletions(-) create mode 100755 examples/inpainting/sd_train_inpaint.sh diff --git a/examples/inpainting/requirements.txt b/examples/inpainting/requirements.txt index 0dd164fc2035..a6794f3dbfb3 100644 --- a/examples/inpainting/requirements.txt +++ b/examples/inpainting/requirements.txt @@ -5,4 +5,5 @@ datasets ftfy tensorboard Jinja2 -peft==0.7.0 \ No newline at end of file +peft==0.7.0 +opencv-python==4.8.0.76 \ No newline at end of file diff --git a/examples/inpainting/sd_train_inpaint.sh b/examples/inpainting/sd_train_inpaint.sh new file mode 100755 index 000000000000..70298bc4386d --- /dev/null +++ b/examples/inpainting/sd_train_inpaint.sh @@ -0,0 +1,17 @@ +export MODEL_NAME="stabilityai/stable-diffusion-2-inpainting" +export DATASET_NAME="lambdalabs/pokemon-blip-captions" + + +CUDA_VISIBLE_DEVICES="1" accelerate launch --num_processes 1 --main_process_port 29502 --mixed_precision="fp16" train_inpainting.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_NAME \ + --use_ema \ + --resolution=768 --center_crop --random_flip \ + --train_batch_size=4 \ + --gradient_accumulation_steps=2 \ + --gradient_checkpointing \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 --seed=42 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --output_dir="sd-pokemon-model-inpaint-2" --validation_size=3 --validation_epochs=1 --report_to="wandb" \ No newline at end of file diff --git a/examples/inpainting/train_inpainting.py b/examples/inpainting/train_inpainting.py index 0d52380a61f6..5c76cf0ba729 100644 --- a/examples/inpainting/train_inpainting.py +++ b/examples/inpainting/train_inpainting.py @@ -14,11 +14,13 @@ # See the License for the specific language governing permissions and import argparse +import cv2 import logging import math import os import random import shutil +from functools import partial from pathlib import Path import accelerate @@ -39,7 +41,6 @@ from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from transformers.utils import ContextManagers -from functools import partial import diffusers from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel @@ -50,7 +51,6 @@ from diffusers.utils.torch_utils import is_compiled_module - if is_wandb_available(): import wandb @@ -64,6 +64,173 @@ "lambdalabs/pokemon-blip-captions": ("image", "text"), } +def make_random_irregular_mask(shape, max_angle, max_len, max_width, min_times, max_times): + + height, width = shape + mask = np.zeros((height, width), np.float32) + times = np.random.randint(min_times, max_times + 1) + for i in range(times): + start_x = np.random.randint(width) + start_y = np.random.randint(height) + for j in range(1 + np.random.randint(5)): + angle = 0.01 + np.random.randint(max_angle) + if i % 2 == 0: + angle = 2 * 3.1415926 - angle + length = 10 + np.random.randint(max_len) + brush_w = 5 + np.random.randint(max_width) + end_x = np.clip((start_x + length * np.sin(angle)).astype(np.int32), 0, width) + end_y = np.clip((start_y + length * np.cos(angle)).astype(np.int32), 0, height) + + cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, brush_w) + start_x, start_y = end_x, end_y + return torch.from_numpy(mask[None, ...]).squeeze(0).byte() + +def make_random_rectangle_mask(shape, margin, bbox_min_size, bbox_max_size, min_times, max_times): + + height, width = shape + mask = np.zeros((height, width), np.float32) + bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2) + times = np.random.randint(min_times, max_times + 1) + for i in range(times): + box_width = np.random.randint(bbox_min_size, bbox_max_size) + box_height = np.random.randint(bbox_min_size, bbox_max_size) + start_x = np.random.randint(margin, width - margin - box_width + 1) + start_y = np.random.randint(margin, height - margin - box_height + 1) + mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1 + return torch.from_numpy(mask[None, ...]).squeeze(0).byte() + + +class RandomIrregularMaskGenerator: + """ + Initializes the RandomIrregularMaskGenerator with the provided parameters. + + Parameters: + max_angle (int): The maximum angle for the line segments, influencing the irregularity of the shapes. + max_len (int): The maximum length for each line segment, affecting the size of the irregular shapes. + max_width (int): The maximum width for each line segment, determining the thickness of the irregular shapes. + min_times (int): The minimum number of irregular shapes to be generated on the mask. + max_times (int): The maximum number of irregular shapes to be generated on the mask. + """ + def __init__(self, max_angle, max_len, max_width, min_times, max_times): + self.max_angle = max_angle + self.max_len = max_len + self.max_width = max_width + self.min_times = min_times + self.max_times = max_times + + def __call__(self, img_shape): + """ + Generates a mask with random irregular shapes when called with an image. + + Parameters: + img (tuple): Tuple of image dimensions, excluding channels. + + Returns: + np.array: A mask array with the same height and width as the input image, containing random irregular shapes. + """ + cur_max_len = int(max(1, self.max_len)) + cur_max_width = int(max(1, self.max_width)) + cur_max_times = int(self.min_times + 1 + (self.max_times - self.min_times)) + return make_random_irregular_mask(img_shape, max_angle=self.max_angle, max_len=cur_max_len, + max_width=cur_max_width, min_times=self.min_times, max_times=cur_max_times) + + +class RandomRectangleMaskGenerator: + """ + A generator class for creating masks with random rectangular shapes on images. + The rectangles are defined within specified constraints for margins, size, and the number of times they appear. + + Attributes: + margin (int): The minimum distance between the rectangle edges and the image boundaries. + bbox_min_size (int): The minimum size for the width and height of the rectangles. + bbox_max_size (int): The maximum size for the width and height of the rectangles. + min_times (int): The minimum number of rectangles to be generated on the mask. + max_times (int): The maximum number of rectangles to be generated on the mask. + """ + def __init__(self, margin, bbox_min_size, bbox_max_size, min_times, max_times): + self.margin = margin + self.bbox_min_size = bbox_min_size + self.bbox_max_size = bbox_max_size + self.min_times = min_times + self.max_times = max_times + + def __call__(self, img_shape): + """ + Generates a mask with random rectangles when called with an image. + + Parameters: + img (tuple): Tuple of image dimensions, excluding channels. + + Returns: + np.array: A mask array with the same height and width as the input image, containing random rectangles. + """ + cur_bbox_max_size = int(self.bbox_min_size + 1 + (self.bbox_max_size - self.bbox_min_size)) + cur_max_times = int(self.min_times + (self.max_times - self.min_times)) + return make_random_rectangle_mask(img_shape, margin=self.margin, bbox_min_size=self.bbox_min_size, + bbox_max_size=cur_bbox_max_size, min_times=self.min_times, + max_times=cur_max_times) + + + +def create_center_square_binary_mask(total_size): + """ + Creates a binary mask tensor based on provided parameters using PyTorch. + The output tensor has values of 0 or 1, where 1 represents the masked area. + + :param total_size: The size of the sides of the square black background image. + :param max_square_size: The maximum size of the sides of the white square. + :param probability_threshold: The threshold for deciding between a completely white mask and a random mask. + :param center_mask: If True, generates a centered white square mask 3/4 the size of total_size. + :return: A PyTorch tensor representing the binary mask. + """ + # Create a new black image (mask) as a PyTorch tensor + mask = torch.zeros(total_size, total_size, dtype=torch.uint8) + + # Create a centered white square mask 3/4 the size of total_size + center_square_size = int(total_size * 3 / 4) + start = (total_size - center_square_size) // 2 + end = start + center_square_size + mask[start:end, start:end] = 1 + + return mask + + +def create_mask(complete_mask_prob, total_size, center_mask, args): + + + if center_mask: + mask = create_center_square_binary_mask(total_size) + + else: + random_number = random.random() + if random_number < complete_mask_prob: + mask = torch.zeros(total_size, total_size, dtype=torch.uint8) + mask.fill_(1) + else: + + if random_number < complete_mask_prob + (1 - complete_mask_prob)/2: + mask_generator = RandomIrregularMaskGenerator( + max_angle=args.i_max_angle, + max_len=args.i_max_len, + max_width=args.i_max_width, + min_times=args.i_min_times, + max_times=args.i_max_times + ) + mask = mask_generator((total_size, total_size)) + else: + mask_generator = RandomRectangleMaskGenerator( + margin=args.r_margin, + bbox_min_size=args.r_bbox_min_size, + bbox_max_size=args.r_bbox_max_size, + min_times=args.r_min_times, + max_times=args.r_max_times + ) + mask = mask_generator((total_size, total_size)) + + return mask + + + def save_model_card( args, @@ -75,7 +242,7 @@ def save_model_card( ): img_str = "" if len(images) > 0: - image_grid = make_image_grid(images+masks, 2, args.validation_size) + image_grid = make_image_grid(images + masks, 2, args.validation_size) image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png")) img_str += "![val_imgs_grid](./val_imgs_grid.png)\n" @@ -94,7 +261,7 @@ def save_model_card( --- """ model_card = f""" -# Text-to-image finetuning - {repo_id} +# Inpainting finetuning - {repo_id} This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {prompts}: \n {img_str} @@ -109,7 +276,9 @@ def save_model_card( pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16) prompt = "{prompts[0]}" -image = pipeline(prompt).images[0] +mask = "{masks[0]}" +init_image = "{images[0]}" +image = pipeline(prompt, mask_image=mask, image=init_image).images[0] image.save("my_image.png") ``` @@ -172,19 +341,14 @@ def log_validation(validation_dataloader, vae, text_encoder, tokenizer, unet, ar masks = [] image_transform = transforms.ToPILImage() for _, batch in enumerate(validation_dataloader): - - mask = image_transform(batch['masks'][0]*255) - init_image= image_transform(batch['pixel_values'][0]) - prompt = batch['prompts'][0] - + mask = image_transform(batch["masks"][0] * 255) + init_image = image_transform(batch["pixel_values"][0]) + prompt = batch["prompts"][0] + with torch.autocast("cuda"): #### UPDATE PIPELINE HERE image = pipeline( - prompt, - image=init_image, - mask_image=mask, - num_inference_steps=20, - generator=generator + prompt, image=init_image, mask_image=mask, num_inference_steps=20, generator=generator ).images[0] prompts.append(prompt) @@ -196,7 +360,6 @@ def log_validation(validation_dataloader, vae, text_encoder, tokenizer, unet, ar np_images = np.stack([np.asarray(img) for img in images]) tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") elif tracker.name == "wandb": - tracker.log( { "validation": [ @@ -214,42 +377,8 @@ def log_validation(validation_dataloader, vae, text_encoder, tokenizer, unet, ar return images -def create_custom_binary_mask_torch(total_size, max_square_size, probability_threshold, center_mask=False): - """ - Creates a binary mask tensor based on provided parameters using PyTorch. - The output tensor has values of 0 or 1, where 1 represents the masked area. - - :param total_size: The size of the sides of the square black background image. - :param max_square_size: The maximum size of the sides of the white square. - :param probability_threshold: The threshold for deciding between a completely white mask and a random mask. - :param center_mask: If True, generates a centered white square mask 3/4 the size of total_size. - :return: A PyTorch tensor representing the binary mask. - """ - # Create a new black image (mask) as a PyTorch tensor - mask = torch.zeros(total_size, total_size, dtype=torch.uint8) - - if center_mask: - # Create a centered white square mask 3/4 the size of total_size - center_square_size = int(total_size * 3 / 4) - start = (total_size - center_square_size) // 2 - end = start + center_square_size - mask[start:end, start:end] = 1 - else: - # Generate a random number between 0 and 1 - random_number = random.random() - if random_number < probability_threshold: - # Create a completely white (binary 1) mask - mask.fill_(1) - else: - # Create a mask with a random white square - square_size = random.randint(max_square_size // 2, max_square_size) - max_val = total_size - square_size - top_left_x = random.randint(0, max_val) - top_left_y = random.randint(0, max_val) - mask[top_left_y:top_left_y + square_size, top_left_x:top_left_x + square_size] = 1 - return mask def parse_args(): @@ -553,6 +682,7 @@ def parse_args(): return args + def main(): args = parse_args() @@ -611,33 +741,13 @@ def main(): tokenizer = CLIPTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision ) - - def deepspeed_zero_init_disabled_context_manager(): - """ - returns either a context list that includes one that will disable zero.Init or an empty context list - """ - deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None - if deepspeed_plugin is None: - return [] - - return [deepspeed_plugin.zero3_init_context_manager(enable=False)] - - # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. - # For this to work properly all models must be run through `accelerate.prepare`. But accelerate - # will try to assign the same optimizer with the same weights to all models during - # `deepspeed.initialize`, which of course doesn't work. - # - # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 - # frozen models from being partitioned during `zero.Init` which gets called during - # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding - # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. - with ContextManagers(deepspeed_zero_init_disabled_context_manager()): - text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant - ) - vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant - ) + + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant + ) unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision @@ -750,14 +860,14 @@ def load_model_hook(models, input_dir): args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir, - split=datasets.ReadInstruction('train', to=args.validation_size, unit='abs') + split=datasets.ReadInstruction("train", to=args.validation_size, unit="abs"), ) train_dataset = load_dataset( args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir, - split=datasets.ReadInstruction('train', from_=args.validation_size, unit='abs') + split=datasets.ReadInstruction("train", from_=args.validation_size, unit="abs"), ) else: raise ValueError("Current Implementation supports only datasets from the hub.") @@ -814,16 +924,15 @@ def tokenize_captions(examples, is_train=True): transforms.Normalize([0.5], [0.5]), ] ) - + val_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.ToTensor() + transforms.ToTensor(), ] ) def preprocess_dataset_train(examples, is_train=True): - images = [image.convert("RGB") for image in examples[image_column]] if is_train: examples["pixel_values"] = [train_transforms(image) for image in images] @@ -843,24 +952,18 @@ def collate_fn(examples, center_mask): pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = torch.stack([example["input_ids"] for example in examples]) - - masks = torch.stack([ - create_custom_binary_mask_torch( - pixel_values.shape[2], - int(3/4*pixel_values.shape[2]), - 0.25, - center_mask - ) for ix in range(pixel_values.shape[0]) - ]).unsqueeze(1) - - prompts = [example[caption_column] for example in examples] - return { - "pixel_values": pixel_values, - "input_ids": input_ids, - "masks": masks, - "prompts": prompts - } + masks = torch.stack( + [ + create_custom_binary_mask_torch( + pixel_values.shape[2], int(3 / 4 * pixel_values.shape[2]), 0.25, center_mask + ) + for ix in range(pixel_values.shape[0]) + ] + ).unsqueeze(1) + + prompts = [example[caption_column] for example in examples] + return {"pixel_values": pixel_values, "input_ids": input_ids, "masks": masks, "prompts": prompts} # DataLoaders creation: train_dataloader = torch.utils.data.DataLoader( @@ -870,7 +973,7 @@ def collate_fn(examples, center_mask): batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers, ) - + val_dataloader = torch.utils.data.DataLoader( val_dataset, shuffle=False, @@ -878,7 +981,6 @@ def collate_fn(examples, center_mask): batch_size=1, num_workers=0, ) - # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -983,7 +1085,7 @@ def unwrap_model(model): # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) - + vae_downscaling_factor = 2 ** (len(vae.config.block_out_channels) - 1) for epoch in range(first_epoch, args.num_train_epochs): @@ -992,23 +1094,23 @@ def unwrap_model(model): with accelerator.accumulate(unet): # Convert images to latent space pixel_values = batch["pixel_values"].to(weight_dtype) - + latents = vae.encode(pixel_values).latent_dist.sample() latents = latents * vae.config.scaling_factor - - mask_orig = batch["masks"].to(weight_dtype) - + + mask_orig = batch["masks"].to(weight_dtype) + masked_latents = vae.encode(pixel_values * (1.0 - mask_orig)).latent_dist.sample() masked_latents = masked_latents * vae.config.scaling_factor - + mask = torch.nn.functional.interpolate( - mask_orig, - size=( - int(mask_orig.shape[3] // vae_downscaling_factor), - int(mask_orig.shape[2] // vae_downscaling_factor) - ) - ).to(dtype=weight_dtype, device=accelerator.device) - + mask_orig, + size=( + int(mask_orig.shape[3] // vae_downscaling_factor), + int(mask_orig.shape[2] // vae_downscaling_factor), + ), + ).to(dtype=weight_dtype, device=accelerator.device) + # Sample noise that we'll add to the latents noise = torch.randn_like(latents) if args.noise_offset: @@ -1090,7 +1192,6 @@ def unwrap_model(model): train_loss = 0.0 if global_step % args.checkpointing_steps == 0: - if accelerator.is_main_process: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: @@ -1121,7 +1222,7 @@ def unwrap_model(model): if global_step >= args.max_train_steps: break - + if accelerator.is_main_process: if args.validation_size > 0 and epoch % args.validation_epochs == 0: if args.use_ema: @@ -1142,7 +1243,7 @@ def unwrap_model(model): if args.use_ema: # Switch back to the original UNet parameters. ema_unet.restore(unet.parameters()) - + # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: @@ -1181,19 +1282,14 @@ def unwrap_model(model): masks = [] image_transform = transforms.ToPILImage() for _, batch in enumerate(val_dataloader): - - mask = image_transform(batch['masks'][0]*255) - init_image= image_transform(batch['pixel_values'][0]) - prompt = batch['prompts'][0] - + mask = image_transform(batch["masks"][0] * 255) + init_image = image_transform(batch["pixel_values"][0]) + prompt = batch["prompts"][0] + with torch.autocast("cuda"): #### UPDATE PIPELINE HERE image = pipeline( - prompt, - image=init_image, - mask_image=mask, - num_inference_steps=20, - generator=generator + prompt, image=init_image, mask_image=mask, num_inference_steps=20, generator=generator ).images[0] prompts.append(prompt) diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index 6068994d913c..e68f26c75f8a 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -50,6 +50,7 @@ def build_sub_model_components( image_size=None, **kwargs, ): + if component_name in pipeline_components: return {} From 07c8fd1a96dfb6312fea978cc2c371f82ba87e5b Mon Sep 17 00:00:00 2001 From: vhakobyan Date: Sat, 2 Mar 2024 07:06:13 +0000 Subject: [PATCH 06/14] wip: final fixes --- examples/inpainting/sd_train_inpaint.sh | 17 --- examples/inpainting/train_inpainting.py | 186 +++++++++++++++--------- 2 files changed, 118 insertions(+), 85 deletions(-) delete mode 100755 examples/inpainting/sd_train_inpaint.sh diff --git a/examples/inpainting/sd_train_inpaint.sh b/examples/inpainting/sd_train_inpaint.sh deleted file mode 100755 index 70298bc4386d..000000000000 --- a/examples/inpainting/sd_train_inpaint.sh +++ /dev/null @@ -1,17 +0,0 @@ -export MODEL_NAME="stabilityai/stable-diffusion-2-inpainting" -export DATASET_NAME="lambdalabs/pokemon-blip-captions" - - -CUDA_VISIBLE_DEVICES="1" accelerate launch --num_processes 1 --main_process_port 29502 --mixed_precision="fp16" train_inpainting.py \ - --pretrained_model_name_or_path=$MODEL_NAME \ - --dataset_name=$DATASET_NAME \ - --use_ema \ - --resolution=768 --center_crop --random_flip \ - --train_batch_size=4 \ - --gradient_accumulation_steps=2 \ - --gradient_checkpointing \ - --max_train_steps=15000 \ - --learning_rate=1e-05 \ - --max_grad_norm=1 --seed=42 \ - --lr_scheduler="constant" --lr_warmup_steps=0 \ - --output_dir="sd-pokemon-model-inpaint-2" --validation_size=3 --validation_epochs=1 --report_to="wandb" \ No newline at end of file diff --git a/examples/inpainting/train_inpainting.py b/examples/inpainting/train_inpainting.py index 5c76cf0ba729..41ffb081137d 100644 --- a/examples/inpainting/train_inpainting.py +++ b/examples/inpainting/train_inpainting.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and import argparse -import cv2 import logging import math import os @@ -24,6 +23,7 @@ from pathlib import Path import accelerate +import cv2 import datasets import numpy as np import torch @@ -32,7 +32,6 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.state import AcceleratorState from accelerate.utils import ProjectConfiguration, set_seed from datasets import load_dataset from huggingface_hub import create_repo, upload_folder @@ -40,13 +39,13 @@ from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer -from transformers.utils import ContextManagers import diffusers from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel, compute_snr from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module @@ -64,8 +63,8 @@ "lambdalabs/pokemon-blip-captions": ("image", "text"), } -def make_random_irregular_mask(shape, max_angle, max_len, max_width, min_times, max_times): +def make_random_irregular_mask(shape, max_angle, max_len, max_width, min_times, max_times): height, width = shape mask = np.zeros((height, width), np.float32) times = np.random.randint(min_times, max_times + 1) @@ -85,8 +84,8 @@ def make_random_irregular_mask(shape, max_angle, max_len, max_width, min_times, start_x, start_y = end_x, end_y return torch.from_numpy(mask[None, ...]).squeeze(0).byte() + def make_random_rectangle_mask(shape, margin, bbox_min_size, bbox_max_size, min_times, max_times): - height, width = shape mask = np.zeros((height, width), np.float32) bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2) @@ -96,9 +95,9 @@ def make_random_rectangle_mask(shape, margin, bbox_min_size, bbox_max_size, min_ box_height = np.random.randint(bbox_min_size, bbox_max_size) start_x = np.random.randint(margin, width - margin - box_width + 1) start_y = np.random.randint(margin, height - margin - box_height + 1) - mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1 + mask[start_y : start_y + box_height, start_x : start_x + box_width] = 1 return torch.from_numpy(mask[None, ...]).squeeze(0).byte() - + class RandomIrregularMaskGenerator: """ @@ -111,6 +110,7 @@ class RandomIrregularMaskGenerator: min_times (int): The minimum number of irregular shapes to be generated on the mask. max_times (int): The maximum number of irregular shapes to be generated on the mask. """ + def __init__(self, max_angle, max_len, max_width, min_times, max_times): self.max_angle = max_angle self.max_len = max_len @@ -131,8 +131,14 @@ def __call__(self, img_shape): cur_max_len = int(max(1, self.max_len)) cur_max_width = int(max(1, self.max_width)) cur_max_times = int(self.min_times + 1 + (self.max_times - self.min_times)) - return make_random_irregular_mask(img_shape, max_angle=self.max_angle, max_len=cur_max_len, - max_width=cur_max_width, min_times=self.min_times, max_times=cur_max_times) + return make_random_irregular_mask( + img_shape, + max_angle=self.max_angle, + max_len=cur_max_len, + max_width=cur_max_width, + min_times=self.min_times, + max_times=cur_max_times, + ) class RandomRectangleMaskGenerator: @@ -147,6 +153,7 @@ class RandomRectangleMaskGenerator: min_times (int): The minimum number of rectangles to be generated on the mask. max_times (int): The maximum number of rectangles to be generated on the mask. """ + def __init__(self, margin, bbox_min_size, bbox_max_size, min_times, max_times): self.margin = margin self.bbox_min_size = bbox_min_size @@ -166,12 +173,16 @@ def __call__(self, img_shape): """ cur_bbox_max_size = int(self.bbox_min_size + 1 + (self.bbox_max_size - self.bbox_min_size)) cur_max_times = int(self.min_times + (self.max_times - self.min_times)) - return make_random_rectangle_mask(img_shape, margin=self.margin, bbox_min_size=self.bbox_min_size, - bbox_max_size=cur_bbox_max_size, min_times=self.min_times, - max_times=cur_max_times) - - - + return make_random_rectangle_mask( + img_shape, + margin=self.margin, + bbox_min_size=self.bbox_min_size, + bbox_max_size=cur_bbox_max_size, + min_times=self.min_times, + max_times=cur_max_times, + ) + + def create_center_square_binary_mask(total_size): """ Creates a binary mask tensor based on provided parameters using PyTorch. @@ -195,27 +206,24 @@ def create_center_square_binary_mask(total_size): return mask -def create_mask(complete_mask_prob, total_size, center_mask, args): - - +def create_mask(total_size, center_mask, args): if center_mask: mask = create_center_square_binary_mask(total_size) - + else: random_number = random.random() - if random_number < complete_mask_prob: + if random_number < args.complete_mask_prob: mask = torch.zeros(total_size, total_size, dtype=torch.uint8) mask.fill_(1) else: - - if random_number < complete_mask_prob + (1 - complete_mask_prob)/2: + if random_number < args.complete_mask_prob + (1.0 - args.complete_mask_prob) / 2: mask_generator = RandomIrregularMaskGenerator( - max_angle=args.i_max_angle, - max_len=args.i_max_len, - max_width=args.i_max_width, - min_times=args.i_min_times, - max_times=args.i_max_times - ) + max_angle=args.i_max_angle, + max_len=args.i_max_len, + max_width=args.i_max_width, + min_times=args.i_min_times, + max_times=args.i_max_times, + ) mask = mask_generator((total_size, total_size)) else: mask_generator = RandomRectangleMaskGenerator( @@ -223,13 +231,11 @@ def create_mask(complete_mask_prob, total_size, center_mask, args): bbox_min_size=args.r_bbox_min_size, bbox_max_size=args.r_bbox_max_size, min_times=args.r_min_times, - max_times=args.r_max_times + max_times=args.r_max_times, ) mask = mask_generator((total_size, total_size)) - - return mask - + return mask def save_model_card( @@ -241,26 +247,12 @@ def save_model_card( repo_folder=None, ): img_str = "" - if len(images) > 0: - image_grid = make_image_grid(images + masks, 2, args.validation_size) + if images is not None > 0: + image_grid = make_image_grid(images, 1, len(args.validation_prompts)) image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png")) img_str += "![val_imgs_grid](./val_imgs_grid.png)\n" - yaml = f""" ---- -license: creativeml-openrail-m -base_model: {args.pretrained_model_name_or_path} -datasets: -- {args.dataset_name} -tags: -- stable-diffusion -- stable-diffusion-diffusers -- text-to-image -- diffusers -inference: true ---- - """ - model_card = f""" + model_description = f""" # Inpainting finetuning - {repo_id} This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {prompts}: \n @@ -305,10 +297,21 @@ def save_model_card( More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}). """ - model_card += wandb_info + model_description += wandb_info + + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="creativeml-openrail-m", + base_model=args.pretrained_model_name_or_path, + model_description=model_description, + inference=True, + ) + + tags = ["stable-diffusion", "stable-diffusion-diffusers", "inpainting", "diffusers"] + model_card = populate_model_card(model_card, tags=tags) - with open(os.path.join(repo_folder, "README.md"), "w") as f: - f.write(yaml + model_card) + model_card.save(os.path.join(repo_folder, "README.md")) def log_validation(validation_dataloader, vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch): @@ -325,6 +328,7 @@ def log_validation(validation_dataloader, vae, text_encoder, tokenizer, unet, ar variant=args.variant, torch_dtype=weight_dtype, ) + pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) @@ -377,10 +381,6 @@ def log_validation(validation_dataloader, vae, text_encoder, tokenizer, unet, ar return images - - - - def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( @@ -667,6 +667,43 @@ def parse_args(): ), ) + parser.add_argument( + "--complete_mask_prob", type=float, default=0.25, help="The probability of using a complete mask." + ) + parser.add_argument("--i_max_angle", type=int, default=4, help="The maximum angle for the line segments.") + parser.add_argument("--i_max_len", type=int, default=400, help="The maximum length for each line segment.") + parser.add_argument("--i_max_width", type=int, default=200, help="The maximum width for each line segment.") + parser.add_argument( + "--i_min_times", + type=int, + default=1, + help="The minimum number of irregular shapes to be generated on the mask.", + ) + parser.add_argument( + "--i_max_times", + type=int, + default=5, + help="The maximum number of irregular shapes to be generated on the mask.", + ) + parser.add_argument( + "--r_margin", + type=int, + default=10, + help="The minimum distance between the rectangle edges and the image boundaries.", + ) + parser.add_argument( + "--r_bbox_min_size", type=int, default=60, help="The minimum size for the width and height of the rectangles." + ) + parser.add_argument( + "--r_bbox_max_size", type=int, default=300, help="The maximum size for the width and height of the rectangles." + ) + parser.add_argument( + "--r_min_times", type=int, default=1, help="The minimum number of rectangles to be generated on the mask." + ) + parser.add_argument( + "--r_max_times", type=int, default=4, help="The maximum number of rectangles to be generated on the mask." + ) + args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: @@ -731,7 +768,7 @@ def main(): if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) - if args.push_to_hub: + if args.report_to == "wandb" and args.hub_token is not None: repo_id = create_repo( repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token ).repo_id @@ -741,7 +778,7 @@ def main(): tokenizer = CLIPTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision ) - + text_encoder = CLIPTextModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) @@ -753,6 +790,27 @@ def main(): args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision ) + # InstructPix2Pix uses an additional image for conditioning. To accommodate that, + # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is + # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized + # from the pre-trained checkpoints. For the extra channels added to the first layer, they are + # initialized to zero. + + # when most likely a text2img pretrained model is used + if unet.conv_in.in_channels == 4: + logger.info("Initializing the Inpainting UNet from the pretrained 4 channel UNet .") + in_channels = 9 + out_channels = unet.conv_in.out_channels + unet.register_to_config(in_channels=in_channels) + + with torch.no_grad(): + new_conv_in = torch.nn.Conv2d( + in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding + ) + new_conv_in.weight.zero_() + new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) + unet.conv_in = new_conv_in + # Freeze vae and text_encoder and set unet to trainable vae.requires_grad_(False) text_encoder.requires_grad_(False) @@ -760,10 +818,7 @@ def main(): # Create EMA for the unet. if args.use_ema: - ema_unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant - ) - ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) + ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -954,12 +1009,7 @@ def collate_fn(examples, center_mask): input_ids = torch.stack([example["input_ids"] for example in examples]) masks = torch.stack( - [ - create_custom_binary_mask_torch( - pixel_values.shape[2], int(3 / 4 * pixel_values.shape[2]), 0.25, center_mask - ) - for ix in range(pixel_values.shape[0]) - ] + [create_mask(pixel_values.shape[2], center_mask, args) for _ in range(pixel_values.shape[0])] ).unsqueeze(1) prompts = [example[caption_column] for example in examples] From c1c3a0e3582a27c5cc4017a3a16b77c3a6cb82e9 Mon Sep 17 00:00:00 2001 From: vhakobyan Date: Sat, 2 Mar 2024 07:12:25 +0000 Subject: [PATCH 07/14] wip: updating README --- examples/inpainting/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/inpainting/README.md b/examples/inpainting/README.md index 2ecf0bffd94f..8a83c74fab19 100644 --- a/examples/inpainting/README.md +++ b/examples/inpainting/README.md @@ -36,7 +36,7 @@ Note also that we use PEFT library as backend for LoRA training, make sure to ha ### Pokemon example -You need to accept the model license before downloading or using the weights. In this example we'll use model version `sd-v1-5-inpainting` from runwayml, so you'll need to visit [its card](https://huggingface.co/runwayml/stable-diffusion-inpainting), read the license and tick the checkbox if you agree. +You need to accept the model license before downloading or using the weights. In this example we'll use model version `sd-v1-5-inpainting` or `v1-5` from runwayml, so you'll need to visit [inpainting card](https://huggingface.co/runwayml/stable-diffusion-inpainting) or [v1-5 card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree. You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens). @@ -53,7 +53,7 @@ If you have already cloned the repo, then you won't need to go through these ste #### Hardware With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory. -**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2-inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) 768x768 model.___** +**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2-inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) or [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** ```bash export MODEL_NAME="runwayml/stable-diffusion-inpainting" From 94d877cbd48ad451c33f28e81d56e284aff95c25 Mon Sep 17 00:00:00 2001 From: vhakobyan Date: Sat, 2 Mar 2024 13:34:47 +0000 Subject: [PATCH 08/14] wip: last inference step with log_validation --- examples/inpainting/train_inpainting.py | 44 +++++++------------------ 1 file changed, 12 insertions(+), 32 deletions(-) diff --git a/examples/inpainting/train_inpainting.py b/examples/inpainting/train_inpainting.py index 41ffb081137d..37420bb4929d 100644 --- a/examples/inpainting/train_inpainting.py +++ b/examples/inpainting/train_inpainting.py @@ -378,7 +378,7 @@ def log_validation(validation_dataloader, vae, text_encoder, tokenizer, unet, ar del pipeline torch.cuda.empty_cache() - return images + return images, prompts, masks def parse_args(): @@ -1312,39 +1312,19 @@ def unwrap_model(model): pipeline.save_pretrained(args.output_dir) # Run a final round of inference. - images = [] if args.validation_size > 0: logger.info("Running inference for collecting generated images...") - pipeline = pipeline.to(accelerator.device) - pipeline.torch_dtype = weight_dtype - pipeline.set_progress_bar_config(disable=True) - - if args.enable_xformers_memory_efficient_attention: - pipeline.enable_xformers_memory_efficient_attention() - - if args.seed is None: - generator = None - else: - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) - - images = [] - prompts = [] - masks = [] - image_transform = transforms.ToPILImage() - for _, batch in enumerate(val_dataloader): - mask = image_transform(batch["masks"][0] * 255) - init_image = image_transform(batch["pixel_values"][0]) - prompt = batch["prompts"][0] - - with torch.autocast("cuda"): - #### UPDATE PIPELINE HERE - image = pipeline( - prompt, image=init_image, mask_image=mask, num_inference_steps=20, generator=generator - ).images[0] - - prompts.append(prompt) - images.append(image) - masks.append(mask) + images, prompts, masks = log_validation( + val_dataloader, + vae, + text_encoder, + tokenizer, + unet, + args, + accelerator, + weight_dtype, + global_step, + ) if args.push_to_hub: save_model_card(args, repo_id, images, prompts, masks, repo_folder=args.output_dir) From 5dd28bd4af1f9ba01531d1e86019008f05884929 Mon Sep 17 00:00:00 2001 From: vhakobyan Date: Sun, 3 Mar 2024 06:14:44 +0000 Subject: [PATCH 09/14] wip: fixing log_validation, tests --- examples/inpainting/test_inpainting.py | 384 ++++++++++++++++++++++++ examples/inpainting/train_inpainting.py | 41 ++- src/diffusers/loaders/single_file.py | 1 - 3 files changed, 402 insertions(+), 24 deletions(-) create mode 100644 examples/inpainting/test_inpainting.py diff --git a/examples/inpainting/test_inpainting.py b/examples/inpainting/test_inpainting.py new file mode 100644 index 000000000000..de5c8908b1d8 --- /dev/null +++ b/examples/inpainting/test_inpainting.py @@ -0,0 +1,384 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import shutil +import sys +import tempfile + +from diffusers import StableDiffusionInpaintPipeline, UNet2DConditionModel # noqa: E402 +from diffusers.utils import load_image + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class Inpainting(ExamplesTestsAccelerate): + def test_inpainting(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/inpainting/train_inpainting.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --r_bbox_min_size 6 + --r_bbox_max_size 40 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + + def test_inpainting_checkpointing(self): + pretrained_model_name_or_path = "hf-internal-testing/tiny-sd-pipe" + prompt = "a prompt" + init_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png" + ).resize((64, 64)) + mask = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png" + ).resize((64, 64)) + + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 4, checkpointing_steps == 2 + # Should create checkpoints at steps 2, 4 + + initial_run_args = f""" + examples/inpainting/train_inpainting.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 4 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --r_bbox_min_size 6 + --r_bbox_max_size 40 + --output_dir {tmpdir} + --checkpointing_steps=2 + --seed=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + pipe = StableDiffusionInpaintPipeline.from_pretrained(tmpdir, safety_checker=None) + pipe(prompt, image=init_image, mask_image=mask, num_inference_steps=1) + + print(os.listdir(tmpdir)) + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4"}, + ) + + # check can run an intermediate checkpoint + unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet") + pipe = StableDiffusionInpaintPipeline.from_pretrained( + pretrained_model_name_or_path, unet=unet, safety_checker=None + ) + pipe(prompt, image=init_image, mask_image=mask, num_inference_steps=1) + + # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming + shutil.rmtree(os.path.join(tmpdir, "checkpoint-2")) + + # Run training script for 2 total steps resuming from checkpoint 4 + + resume_run_args = f""" + examples/inpainting/train_inpainting.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --r_bbox_min_size 6 + --r_bbox_max_size 40 + --output_dir {tmpdir} + --checkpointing_steps=1 + --resume_from_checkpoint=checkpoint-4 + --seed=0 + """.split() + + run_command(self._launch_args + resume_run_args) + + # check can run new fully trained pipeline + pipe = StableDiffusionInpaintPipeline.from_pretrained(tmpdir, safety_checker=None) + pipe(prompt, image=init_image, mask_image=mask, num_inference_steps=1) + + # no checkpoint-2 -> check old checkpoints do not exist + # check new checkpoints exist + print(os.listdir(tmpdir)) + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-5"}, + ) + + def test_inpainting_checkpointing_use_ema(self): + pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" + prompt = "a prompt" + init_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png" + ).resize((64, 64)) + mask = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png" + ).resize((64, 64)) + + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 4, checkpointing_steps == 2 + # Should create checkpoints at steps 2, 4 + + initial_run_args = f""" + examples/inpainting/train_inpainting.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 4 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --r_bbox_min_size 6 + --r_bbox_max_size 40 + --output_dir {tmpdir} + --checkpointing_steps=2 + --use_ema + --seed=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + pipe = StableDiffusionInpaintPipeline.from_pretrained(tmpdir, safety_checker=None) + pipe(prompt, image=init_image, mask_image=mask, num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4"}, + ) + + # check can run an intermediate checkpoint + unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet") + pipe = StableDiffusionInpaintPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None) + pipe(prompt, image=init_image, mask_image=mask, num_inference_steps=1) + + # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming + shutil.rmtree(os.path.join(tmpdir, "checkpoint-2")) + + # Run training script for 2 total steps resuming from checkpoint 4 + + resume_run_args = f""" + examples/inpainting/train_inpainting.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --r_bbox_min_size 6 + --r_bbox_max_size 40 + --output_dir {tmpdir} + --checkpointing_steps=1 + --resume_from_checkpoint=checkpoint-4 + --use_ema + --seed=0 + """.split() + + run_command(self._launch_args + resume_run_args) + + # check can run new fully trained pipeline + pipe = StableDiffusionInpaintPipeline.from_pretrained(tmpdir, safety_checker=None) + pipe(prompt, image=init_image, mask_image=mask, num_inference_steps=1) + + # no checkpoint-2 -> check old checkpoints do not exist + # check new checkpoints exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-5"}, + ) + + def test_inpainting_checkpointing_checkpoints_total_limit(self): + pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" + prompt = "a prompt" + init_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png" + ).resize((64, 64)) + mask = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png" + ).resize((64, 64)) + + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2 + # Should create checkpoints at steps 2, 4, 6 + # with checkpoint at step 2 deleted + + initial_run_args = f""" + examples/text_to_image/train_text_to_image.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 6 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --r_bbox_min_size 6 + --r_bbox_max_size 40 + --output_dir {tmpdir} + --checkpointing_steps=2 + --checkpoints_total_limit=2 + --seed=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + pipe = StableDiffusionInpaintPipeline.from_pretrained(tmpdir, safety_checker=None) + pipe(prompt, image=init_image, mask_image=mask, num_inference_steps=1) + + # check checkpoint directories exist + # checkpoint-2 should have been deleted + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"}) + + def test_inpainting_checkpoints_total_limit_removes_multiple_checkpoints(self): + pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" + prompt = "a prompt" + init_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png" + ).resize((64, 64)) + mask = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png" + ).resize((64, 64)) + + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 4, checkpointing_steps == 2 + # Should create checkpoints at steps 2, 4 + + initial_run_args = f""" + examples/inpainting/train_inpainting.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 4 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --r_bbox_min_size 6 + --r_bbox_max_size 40 + --output_dir {tmpdir} + --checkpointing_steps=2 + --seed=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + pipe = StableDiffusionInpaintPipeline.from_pretrained(tmpdir, safety_checker=None) + pipe(prompt, image=init_image, mask_image=mask, num_inference_steps=1) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4"}, + ) + + # resume and we should try to checkpoint at 6, where we'll have to remove + # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint + + resume_run_args = f""" + examples/text_to_image/train_text_to_image.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 8 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --r_bbox_min_size 6 + --r_bbox_max_size 40 + --output_dir {tmpdir} + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-4 + --checkpoints_total_limit=2 + --seed=0 + """.split() + + run_command(self._launch_args + resume_run_args) + + pipe = StableDiffusionInpaintPipeline.from_pretrained(tmpdir, safety_checker=None) + pipe(prompt, image=init_image, mask_image=mask, num_inference_steps=1) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-6", "checkpoint-8"}, + ) diff --git a/examples/inpainting/train_inpainting.py b/examples/inpainting/train_inpainting.py index 37420bb4929d..18bd525bfe07 100644 --- a/examples/inpainting/train_inpainting.py +++ b/examples/inpainting/train_inpainting.py @@ -55,7 +55,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.26.0.dev0") +check_min_version("0.27.0.dev0") logger = get_logger(__name__, log_level="INFO") @@ -314,7 +314,18 @@ def save_model_card( model_card.save(os.path.join(repo_folder, "README.md")) -def log_validation(validation_dataloader, vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch): +def log_validation( + validation_dataloader, + vae, + text_encoder, + tokenizer, + unet, + args, + accelerator, + weight_dtype, + epoch, + is_final_validation=False, +): logger.info("Running validation... ") pipeline = StableDiffusionInpaintPipeline.from_pretrained( @@ -350,7 +361,6 @@ def log_validation(validation_dataloader, vae, text_encoder, tokenizer, unet, ar prompt = batch["prompts"][0] with torch.autocast("cuda"): - #### UPDATE PIPELINE HERE image = pipeline( prompt, image=init_image, mask_image=mask, num_inference_steps=20, generator=generator ).images[0] @@ -360,13 +370,14 @@ def log_validation(validation_dataloader, vae, text_encoder, tokenizer, unet, ar masks.append(mask) for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") elif tracker.name == "wandb": tracker.log( { - "validation": [ + phase_name: [ wandb.Image(make_image_grid([image, masks[i]], 1, 2), caption=f"{i}: {prompts[i]}") for i, image in enumerate(images) ] @@ -1280,15 +1291,7 @@ def unwrap_model(model): ema_unet.store(unet.parameters()) ema_unet.copy_to(unet.parameters()) log_validation( - val_dataloader, - vae, - text_encoder, - tokenizer, - unet, - args, - accelerator, - weight_dtype, - global_step, + val_dataloader, vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, global_step ) if args.use_ema: # Switch back to the original UNet parameters. @@ -1315,15 +1318,7 @@ def unwrap_model(model): if args.validation_size > 0: logger.info("Running inference for collecting generated images...") images, prompts, masks = log_validation( - val_dataloader, - vae, - text_encoder, - tokenizer, - unet, - args, - accelerator, - weight_dtype, - global_step, + val_dataloader, vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, global_step, True ) if args.push_to_hub: diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index 69c96251c609..875858ce7761 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -51,7 +51,6 @@ def build_sub_model_components( torch_dtype=None, **kwargs, ): - if component_name in pipeline_components: return {} From 5179539e7154945d29f9016315181b59a2f02362 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 3 Mar 2024 11:58:25 +0530 Subject: [PATCH 10/14] run quality --- examples/inpainting/test_inpainting.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/inpainting/test_inpainting.py b/examples/inpainting/test_inpainting.py index de5c8908b1d8..09a70d9470d3 100644 --- a/examples/inpainting/test_inpainting.py +++ b/examples/inpainting/test_inpainting.py @@ -208,7 +208,9 @@ def test_inpainting_checkpointing_use_ema(self): # check can run an intermediate checkpoint unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet") - pipe = StableDiffusionInpaintPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None) + pipe = StableDiffusionInpaintPipeline.from_pretrained( + pretrained_model_name_or_path, unet=unet, safety_checker=None + ) pipe(prompt, image=init_image, mask_image=mask, num_inference_steps=1) # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming From 8f33ed1959b532ea5893141071ce487c4e4013fc Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Thu, 11 Apr 2024 07:04:05 -1000 Subject: [PATCH 11/14] Update examples/inpainting/README.md Co-authored-by: Suraj Patil --- examples/inpainting/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/inpainting/README.md b/examples/inpainting/README.md index 8a83c74fab19..b332b00a2966 100644 --- a/examples/inpainting/README.md +++ b/examples/inpainting/README.md @@ -1,6 +1,6 @@ # Stable Diffusion Inpainting fine-tuning -The `train_inpainting.py` script shows how to fine-tune stable diffusion model on your own dataset. +The `train_inpainting.py` script shows how to train/fine-tune stable diffusion model for inpainting on your own dataset. ___Note___: From f2b04e32ad8256e2c0e9290f23ea1002c9827e47 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Thu, 11 Apr 2024 07:04:17 -1000 Subject: [PATCH 12/14] Update examples/inpainting/train_inpainting.py Co-authored-by: Suraj Patil --- examples/inpainting/train_inpainting.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/inpainting/train_inpainting.py b/examples/inpainting/train_inpainting.py index 18bd525bfe07..e7af6f96f574 100644 --- a/examples/inpainting/train_inpainting.py +++ b/examples/inpainting/train_inpainting.py @@ -807,7 +807,6 @@ def main(): # from the pre-trained checkpoints. For the extra channels added to the first layer, they are # initialized to zero. - # when most likely a text2img pretrained model is used if unet.conv_in.in_channels == 4: logger.info("Initializing the Inpainting UNet from the pretrained 4 channel UNet .") in_channels = 9 From 7dc6bfb3bd023834b62dfd03e9152b3a6d644c25 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Thu, 11 Apr 2024 07:04:30 -1000 Subject: [PATCH 13/14] Update examples/inpainting/train_inpainting.py Co-authored-by: Suraj Patil --- examples/inpainting/train_inpainting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/inpainting/train_inpainting.py b/examples/inpainting/train_inpainting.py index e7af6f96f574..9bcf757d8fa8 100644 --- a/examples/inpainting/train_inpainting.py +++ b/examples/inpainting/train_inpainting.py @@ -801,9 +801,9 @@ def main(): args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision ) - # InstructPix2Pix uses an additional image for conditioning. To accommodate that, + # For inpainting an additional image is used for conditioning. To accommodate that, # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is - # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized + # then fine-tuned on the custom inpainting dataset. This modified UNet is initialized # from the pre-trained checkpoints. For the extra channels added to the first layer, they are # initialized to zero. From d11619a27b5a67f2ed75359cde3c1af8b49c5453 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Thu, 11 Apr 2024 07:04:51 -1000 Subject: [PATCH 14/14] Update examples/inpainting/train_inpainting.py Co-authored-by: Suraj Patil --- examples/inpainting/train_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/inpainting/train_inpainting.py b/examples/inpainting/train_inpainting.py index 9bcf757d8fa8..fe62be5b7587 100644 --- a/examples/inpainting/train_inpainting.py +++ b/examples/inpainting/train_inpainting.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.