Skip to content

add: train to text image with sdxl script. #4505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Aug 16, 2023

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Aug 7, 2023

Closes #4366 and builds on top of #4401.

Many thanks to @CaptnSeraph for laying out the foundations here.

To test:

export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export VAE="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"

accelerate launch train_text_to_image_sdxl.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --pretrained_vae_model_name_or_path=$VAE \
  --dataset_name=$DATASET_NAME \
  --enable_xformers_memory_efficient_attention \
  --resolution=512 --center_crop --random_flip \
  --proportion_empty_prompts=0.2 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 --gradient_checkpointing \
  --max_train_steps=15000 \
  --use_8bit_adam \
  --learning_rate=1e-05 --lr_scheduler="constant" --lr_warmup_steps=0 \
  --mixed_precision="fp16" \
  --report_to="wandb" \
  --validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \
  --checkpointing_steps=5000 \
  --output_dir="sdxl-pokemon-model" \
  --push_to_hub

TODOs

  • Tests
  • Docs
  • Share results

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 7, 2023

The documentation is not available anymore as the PR was closed or merged.

torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, unet=unet, vae=vae, revision=args.revision, torch_dtype=weight_dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can use precomputed embeds here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, i pinged the wrong line. when we run validations, we don't need to load the text encoder onto GPU.

Copy link
Member Author

@sayakpaul sayakpaul Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't support such device placements when initializing and using a pipeline. When we call to() on a pipeline, all the nn.Module components are placed on the same device.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah! i just modified my local copy to accept None for text encoders, similar to how, the Kandinsky and DeepFloyd pipelines work.

@bghira
Copy link
Contributor

bghira commented Aug 7, 2023

you might want to precompute the latents, they don't take much more room and the use of VAE during tuning really hampers max batch size. in fact, the text embeds take a lot more disk space than the latents.

Comment on lines +770 to +771
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
image = crop(image, y1, x1, h, w)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! that may save some VRAM. but my approach was to go with data bucketing. i don't notice much issue with 1536x1024 training data at batch size 10 with the #4474 fix

crop_top_left = (y1, x1)
crop_top_lefts.append(crop_top_left)
image = train_transforms(image)
all_images.append(image)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ouch, this might run out of memory.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# fingerprint used by the cache for the other processes to load the result
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
new_fingerprint = Hasher.hash(args)
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we're not writing these to disk, that's also a lot of memory to consume.

on a 138,000 image dataset i've processed, it uses 44G of system memory for text embeds, and 20GB of memory for the VAE latents. it's really not viable to hold them all in memory for large jobs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, one would serialize those, correct. But I would like to bring your attention to this comment here :-)

#4505 (comment)

@sayakpaul
Copy link
Member Author

@bghira thanks for suggestions. For v1, I would like to keep things as they are. While your suggestions are pretty nice, we, as maintainers of the library, follow our guidelines here: https://huggingface.co/docs/diffusers/main/en/training/overview.

So, this means, in many cases, we will prioritize simplicity over too much exhaustivity. This is why, we try to keep the scripts simple enough so that others can easily customize them as per their needs.

@bghira
Copy link
Contributor

bghira commented Aug 7, 2023

i appreciate simplicity, but it gimps the utility of the script to a degree of requiring users to make extensive changes for it. as i had a feeling this would happen, it's why i didn't bother doing the work of writing the PR, instead opting to implement all of my suggestions in https://github.com/bghira/SimpleTuner.

@sayakpaul
Copy link
Member Author

@bghira let's revisit some of your suggestions here, as on a retrospect, I think they make a lot of sense :-)

on a 138,000 image dataset i've processed, it uses 44G of system memory for text embeds, and 20GB of memory for the VAE latents. it's really not viable to hold them all in memory for large jobs.

What would you recommend here?

@bghira
Copy link
Contributor

bghira commented Aug 7, 2023

i have a vae_cache folder in my implementation that i write .pt files for all of the embeds, and the encode function reads from disk if it is there, instead of computing it.

@bghira
Copy link
Contributor

bghira commented Aug 7, 2023

here is a naive implementation that uses multiprocessing (lol, it does not, i was thinking of my data loader) but it even has a progress bar! :D

import hashlib, os, torch, logging
from tqdm import tqdm
from PIL import Image
import torchvision.transforms as transforms

logger = logging.getLogger("VAECache")
logger.setLevel("INFO")


