Skip to content

SDXL text to image trainer #4401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed

Conversation

CaptnSeraph
Copy link
Contributor

What does this PR do?

This is a modified pix2pix that should enable a basic text to image trainer for SDXL

fixes issue #4366 as requested by @sayakpaul

theres some small todo in the code, and it doesn't seem to work properly at the moment.

very prototype...
Replaced code with the first prototype
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 1, 2023

The documentation is not available anymore as the PR was closed or merged.

model_pred = unet(
concatenated_noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
).sample
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is interesting, we're only supporting the mse loss here. but there is an SNR gamma implementation available;

                if args.snr_gamma is None:
                    logging.debug(f"Calculating loss")
                    loss = F.mse_loss(
                        model_pred.float(), target.float(), reduction="mean"
                    )
                else:
                    # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.
                    # This is discussed in Section 4.2 of the same paper.
                    snr = compute_snr(timesteps, noise_scheduler)

                    if torch.any(torch.isnan(snr)):
                        print("snr contains NaN values")
                    if torch.any(snr == 0):
                        print("snr contains zero values")

                    mse_loss_weights = (
                        torch.stack(
                            [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1
                        ).min(dim=1)[0]
                        / snr
                    )

                    if torch.any(torch.isnan(mse_loss_weights)):
                        print("mse_loss_weights contains NaN values")
                    # We first calculate the original loss. Then we mean over the non-batch dimensions and
                    # rebalance the sample-wise losses with their respective loss weights.
                    # Finally, we take the mean of the rebalanced loss.
                    loss = F.mse_loss(
                        model_pred.float(), target.float(), reduction="none"
                    )
                    loss = (
                        loss.mean(dim=list(range(1, len(loss.shape))))
                        * mse_loss_weights
                    )
                    loss = loss.mean()

made some changes in line with @bghira code review
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi there!

Thanks for kickstarting this!

I think you could refer to https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py to lay out most of the logic.

The major change would be how we're computing the embeddings for the UNet. You could refer to the ControlNet training script for that: https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet_sdxl.py.

We would also want pre-compute the text embeddings and potentially also the VAE encodings from the images to save maximum amount of memory.

Let me know if you're feeling stuck in implementing these. More than happy to provide more directions!

@sayakpaul
Copy link
Member

Hey @CaptnSeraph! Thanks for your efforts! Will it be okay if I went into your PR and do some changes to get this merged? I will likely be able to do early next week.

@CaptnSeraph
Copy link
Contributor Author

you feel free to make any changes you want... i am very much exceeding my knowledge base on this one, but i think the basic logic might work?

@bghira
Copy link
Contributor

bghira commented Aug 2, 2023

at the risk of sounding reaaally stupid, i personally found it very difficult to cross-reference the three scripts, as it's clear different people worked on them and there is no singular way of things being handled. trying to compare the text embed creation on one vs the other, and trying to deal with removing or even identifying any extra channels added to the embeds has been something i've spent about a day working on, with little to no further understanding.

i do now understand that the batch size is in the embed, but is it supposed to be [bsize, channelcount, seq_len, embeds_per_token] or is channelcount not supposed to be there? I see [4, 1, 77, 2048] for a batch size of 4.

but nonetheless, that's not where the error was occurring, anyway. it was when the time embeds are being added into the text embeds via torch.concat, the dimension of one is 2 and the other is 3. because it's [4, 1536] for the time_embeds and [4, 1, 1280] for the text embeds, which frankly just seems wrong.

so i traced it back to the stack operation inside encode_prompts method which adds the batch size at the front dimension. this seems correct, logically, as it would need all four image-captions in the batch.

but the encode_prompts look like they behave mostly the same between all three scripts, so, it doesn't look like the problem arises there.

see:

08/02/2023 01:31:46 - INFO - __main__ - ***** Running training *****
08/02/2023 01:31:46 - INFO - __main__ -   Num examples = 5
08/02/2023 01:31:46 - INFO - __main__ -   Num Epochs = 100
08/02/2023 01:31:46 - INFO - __main__ -   Instantaneous batch size per device = 4
08/02/2023 01:31:46 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 4
08/02/2023 01:31:46 - INFO - __main__ -   Gradient Accumulation steps = 1
08/02/2023 01:31:46 - INFO - __main__ -   Total optimization steps = 378100
Steps:   0%|                                                                                                                                                                                                               | 0/378100 [00:00<?, ?it/s]08/02/2023 01:31:47 - DEBUG - __main__ - Running collate_fn
08/02/2023 01:31:47 - DEBUG - __main__ - Beginning compute_embeddings_for_prompts: ["croplands san california bishop (calif.) roofs buildings obispo meadows landscape s.e. county peak mountains bishop's photographs from luis", 'norris road. flowers along grove costa california contra county dandelions trees clouds canyon landscape walnut photographs in winter grasses', 'seas (calif.) green coastlines mountains california pacific monterey county oceans sur coast at landscape cliffs point photographs ocean velvet beaches', 'canyons mesas chelly (ariz.) channels junction overlook county streams arizona apache canyon landscape cliffs photographs from rock de formations']
08/02/2023 01:31:47 - DEBUG - __main__ - Returning computed embeddings: tensor([[[-3.8750, -2.5312,  4.7812,  ...,  0.1904,  0.4238, -0.2969],
         [-0.1348, -0.1123,  0.2812,  ...,  1.1484, -0.3359,  0.1016],
         [-0.4121, -0.0820,  0.5547,  ...,  0.4004,  0.3730, -0.2969],
         ...,
         [-0.2061,  0.4883, -0.2656,  ..., -0.4492,  0.0645,  0.2891],
         [-0.1865,  0.4766, -0.2754,  ..., -0.5078, -0.0781,  0.2715],
         [-0.1953,  0.4062, -0.2070,  ..., -0.4453,  0.1533,  0.1719]],

        [[-3.8750, -2.5312,  4.7812,  ...,  0.1904,  0.4238, -0.2969],
         [-0.3809,  0.0430, -0.6055,  ..., -0.1289,  0.3711, -0.1611],
         [ 0.0215, -0.7773, -1.4531,  ...,  0.4570,  0.5078,  0.1699],
         ...,
         [-0.0820,  0.5781, -0.3242,  ...,  0.0186,  0.1670,  0.3477],
         [-0.0840,  0.5781, -0.3203,  ..., -0.0166,  0.0430,  0.3008],
         [-0.1396,  0.5938, -0.3438,  ...,  0.2207,  0.0059,  0.4258]],

        [[-3.8750, -2.5312,  4.7812,  ...,  0.1904,  0.4238, -0.2969],
         [ 0.1191, -0.0278,  0.3867,  ...,  0.6602,  0.5195, -0.6445],
         [ 1.1250, -0.2715,  0.2656,  ..., -0.2656, -0.2256,  0.5977],
         ...,
         [ 0.0645,  0.5469, -0.2539,  ..., -0.6172,  0.3008,  0.2480],
         [ 0.0688,  0.5312, -0.2754,  ..., -0.7109,  0.1230,  0.2178],
         [ 0.0410,  0.5391, -0.2695,  ..., -0.5117,  0.2598, -0.0859]],

        [[-3.8750, -2.5312,  4.7812,  ...,  0.1904,  0.4238, -0.2969],
         [-0.4961,  0.5391,  0.4434,  ...,  0.8906,  0.7852, -0.5352],
         [-0.3633,  0.1943,  0.5547,  ..., -0.3340,  0.4961, -0.2891],
         ...,
         [ 0.1816,  0.5859,  0.0723,  ...,  0.2109, -0.1387,  0.3906],
         [ 0.1826,  0.5664,  0.0625,  ...,  0.1953, -0.2188,  0.3496],
         [ 0.1553,  0.5312,  0.0156,  ...,  0.2188,  0.1943,  0.3867]]],
       device='cuda:0', dtype=torch.bfloat16), tensor([[[ 0.4414,  0.4160, -0.7383,  ..., -0.4824, -0.1436, -0.0537]],

        [[ 0.2500,  0.1475, -0.0184,  ..., -0.2773, -1.8203,  0.9883]],

        [[-0.3906,  0.0806, -0.4922,  ..., -0.3125,  0.4531,  2.4219]],

        [[ 1.1406,  0.2793,  0.1572,  ..., -0.1040, -0.3047, -0.2617]]],
       device='cuda:0', dtype=torch.bfloat16)
08/02/2023 01:31:47 - DEBUG - __main__ - Returning collate_fn results.
08/02/2023 01:31:48 - DEBUG - __main__ - Running collate_fn
08/02/2023 01:31:48 - DEBUG - __main__ - Beginning compute_embeddings_for_prompts: ['california rock mono co. lake craters (calif.) calif. landscape mountains plants photographs county formations', 'photographs evansville indiana flowers flowering landscape and white dogwoods trees identification county grasses of dogwood vanderburgh w pink', 'hip automobiles - poles (notices) mississippi trucks vicksburg of coca-cola cans waterfronts streets county end utility warren plants porches south vicksburg. company signs seven-up houses roofs front dogs river', 'echo lake calif.) nev.) mountains l. (calif. toward dorado california (el ponds north county tahoe trees us landscape below placer photographs from summit just lakes looking 50 and']
08/02/2023 01:31:48 - DEBUG - __main__ - Returning computed embeddings: tensor([[[-3.8750e+00, -2.5312e+00,  4.7812e+00,  ...,  1.9043e-01,
           4.2383e-01, -2.9688e-01],
         [ 2.7539e-01, -9.7266e-01, -5.6250e-01,  ..., -2.5391e-01,
           8.6719e-01, -3.2031e-01],
         [ 1.8359e-01, -3.4375e-01, -7.4219e-02,  ...,  1.0234e+00,
          -1.1182e-01,  6.8359e-03],
         ...,
         [-6.3477e-02,  3.6719e-01, -7.2656e-01,  ..., -1.7871e-01,
           2.4902e-01,  2.6172e-01],
         [-5.1514e-02,  3.5742e-01, -7.2656e-01,  ..., -2.4414e-01,
          -3.5156e-02,  1.9531e-01],
         [-3.7842e-02,  3.4766e-01, -6.9531e-01,  ...,  5.9570e-02,
           3.7109e-02,  2.6562e-01]],

        [[-3.8750e+00, -2.5312e+00,  4.7812e+00,  ...,  1.9043e-01,
           4.2383e-01, -2.9688e-01],
         [-2.9102e-01,  1.9141e-01,  5.7812e-01,  ...,  7.0312e-02,
           5.1562e-01, -2.2949e-01],
         [ 1.5625e-01,  2.6367e-01,  3.9062e-02,  ...,  3.6328e-01,
           1.0742e-01, -5.5078e-01],
         ...,
         [ 2.0703e-01,  2.7344e-01, -7.5000e-01,  ..., -2.3633e-01,
           4.1602e-01,  5.2344e-01],
         [ 2.0020e-01,  2.6562e-01, -7.3047e-01,  ..., -3.1445e-01,
           2.6367e-01,  4.5898e-01],
         [ 1.2256e-01,  2.8320e-01, -8.2812e-01,  ..., -2.0703e-01,
           5.1562e-01,  2.6172e-01]],

        [[-3.8750e+00, -2.5312e+00,  4.7812e+00,  ...,  1.9043e-01,
           4.2383e-01, -2.9688e-01],
         [-1.0312e+00, -2.6953e-01, -4.5508e-01,  ..., -5.2734e-01,
          -8.6328e-01,  6.9336e-02],
         [ 8.6719e-01, -3.0273e-01, -1.8359e-01,  ...,  7.8125e-01,
          -2.8516e-01,  3.3594e-01],
         ...,
         [ 2.7148e-01,  3.8672e-01, -6.3965e-02,  ..., -3.2031e-01,
          -1.4648e-01,  3.4766e-01],
         [ 2.4707e-01,  4.2188e-01, -7.5684e-02,  ..., -2.4902e-01,
          -2.5000e-01,  2.6172e-01],
         [ 3.6719e-01,  3.0469e-01,  2.0996e-02,  ..., -3.3789e-01,
           4.6875e-02,  4.8828e-02]],

        [[-3.8750e+00, -2.5312e+00,  4.7812e+00,  ...,  1.9043e-01,
           4.2383e-01, -2.9688e-01],
         [ 8.4961e-02,  9.1406e-01,  4.7852e-01,  ...,  4.7852e-01,
           6.1328e-01,  3.7500e-01],
         [ 1.3281e+00, -2.5781e-01, -2.7930e-01,  ..., -1.0742e-01,
           6.3281e-01, -3.4766e-01],
         ...,
         [ 2.0996e-01,  2.7539e-01, -4.1406e-01,  ..., -3.1250e-01,
           4.0039e-02,  3.5547e-01],
         [ 2.0703e-01,  2.7539e-01, -4.2578e-01,  ..., -3.2812e-01,
          -8.8867e-02,  2.9883e-01],
         [ 2.1094e-01,  2.0215e-01, -4.3750e-01,  ..., -1.5332e-01,
           1.9531e-02,  1.9531e-03]]], device='cuda:0', dtype=torch.bfloat16), tensor([[[ 0.7188,  0.7227,  0.4023,  ..., -0.0205, -1.8125, -0.1357]],

        [[ 0.8359,  0.0864, -0.5664,  ..., -2.1875, -1.7891, -0.4141]],

        [[-0.7617, -0.9883, -0.1138,  ..., -0.7734,  0.7031,  0.2461]],

        [[ 0.3887, -0.5430, -0.3848,  ..., -0.4688, -0.0189,  0.1816]]],
       device='cuda:0', dtype=torch.bfloat16)
08/02/2023 01:31:48 - DEBUG - __main__ - Returning collate_fn results.
08/02/2023 01:31:50 - DEBUG - __main__ - Encoder hidden states: torch.Size([4, 77, 2048])
08/02/2023 01:31:50 - DEBUG - __main__ - Added text embeds: torch.Size([4, 1, 1280])
08/02/2023 01:31:50 - DEBUG - __main__ - Conditioning dropout: None
08/02/2023 01:31:50 - DEBUG - __main__ - Concatenate the `original_image_embeds` torch.Size([4, 4, 128, 187]) with the `noisy_latents` torch.Size([4, 4, 128, 187]).
08/02/2023 01:31:50 - DEBUG - __main__ - Using epsilon prediction.
08/02/2023 01:31:51 - DEBUG - __main__ - add_text_embeds: torch.Size([4, 1, 1280]), time_ids: tensor([[1024., 1024.,    0.,    0., 1024., 1024.],
        [1024., 1024.,    0.,    0., 1024., 1024.],
        [1024., 1024.,    0.,    0., 1024., 1024.],
        [1024., 1024.,    0.,    0., 1024., 1024.]], device='cuda:0',
       dtype=torch.bfloat16)
Time ids: tensor([[1024., 1024.,    0.,    0., 1024., 1024.],
        [1024., 1024.,    0.,    0., 1024., 1024.],
        [1024., 1024.,    0.,    0., 1024., 1024.],
        [1024., 1024.,    0.,    0., 1024., 1024.]], device='cuda:0',
       dtype=torch.bfloat16)
Time embeds before reshape by ((4, -1)): torch.Size([24, 256])
Time embeds shape: torch.Size([4, 1536])
Text embeds shape: torch.Size([4, 1, 1280])

@bghira
Copy link
Contributor

bghira commented Aug 2, 2023

note that i've bypassed all of the precomputing of the embeds for now because it was frustrating to wait so long for it to fail at the forward pass :D

add_time_ids = compute_time_ids()
def preprocess_train(examples):

with accelerator.main_process_first():
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May preprocess_train be omitted wrongly?

@patrickvonplaten
Copy link
Contributor

Exciting

original_image_embeds = image_mask * original_image_embeds

# Concatenate the `original_image_embeds` with the `noisy_latents`.
concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

text-to-image may not need original_image_embeds.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes in fact this will make the image input 8 channels, when it should be 4.

we need to pull the add noise method from the LoRA script.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Can you update a new version of this PR? In addition, the function preprocess_train in line 790 may be missed.

@sayakpaul
Copy link
Member

@CaptnSeraph I took the liberty to build on top of this PR and start a new one here: #4505.

To honor your contributions, I have also made you a co-author of the commits :-)

So, I am closing this PR and instead, let's work on #4505. I am currently running experiments with the current changes.

@sayakpaul sayakpaul closed this Aug 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants