-
Notifications
You must be signed in to change notification settings - Fork 5.9k
add: train to text image with sdxl script. #4505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: CaptnSeraph <[email protected]>
The documentation is not available anymore as the PR was closed or merged. |
torch_dtype=weight_dtype, | ||
) | ||
pipeline = StableDiffusionXLPipeline.from_pretrained( | ||
args.pretrained_model_name_or_path, unet=unet, vae=vae, revision=args.revision, torch_dtype=weight_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can use precomputed embeds here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, i pinged the wrong line. when we run validations, we don't need to load the text encoder onto GPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't support such device placements when initializing and using a pipeline. When we call to()
on a pipeline, all the nn.Module
components are placed on the same device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah! i just modified my local copy to accept None for text encoders, similar to how, the Kandinsky and DeepFloyd pipelines work.
you might want to precompute the latents, they don't take much more room and the use of VAE during tuning really hampers max batch size. in fact, the text embeds take a lot more disk space than the latents. |
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) | ||
image = crop(image, y1, x1, h, w) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! that may save some VRAM. but my approach was to go with data bucketing. i don't notice much issue with 1536x1024 training data at batch size 10 with the #4474 fix
crop_top_left = (y1, x1) | ||
crop_top_lefts.append(crop_top_left) | ||
image = train_transforms(image) | ||
all_images.append(image) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ouch, this might run out of memory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# fingerprint used by the cache for the other processes to load the result | ||
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401 | ||
new_fingerprint = Hasher.hash(args) | ||
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we're not writing these to disk, that's also a lot of memory to consume.
on a 138,000 image dataset i've processed, it uses 44G of system memory for text embeds, and 20GB of memory for the VAE latents. it's really not viable to hold them all in memory for large jobs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case, one would serialize those, correct. But I would like to bring your attention to this comment here :-)
@bghira thanks for suggestions. For v1, I would like to keep things as they are. While your suggestions are pretty nice, we, as maintainers of the library, follow our guidelines here: https://huggingface.co/docs/diffusers/main/en/training/overview. So, this means, in many cases, we will prioritize simplicity over too much exhaustivity. This is why, we try to keep the scripts simple enough so that others can easily customize them as per their needs. |
i appreciate simplicity, but it gimps the utility of the script to a degree of requiring users to make extensive changes for it. as i had a feeling this would happen, it's why i didn't bother doing the work of writing the PR, instead opting to implement all of my suggestions in https://github.com/bghira/SimpleTuner. |
@bghira let's revisit some of your suggestions here, as on a retrospect, I think they make a lot of sense :-)
What would you recommend here? |
i have a vae_cache folder in my implementation that i write .pt files for all of the embeds, and the encode function reads from disk if it is there, instead of computing it. |
here is a naive implementation that import hashlib, os, torch, logging
from tqdm import tqdm
from PIL import Image
import torchvision.transforms as transforms
logger = logging.getLogger("VAECache")
logger.setLevel("INFO")
class VAECache:
def __init__(self, vae, accelerator, cache_dir="vae_cache", resolution: int = 1024):
self.vae = vae
self.vae.enable_slicing()
self.accelerator = accelerator
self.cache_dir = cache_dir
self.resolution = resolution
os.makedirs(self.cache_dir, exist_ok=True)
def create_hash(self, filename):
# Create a sha256 hash
sha256_hash = hashlib.sha256()
# Feed the hash function with the filename
sha256_hash.update(filename.encode())
# Get the hexadecimal representation of the hash
return sha256_hash.hexdigest()
def save_to_cache(self, filename, embeddings):
torch.save(embeddings, filename)
def load_from_cache(self, filename):
return torch.load(filename)
def encode_image(self, pixel_values, filepath: str):
file_hash = self.create_hash(filepath)
filename = os.path.join(self.cache_dir, file_hash + ".pt")
logger.debug(f'Created file_hash {file_hash} from filepath {filepath} for resulting .pt filename.')
if os.path.exists(filename):
latents = self.load_from_cache(filename)
logger.debug(
f"Loading latents of shape {latents.shape} from existing cache file: {filename}"
)
else:
with torch.no_grad():
latents = self.vae.encode(
pixel_values.unsqueeze(0).to(
self.accelerator.device, dtype=torch.bfloat16
)
).latent_dist.sample()
logger.debug(
f"Using shape {latents.shape}, creating new latent cache: {filename}"
)
latents = latents * self.vae.config.scaling_factor
logger.debug(f"Latent shape after re-scale: {latents.shape}")
self.save_to_cache(filename, latents.squeeze())
output_latents = latents.squeeze().to(
self.accelerator.device, dtype=self.vae.dtype
)
logger.debug(f"Output latents shape: {output_latents.shape}")
return output_latents
def process_directory(self, directory):
# Define a transform to convert the image to tensor
transform = transforms.ToTensor()
# Get a list of all the files to process (customize as needed)
files_to_process = []
logger.debug(f"Beginning processing of VAECache directory {directory}")
for subdir, _, files in os.walk(directory):
for file in files:
if file.endswith((".png", ".jpg", ".jpeg")):
logger.debug(f"Discovered image: {os.path.join(subdir, file)}")
files_to_process.append(os.path.join(subdir, file))
# Iterate through the files, displaying a progress bar
for filepath in tqdm(files_to_process, desc="Processing images"):
# Create a hash based on the filename
file_hash = self.create_hash(filepath)
filename = os.path.join(self.cache_dir, file_hash + ".pt")
# If processed file already exists, skip processing for this image
if os.path.exists(filename):
logger.debug(
f"Skipping processing for {filepath} as cached file {filename} already exists."
)
continue
# Open the image using PIL
try:
logger.debug(f"Loading image: {filepath}")
image = Image.open(filepath)
image = image.convert("RGB")
image = self._resize_for_condition_image(image, self.resolution)
except Exception as e:
logger.error(f"Encountered error opening image: {e}")
os.remove(filepath)
continue
# Convert the image to a tensor
try:
pixel_values = transform(image).to(
self.accelerator.device, dtype=self.vae.dtype
)
except OSError as e:
logger.error(f"Encountered error converting image to tensor: {e}")
continue
# Process the image with the VAE
self.encode_image(pixel_values, filepath)
logger.debug(f"Processed image {filepath}")
def _resize_for_condition_image(self, input_image: Image, resolution: int):
input_image = input_image.convert("RGB")
W, H = input_image.size
aspect_ratio = round(W / H, 3)
msg = f"Inspecting image of aspect {aspect_ratio} and size {W}x{H} to "
if W < H:
W = resolution
H = int(resolution / aspect_ratio) # Calculate the new height
elif H < W:
H = resolution
W = int(resolution * aspect_ratio) # Calculate the new width
if W == H:
W = resolution
H = resolution
msg = f"{msg} {W}x{H}."
logger.debug(msg)
img = input_image.resize((W, H), resample=Image.BICUBIC)
return img |
Thanks for providing the snippets! If we compute the VAE encodings like this, then it creates a problem during the batch preparation, as the images are no longer of uniform shape. I guess we need to also apply a cropping (of Also, for my own understanding:
|
Elaborate on one part. When you say multi-aspect 1024px, do you mean multiple aspect ratios while keeping the base resolution to 1024? |
i use the condition resize method that keeps images at 64px step increments, and resize smaller edge to 1024px. i then factor 1024 by the image aspect ratio to get the other side's length. by doing it this way, an entire batch will have the same pixel-perfect sizing. i'm not doing what StabilityAI did, which is to precompute 1 megapixel resolutions at various aspect ratios, and then conform the images to that. my images can go quite large. this is because the VAE was the true limiting factor there. |
Thanks for explaining! Do you have some code reference for me? Would love to understand and visualize this preferably in a notebook. Could be a valuable resource to the community! |
it has some problems, and the code is more convoluted than i'd like. it can be greatly simplified, but here is what i've been using. a notable difference between my implementation and others is that i don't have random sampling of images. we do them once and put them into a 'seen' list, and then don't reference those images again until the next epoch. i noted the Diffusers sampler defaults often tend to over-sample images. |
Elaborate a bit? How crucial is this for impacting / degrading quality? Have you experienced it? Thanks for providing the reference! |
I have experienced it. oversampling on some images over others can lead to an uneven distribution of timesteps per image, which results in some of the training data overfitting and others, underfitting. this is exacerbated with aspect buckets because you might have an uneven distribution of images in each. and random sampling of aspects then occurs on top of random sampling of images, resulting in potentially not ever seeing some of the training data. it depends on how large your dataset is, and how much time you can really dedicate to the task. it is made worse when tuning the text encoder alongside the u-net, especially because you're not doing caption dropout. things are more likely to overfit on captions as well as image features. |
a large multi-aspect dataset from LAION might have more than 60% of the images in the 1.0 square aspect. so by ensuring you can safely sequentially sample each bucket, the chronically underfilled ones are sampled entirely, and the bulk of remaining training time is on the majority of square images. you could slice the buckets so that they're all evenly filled. but a colleague ( @kaibioinfo ) mentioned an interesting idea where we could crop images down and train on tiled versions of high-res images with their complete coordinates available as conditioning inputs. honestly, a lot of my problems with data bucketing could be resolved through clever utilisation of cropping. |
I think all of this could make an interesting utility repository for different dataloaders for SD training haha. Thanks for sharing your experience and wisdom. |
there's so much to test and so little GPU hours to go around 😂 thanks for being receptive to these changes |
another idea i had was an aspect bucket for each base resolution 256, 512, 768 and 1024 so that we can make the best use of SDXL's conditioning values, and opening the training data pool up in a massive way. |
Co-authored-by: bghira <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
I think maybe you should leave the precomputation steps in a community examples section and allow the current training script to use them. Precomputation of embeds and latents is something I do to finetune most models I work on, not just SDXL, so for me a more general solution in a training script that leverages the ability to directly use precomputed embeds/latents is useful. It could be a CLI option for the training script, |
hi, I want to know if precomputed VAE embedding will lead to the inability to use image data augmentation (because the precomputed VAE locks the content of the image pixels), thanks :) |
* add: train to text image with sdxl script. Co-authored-by: CaptnSeraph <[email protected]> * fix: partial func. * fix: default value of output_dir. * make style * set num inference steps to 25. * remove mentions of LoRA. * up min version * add: ema cli arg * run device placement while running step. * precompute vae encodings too. * fix * debug * should work now. * debug * debug * goes alright? * style * debugging * debugging * debugging * debugging * fix * reinit scheduler if prediction_type was passed. * akways cast vae in float32 * better handling of snr. Co-authored-by: bghira <[email protected]> * the vae should be also passed * add: docs. * add: sdlx t2i tests * save the pipeline * autocast. * fix: save_model_card * fix: save_model_card. --------- Co-authored-by: CaptnSeraph <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: bghira <[email protected]>
thanks for your efforts! just wondering if i want to finetune SDXL on a large dataset (say 400k), would the bahavior of precomputing embedding and saving in memory cause memory issue? how can i deal with it? |
Closes #4366 and builds on top of #4401.
Many thanks to @CaptnSeraph for laying out the foundations here.
To test:
TODOs