class VAECache:
    def __init__(self, vae, accelerator, cache_dir="vae_cache", resolution: int = 1024):
        self.vae = vae
        self.vae.enable_slicing()
        self.accelerator = accelerator
        self.cache_dir = cache_dir
        self.resolution = resolution
        os.makedirs(self.cache_dir, exist_ok=True)

    def create_hash(self, filename):
        # Create a sha256 hash
        sha256_hash = hashlib.sha256()

        # Feed the hash function with the filename
        sha256_hash.update(filename.encode())

        # Get the hexadecimal representation of the hash
        return sha256_hash.hexdigest()

    def save_to_cache(self, filename, embeddings):
        torch.save(embeddings, filename)

    def load_from_cache(self, filename):
        return torch.load(filename)

    def encode_image(self, pixel_values, filepath: str):
        file_hash = self.create_hash(filepath)
        filename = os.path.join(self.cache_dir, file_hash + ".pt")
        logger.debug(f'Created file_hash {file_hash} from filepath {filepath} for resulting .pt filename.')
        if os.path.exists(filename):
            latents = self.load_from_cache(filename)
            logger.debug(
                f"Loading latents of shape {latents.shape} from existing cache file: {filename}"
            )
        else:
            with torch.no_grad():
                latents = self.vae.encode(
                    pixel_values.unsqueeze(0).to(
                        self.accelerator.device, dtype=torch.bfloat16
                    )
                ).latent_dist.sample()
                logger.debug(
                    f"Using shape {latents.shape}, creating new latent cache: {filename}"
                )
            latents = latents * self.vae.config.scaling_factor
            logger.debug(f"Latent shape after re-scale: {latents.shape}")
            self.save_to_cache(filename, latents.squeeze())

        output_latents = latents.squeeze().to(
            self.accelerator.device, dtype=self.vae.dtype
        )
        logger.debug(f"Output latents shape: {output_latents.shape}")
        return output_latents

    def process_directory(self, directory):
        # Define a transform to convert the image to tensor
        transform = transforms.ToTensor()

        # Get a list of all the files to process (customize as needed)
        files_to_process = []
        logger.debug(f"Beginning processing of VAECache directory {directory}")
        for subdir, _, files in os.walk(directory):
            for file in files:
                if file.endswith((".png", ".jpg", ".jpeg")):
                    logger.debug(f"Discovered image: {os.path.join(subdir, file)}")
                    files_to_process.append(os.path.join(subdir, file))

        # Iterate through the files, displaying a progress bar
        for filepath in tqdm(files_to_process, desc="Processing images"):
            # Create a hash based on the filename
            file_hash = self.create_hash(filepath)
            filename = os.path.join(self.cache_dir, file_hash + ".pt")

            # If processed file already exists, skip processing for this image
            if os.path.exists(filename):
                logger.debug(
                    f"Skipping processing for {filepath} as cached file {filename} already exists."
                )
                continue

            # Open the image using PIL
            try:
                logger.debug(f"Loading image: {filepath}")
                image = Image.open(filepath)
                image = image.convert("RGB")
                image = self._resize_for_condition_image(image, self.resolution)
            except Exception as e:
                logger.error(f"Encountered error opening image: {e}")
                os.remove(filepath)
                continue

            # Convert the image to a tensor
            try:
                pixel_values = transform(image).to(
                    self.accelerator.device, dtype=self.vae.dtype
                )
            except OSError as e:
                logger.error(f"Encountered error converting image to tensor: {e}")
                continue

            # Process the image with the VAE
            self.encode_image(pixel_values, filepath)

            logger.debug(f"Processed image {filepath}")

    def _resize_for_condition_image(self, input_image: Image, resolution: int):
        input_image = input_image.convert("RGB")
        W, H = input_image.size
        aspect_ratio = round(W / H, 3)
        msg = f"Inspecting image of aspect {aspect_ratio} and size {W}x{H} to "
        if W < H:
            W = resolution
            H = int(resolution / aspect_ratio)  # Calculate the new height
        elif H < W:
            H = resolution
            W = int(resolution * aspect_ratio)  # Calculate the new width
        if W == H:
            W = resolution
            H = resolution
        msg = f"{msg} {W}x{H}."
        logger.debug(msg)
        img = input_image.resize((W, H), resample=Image.BICUBIC)
        return img

@sayakpaul
Copy link
Member Author

Thanks for providing the snippets! If we compute the VAE encodings like this, then it creates a problem during the batch preparation, as the images are no longer of uniform shape. I guess we need to also apply a cropping (of resolution shape) here, no?

Also, for my own understanding:

  • At this stage of pre-computing, I think we also need to store the original_size and crop_top_lefts information in the .pt files.
  • What about the pre-computing the text embeddings too? How would you suggest approaching it?

@sayakpaul
Copy link
Member Author

As it is now, with my repository's approach I can do multi-aspect base-1024px images that stretch up to 1.5 megapixel, on a 48G GPU with batch size 4.

Elaborate on one part. When you say multi-aspect 1024px, do you mean multiple aspect ratios while keeping the base resolution to 1024?

@bghira
Copy link
Contributor

bghira commented Aug 10, 2023

i use the condition resize method that keeps images at 64px step increments, and resize smaller edge to 1024px.

i then factor 1024 by the image aspect ratio to get the other side's length. by doing it this way, an entire batch will have the same pixel-perfect sizing.

i'm not doing what StabilityAI did, which is to precompute 1 megapixel resolutions at various aspect ratios, and then conform the images to that. my images can go quite large. this is because the VAE was the true limiting factor there.

@sayakpaul
Copy link
Member Author

