-
Notifications
You must be signed in to change notification settings - Fork 6k
[LoRA] feat: lora support for SANA. #10234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
||
return noise, input_ids, pipeline_inputs | ||
|
||
@unittest.skip("Not supported in Sana.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skipped tests are the same as Mochi.
"prompt": "", | ||
"negative_prompt": "", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check this internal thread:
https://huggingface.slack.com/archives/C065E480NN9/p1734324025408149
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
So was not requested for review, but saw the latest commit on email notifications about I had shared the concern in a previous lora refactor PR and this comment. This is because I often find myself having to refer to the documentation for different pipelines instead of just being able to use one consistent parameter name to pass lora scale with, and it is frustrating because you wait for the pipeline to load only to find it fail immediately. I'm not sure if others resonate with this, but anyone using loras often will have faced this. We have |
Yeah I don't mind. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the super fast work! Looks good to merge after some of the more important reviews are addressed
vae.to(dtype=torch.float32) | ||
transformer.to(accelerator.device, dtype=weight_dtype) | ||
# because Gemma2 is particularly suited for bfloat16. | ||
text_encoder.to(dtype=torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could instead load with torch_dtype=torch.bfloat16
and keept he same comment. This is because weight casting this way ignores _keep_modules_in_fp32
. I could not get our numerical precision unit tests to match when using the two different ways when working on the integration PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh this is coming from the example provided in https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana#diffusers.SanaPipeline.__call__.example. In this case, we're doing the exact same thing and we are not fine-tuning the text encoder.
) | ||
|
||
# VAE should always be kept in fp32 for SANA (?) | ||
vae.to(dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FP32 should be good, but I'm not 100% sure. I think AutoencoderDC were all trained in bf16. Maybe @lawrence-cj can comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just to be sure VAE's precision isn't a bottleneck for getting good quality training runs. This is anyway a small VAE, won't matter too much I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. AutoencoderDC is trained under BF16 and FP32 testing is also fine, just it'll cost a lot of additional GPU memory in FP32.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're offloading it to CPU when it's not used and when cache_latents
is supplied through the CLI, we will precompute the latents and delete the VAE. So, I guess okay for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean the VAE.decode() part will consume much more GPU memory if it runs in FP32, specially when the batch_size is more than 1, not the VAE model itself. Not sure if I understand right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I think we should be good as we barely make use of decode()
in training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, that's cool. Then the only concern is that when we visualize the training results during training
@@ -185,6 +189,7 @@ def encode_prompt( | |||
clean_caption: bool = False, | |||
max_sequence_length: int = 300, | |||
complex_human_instruction: Optional[List[str]] = None, | |||
lora_scale: Optional[float] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we training text encoder? If not, we can remove these changes maybe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was to have no surprise for our users when text encoder training support is merged. It's common to see the encode_prompt()
method being equipped with handling lora_scale
.
@a-r-r-o-w your comments have been addressed. @lawrence-cj could you review / test the training script if you want? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
Failing tests are unrelated and can safely be ignored. Will add a training test in a followup PR. |
Working on it. Will fine-tune the model using your pokemon dataset. |
* feat: lora support for SANA. * make fix-copies * rename test class. * attention_kwargs -> cross_attention_kwargs. * Revert "attention_kwargs -> cross_attention_kwargs." This reverts commit 23433bf. * exhaust 119 max line limit * sana lora fine-tuning script. * readme * add a note about the supported models. * Apply suggestions from code review Co-authored-by: Aryan <[email protected]> * style * docs for attention_kwargs. * remove lora_scale from pag pipeline. * copy fix --------- Co-authored-by: Aryan <[email protected]>
* feat: lora support for SANA. * make fix-copies * rename test class. * attention_kwargs -> cross_attention_kwargs. * Revert "attention_kwargs -> cross_attention_kwargs." This reverts commit 23433bf. * exhaust 119 max line limit * sana lora fine-tuning script. * readme * add a note about the supported models. * Apply suggestions from code review Co-authored-by: Aryan <[email protected]> * style * docs for attention_kwargs. * remove lora_scale from pag pipeline. * copy fix --------- Co-authored-by: Aryan <[email protected]>
What does this PR do?
Example LoRA fine-tuning command:
Notes
mixed_precision="fp16"
is leading to NaN loss values despite the recommendation to use FP16 for "Efficient-Large-Model/Sana_1600M_1024px_diffusers".Results
https://wandb.ai/sayakpaul/dreambooth-sana-lora/runs/tf9fo8o6
Pre-trained LoRA: https://huggingface.co/sayakpaul/yarn_art_lora_sana