Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Refactor LoRA #3778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor LoRA #3778
Changes from all commits
99e35de
7f23d51
104ae13
a6ad973
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why these changes? We try to avoid running pipelines in autocast if possible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I copy and pasted from the regular dreambooth training script which runs the validation inference under autocast. Will remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I might have added this because when we load the rest of the model in fp16, we keep the lora weights in fp32 and needed autocast for them to work together
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, looking through the commit history, that's why I added it. Is that ok?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For running inference validation in intervals, I think keeping autocast is okay as it helps keep things simple.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why didn't we need it before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok so I was wrong that we needed it because of the difference in dtype in the lora weights. The lora layer casts manually internally.
The issue was the dtype of the output of the unet being passed to the vae.
The difference is that in the
main
version of the script, the unet is not wrapped in amp, so the output of the unet during validation is the same dtype as the unet, fp16.In the branch, the full unet is wrapped in amp so even though the unet is loaded in fp16, the output is fp32 and then there's an error when the fp32 ending latents are passed to the fp16 vae.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The alternative to wrapping this section in amp is to put a check either in the pipeline or the beginning of the vae for the dtype of the initial latents and manually cast them if necessary. I like to avoid manual casts like that in case the caller expects the execution in the dtype of their input (which is a reasonable assumption imo). So I would prefer to leave the amp decorator in this case