-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[Feat] add tiny Autoencoder for (almost) instant decoding #4384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
Outdated
Show resolved
Hide resolved
def encode( | ||
self, x: torch.FloatTensor, return_dict: bool = True | ||
) -> Union[TinyAutoencoderOutput, Tuple[torch.FloatTensor]]: | ||
output = self.encoder(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the input images are assumed to be in [-1, 1]
convention, I suspect this needs to be self.encoder(x.mul(0.5).add_(0.5))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We call the encode() method usually from training script where we ALWAYS ensure that the image pixel values are appropriately scaled :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sayakpaul after my hellscape i've been through with normalisation i'd love for this to be fixed deeper in
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool!
nice New york cheese cakes 😍😍😍
@stevhliu could you check the doc part of the PR and let me know your thoughts? |
return init_dict, inputs_dict | ||
|
||
def test_outputs_equivalence(self): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we also add one integration test where we test that a latent vector is correctly decoded to an image?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in d7ab16f.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! Would be nice to add at least one integration test
@patrickvonplaten I have added an integration test. Works for you? @madebyollin I know you're AFK, but as an FYI, I have opened the following:
Once these two PRs are merged, we will merge this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding the tests!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice and clear docstrings! 👏
A small detail, but I found it a little curious that the model is called Tiny AutoEncoder, but its API is AutoencoderTiny
. Not a big deal, but I think it makes it easier to refer to when they have the same names (ie you don't have to mentally convert Tiny AutoEncoder to AutoencoderTiny
).
if isinstance(module, (EncoderTiny, DecoderTiny)): | ||
module.gradient_checkpointing = value | ||
|
||
def scale_latents(self, x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a sentence explaining what this function and unscale_latents
does would be clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it's not a highlighted method from the docs, I think it's fine as is for now.
Co-authored-by: Steven Liu <[email protected]>
AutoencoderTiny was done to be in line with |
…e#4384) * add: model implementation of tiny autoencoder. * add: inits. * push the latest devs. * add: conversion script and finish. * add: scaling factor args. * debugging * fix denormalization. * fix: positional argument. * handle use_torch_2_0_or_xformers. * handle post_quant_conv * handle dtype * fix: sdxl image processor for tiny ae. * fix: sdxl image processor for tiny ae. * unify upcasting logic. * copied from madness. * remove trailing whitespace. * set is_tiny_vae = False * address PR comments. * change to AutoencoderTiny * make act_fn an str throughout * fix: apply_forward_hook decorator call * get rid of the special is_tiny_vae flag. * directly scale the output. * fix dummies? * fix: act_fn. * get rid of the Clamp() layer. * bring back copied from. * movement of the blocks to appropriate modules. * add: docstrings to AutoencoderTiny * add: documentation. * changes to the conversion script. * add doc entry. * settle tests. * style * add one slow test. * fix * fix 2 * fix 2 * fix: 4 * fix: 5 * finish integration tests * Apply suggestions from code review Co-authored-by: Steven Liu <[email protected]> * style --------- Co-authored-by: Steven Liu <[email protected]>
…e#4384) * add: model implementation of tiny autoencoder. * add: inits. * push the latest devs. * add: conversion script and finish. * add: scaling factor args. * debugging * fix denormalization. * fix: positional argument. * handle use_torch_2_0_or_xformers. * handle post_quant_conv * handle dtype * fix: sdxl image processor for tiny ae. * fix: sdxl image processor for tiny ae. * unify upcasting logic. * copied from madness. * remove trailing whitespace. * set is_tiny_vae = False * address PR comments. * change to AutoencoderTiny * make act_fn an str throughout * fix: apply_forward_hook decorator call * get rid of the special is_tiny_vae flag. * directly scale the output. * fix dummies? * fix: act_fn. * get rid of the Clamp() layer. * bring back copied from. * movement of the blocks to appropriate modules. * add: docstrings to AutoencoderTiny * add: documentation. * changes to the conversion script. * add doc entry. * settle tests. * style * add one slow test. * fix * fix 2 * fix 2 * fix: 4 * fix: 5 * finish integration tests * Apply suggestions from code review Co-authored-by: Steven Liu <[email protected]> * style --------- Co-authored-by: Steven Liu <[email protected]>
Fixes #4233.
Considerations
I decided to NOT modify the
AutoencoderKL
class for this one because:AutoencoderKL
. This tiny Autoencoder just has a bunch of conv + relu blocks with residual connections.Results
SD
Comparison
SDXL
Todos