-
Notifications
You must be signed in to change notification settings - Fork 5.9k
SDXL text to image trainer #4401
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
very prototype...
Replaced code with the first prototype
The documentation is not available anymore as the PR was closed or merged. |
model_pred = unet( | ||
concatenated_noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | ||
).sample | ||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is interesting, we're only supporting the mse loss here. but there is an SNR gamma implementation available;
if args.snr_gamma is None:
logging.debug(f"Calculating loss")
loss = F.mse_loss(
model_pred.float(), target.float(), reduction="mean"
)
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps, noise_scheduler)
if torch.any(torch.isnan(snr)):
print("snr contains NaN values")
if torch.any(snr == 0):
print("snr contains zero values")
mse_loss_weights = (
torch.stack(
[snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1
).min(dim=1)[0]
/ snr
)
if torch.any(torch.isnan(mse_loss_weights)):
print("mse_loss_weights contains NaN values")
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(
model_pred.float(), target.float(), reduction="none"
)
loss = (
loss.mean(dim=list(range(1, len(loss.shape))))
* mse_loss_weights
)
loss = loss.mean()
made some changes in line with @bghira code review
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi there!
Thanks for kickstarting this!
I think you could refer to https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py to lay out most of the logic.
The major change would be how we're computing the embeddings for the UNet. You could refer to the ControlNet training script for that: https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet_sdxl.py.
We would also want pre-compute the text embeddings and potentially also the VAE encodings from the images to save maximum amount of memory.
Let me know if you're feeling stuck in implementing these. More than happy to provide more directions!
Hey @CaptnSeraph! Thanks for your efforts! Will it be okay if I went into your PR and do some changes to get this merged? I will likely be able to do early next week. |
you feel free to make any changes you want... i am very much exceeding my knowledge base on this one, but i think the basic logic might work? |
at the risk of sounding reaaally stupid, i personally found it very difficult to cross-reference the three scripts, as it's clear different people worked on them and there is no singular way of things being handled. trying to compare the text embed creation on one vs the other, and trying to deal with removing or even identifying any extra channels added to the embeds has been something i've spent about a day working on, with little to no further understanding. i do now understand that the batch size is in the embed, but is it supposed to be but nonetheless, that's not where the error was occurring, anyway. it was when the time embeds are being added into the text embeds via so i traced it back to the stack operation inside but the encode_prompts look like they behave mostly the same between all three scripts, so, it doesn't look like the problem arises there. see:
|
note that i've bypassed all of the precomputing of the embeds for now because it was frustrating to wait so long for it to fail at the forward pass :D |
add_time_ids = compute_time_ids() | ||
def preprocess_train(examples): | ||
|
||
with accelerator.main_process_first(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May preprocess_train
be omitted wrongly?
Exciting |
original_image_embeds = image_mask * original_image_embeds | ||
|
||
# Concatenate the `original_image_embeds` with the `noisy_latents`. | ||
concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
text-to-image may not need original_image_embeds
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes in fact this will make the image input 8 channels, when it should be 4.
we need to pull the add noise method from the LoRA script.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. Can you update a new version of this PR? In addition, the function preprocess_train
in line 790 may be missed.
@CaptnSeraph I took the liberty to build on top of this PR and start a new one here: #4505. To honor your contributions, I have also made you a co-author of the commits :-) So, I am closing this PR and instead, let's work on #4505. I am currently running experiments with the current changes. |
What does this PR do?
This is a modified pix2pix that should enable a basic text to image trainer for SDXL
fixes issue #4366 as requested by @sayakpaul
theres some small todo in the code, and it doesn't seem to work properly at the moment.