Thanks for explaining! Do you have some code reference for me? Would love to understand and visualize this preferably in a notebook. Could be a valuable resource to the community!

@bghira
Copy link
Contributor

bghira commented Aug 10, 2023

it has some problems, and the code is more convoluted than i'd like. it can be greatly simplified, but here is what i've been using.

a notable difference between my implementation and others is that i don't have random sampling of images. we do them once and put them into a 'seen' list, and then don't reference those images again until the next epoch. i noted the Diffusers sampler defaults often tend to over-sample images.

@sayakpaul
Copy link
Member Author

a notable difference between my implementation and others is that i don't have random sampling of images. we do them once and put them into a 'seen' list, and then don't reference those images again until the next epoch. i noted the Diffusers sampler defaults often tend to over-sample images.

Elaborate a bit? How crucial is this for impacting / degrading quality? Have you experienced it?

Thanks for providing the reference!

@bghira
Copy link
Contributor

bghira commented Aug 10, 2023

I have experienced it.

oversampling on some images over others can lead to an uneven distribution of timesteps per image, which results in some of the training data overfitting and others, underfitting.

this is exacerbated with aspect buckets because you might have an uneven distribution of images in each. and random sampling of aspects then occurs on top of random sampling of images, resulting in potentially not ever seeing some of the training data.

it depends on how large your dataset is, and how much time you can really dedicate to the task. it is made worse when tuning the text encoder alongside the u-net, especially because you're not doing caption dropout. things are more likely to overfit on captions as well as image features.

@bghira
Copy link
Contributor

bghira commented Aug 10, 2023

a large multi-aspect dataset from LAION might have more than 60% of the images in the 1.0 square aspect.

so by ensuring you can safely sequentially sample each bucket, the chronically underfilled ones are sampled entirely, and the bulk of remaining training time is on the majority of square images. you could slice the buckets so that they're all evenly filled.

but a colleague ( @kaibioinfo ) mentioned an interesting idea where we could crop images down and train on tiled versions of high-res images with their complete coordinates available as conditioning inputs.

honestly, a lot of my problems with data bucketing could be resolved through clever utilisation of cropping.

@sayakpaul
Copy link
Member Author

I think all of this could make an interesting utility repository for different dataloaders for SD training haha. Thanks for sharing your experience and wisdom.

@bghira
Copy link
Contributor

bghira commented Aug 10, 2023

there's so much to test and so little GPU hours to go around 😂 thanks for being receptive to these changes

@bghira
Copy link
Contributor

bghira commented Aug 10, 2023

another idea i had was an aspect bucket for each base resolution 256, 512, 768 and 1024 so that we can make the best use of SDXL's conditioning values, and opening the training data pool up in a massive way.

@sayakpaul sayakpaul marked this pull request as ready for review August 11, 2023 12:10
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@sayakpaul sayakpaul merged commit 5175d3d into main Aug 16, 2023
@sayakpaul sayakpaul deleted the feat/training-sdxl-text-to-image branch August 16, 2023 03:32
@AmericanPresidentJimmyCarter
Copy link
Contributor

I think maybe you should leave the precomputation steps in a community examples section and allow the current training script to use them. Precomputation of embeds and latents is something I do to finetune most models I work on, not just SDXL, so for me a more general solution in a training script that leverages the ability to directly use precomputed embeds/latents is useful. It could be a CLI option for the training script, --precomputed-text-embeds and --precompute-image-latents which skips the loading of the text encoder and VAE and pulls them from the dataset instead.

@little-misfit
Copy link

even if not precomputed, keeping the late

So, you mean for the first epoch we first generate and save them and free th VAE. For the subsequent epochs we read from the disk?

hi, I want to know if precomputed VAE embedding will lead to the inability to use image data augmentation (because the precomputed VAE locks the content of the image pixels), thanks :)

AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* add: train to text image with sdxl script.

Co-authored-by: CaptnSeraph <[email protected]>

* fix: partial func.

* fix: default value of output_dir.

* make style

* set num inference steps to 25.

* remove mentions of LoRA.

* up min version

* add: ema cli arg

* run device placement while running step.

* precompute vae encodings too.

* fix

* debug

* should work now.

* debug

* debug

* goes alright?

* style

* debugging

* debugging

* debugging

* debugging

* fix

* reinit scheduler if prediction_type was passed.

* akways cast vae in float32

* better handling of snr.

Co-authored-by: bghira <[email protected]>

* the vae should be also passed

* add: docs.

* add: sdlx t2i tests

* save the pipeline

* autocast.

* fix: save_model_card

* fix: save_model_card.

---------

Co-authored-by: CaptnSeraph <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: bghira <[email protected]>
@cfeng16
Copy link

cfeng16 commented Sep 29, 2024

thanks for your efforts! just wondering if i want to finetune SDXL on a large dataset (say 400k), would the bahavior of precomputing embedding and saving in memory cause memory issue? how can i deal with it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[SD-XL] Fine-tuning text-to-image script for full U-net
8 participants