Skip to content

[LoRA] feat: lora support for SANA. #10234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Dec 18, 2024
Merged

[LoRA] feat: lora support for SANA. #10234

merged 18 commits into from
Dec 18, 2024

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Dec 16, 2024

What does this PR do?

  • LoRA support to SANA
  • LoRA fine-tuning script
  • Tests for the training script
  • README for the training script

Example LoRA fine-tuning command:

CUDA_VISIBLE_DEVICES=0 accelerate launch train_dreambooth_lora_sana.py \
  --pretrained_model_name_or_path=Efficient-Large-Model/Sana_1600M_1024px_diffusers \
  --dataset_name=Norod78/Yarn-art-style --instance_prompt="a puppy, yarn art style" \
  --output_dir=yarn_art_lora_sana \
  --mixed_precision=bf16 --use_8bit_adam \
  --weighting_scheme=none \
  --resolution=1024 --train_batch_size=1 --repeats=1 \
  --learning_rate=1e-4 --report_to=wandb \
  --gradient_accumulation_steps=1 --gradient_checkpointing \
  --lr_scheduler=constant --lr_warmup_steps=0 --rank=4 \
  --max_train_steps=700 --checkpointing_steps=2000 --seed=0 \
  --validation_prompt="a puppy in a pond, yarn art style" --validation_epochs=1 \
  --push_to_hub

Notes

  • mixed_precision="fp16" is leading to NaN loss values despite the recommendation to use FP16 for "Efficient-Large-Model/Sana_1600M_1024px_diffusers".
  • VAE is kept in FP32, always.
  • When FP16 mixed-precision is used, we cast the LoRA params to FP32 while keeping the transformer in FP16.
  • Only QKV is LoRA targeted for LoRA.

Results

https://wandb.ai/sayakpaul/dreambooth-sana-lora/runs/tf9fo8o6

image

Pre-trained LoRA: https://huggingface.co/sayakpaul/yarn_art_lora_sana


return noise, input_ids, pipeline_inputs

@unittest.skip("Not supported in Sana.")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skipped tests are the same as Mochi.

Comment on lines +100 to +101
"prompt": "",
"negative_prompt": "",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w
Copy link
Member

So was not requested for review, but saw the latest commit on email notifications about attention_kwargs rename to cross_attention_kwargs. Is it possible to consistently name this simply attention_kwargs across the library?

I had shared the concern in a previous lora refactor PR and this comment. This is because I often find myself having to refer to the documentation for different pipelines instead of just being able to use one consistent parameter name to pass lora scale with, and it is frustrating because you wait for the pipeline to load only to find it fail immediately. I'm not sure if others resonate with this, but anyone using loras often will have faced this.

We have attention_kwargs, cross_attention_kwargs and joint_attention_kwargs. IMO these can all just be called attention_kwargs and we don't need to make this distinction. I know we can't just change it for all the old pipelines until a release where we are okay with backwards-breaking (1.0.0), but going forward, let's maybe try using same name everywhere? WDYT?

@sayakpaul
Copy link
Member Author

Yeah I don't mind.

@sayakpaul sayakpaul requested a review from a-r-r-o-w December 16, 2024 08:32
@sayakpaul sayakpaul marked this pull request as ready for review December 16, 2024 08:40
@sayakpaul sayakpaul added roadmap Add to current release roadmap lora labels Dec 16, 2024
Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the super fast work! Looks good to merge after some of the more important reviews are addressed

vae.to(dtype=torch.float32)
transformer.to(accelerator.device, dtype=weight_dtype)
# because Gemma2 is particularly suited for bfloat16.
text_encoder.to(dtype=torch.bfloat16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could instead load with torch_dtype=torch.bfloat16 and keept he same comment. This is because weight casting this way ignores _keep_modules_in_fp32. I could not get our numerical precision unit tests to match when using the two different ways when working on the integration PR

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this is coming from the example provided in https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana#diffusers.SanaPipeline.__call__.example. In this case, we're doing the exact same thing and we are not fine-tuning the text encoder.

)

# VAE should always be kept in fp32 for SANA (?)
vae.to(dtype=torch.float32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FP32 should be good, but I'm not 100% sure. I think AutoencoderDC were all trained in bf16. Maybe @lawrence-cj can comment

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just to be sure VAE's precision isn't a bottleneck for getting good quality training runs. This is anyway a small VAE, won't matter too much I guess.

Copy link
Contributor

@lawrence-cj lawrence-cj Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. AutoencoderDC is trained under BF16 and FP32 testing is also fine, just it'll cost a lot of additional GPU memory in FP32.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're offloading it to CPU when it's not used and when cache_latents is supplied through the CLI, we will precompute the latents and delete the VAE. So, I guess okay for now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the VAE.decode() part will consume much more GPU memory if it runs in FP32, specially when the batch_size is more than 1, not the VAE model itself. Not sure if I understand right.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think we should be good as we barely make use of decode() in training.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, that's cool. Then the only concern is that when we visualize the training results during training

@@ -185,6 +189,7 @@ def encode_prompt(
clean_caption: bool = False,
max_sequence_length: int = 300,
complex_human_instruction: Optional[List[str]] = None,
lora_scale: Optional[float] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we training text encoder? If not, we can remove these changes maybe

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was to have no surprise for our users when text encoder training support is merged. It's common to see the encode_prompt() method being equipped with handling lora_scale.

@sayakpaul sayakpaul changed the title [WIP][LoRA] feat: lora support for SANA. [LoRA] feat: lora support for SANA. Dec 18, 2024
@sayakpaul
Copy link
Member Author

@a-r-r-o-w your comments have been addressed. @lawrence-cj could you review / test the training script if you want?

@sayakpaul sayakpaul requested a review from yiyixuxu December 18, 2024 01:52
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@sayakpaul
Copy link
Member Author

Failing tests are unrelated and can safely be ignored. Will add a training test in a followup PR.

@sayakpaul sayakpaul merged commit 9408aa2 into main Dec 18, 2024
13 of 15 checks passed
@sayakpaul sayakpaul deleted the sana-lora branch December 18, 2024 02:52
@lawrence-cj
Copy link
Contributor

@a-r-r-o-w your comments have been addressed. @lawrence-cj could you review / test the training script if you want?

Working on it. Will fine-tune the model using your pokemon dataset.

Foundsheep pushed a commit to Foundsheep/diffusers that referenced this pull request Dec 23, 2024
* feat: lora support for SANA.

* make fix-copies

* rename test class.

* attention_kwargs -> cross_attention_kwargs.

* Revert "attention_kwargs -> cross_attention_kwargs."

This reverts commit 23433bf.

* exhaust 119 max line limit

* sana lora fine-tuning script.

* readme

* add a note about the supported models.

* Apply suggestions from code review

Co-authored-by: Aryan <[email protected]>

* style

* docs for attention_kwargs.

* remove lora_scale from pag pipeline.

* copy fix

---------

Co-authored-by: Aryan <[email protected]>
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* feat: lora support for SANA.

* make fix-copies

* rename test class.

* attention_kwargs -> cross_attention_kwargs.

* Revert "attention_kwargs -> cross_attention_kwargs."

This reverts commit 23433bf.

* exhaust 119 max line limit

* sana lora fine-tuning script.

* readme

* add a note about the supported models.

* Apply suggestions from code review

Co-authored-by: Aryan <[email protected]>

* style

* docs for attention_kwargs.

* remove lora_scale from pag pipeline.

* copy fix

---------

Co-authored-by: Aryan <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
lora roadmap Add to current release roadmap
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants