Skip to content

[Feat] add tiny Autoencoder for (almost) instant decoding #4384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
c15d27a
add: model implementation of tiny autoencoder.
sayakpaul Jul 26, 2023
937eb19
add: inits.
sayakpaul Jul 26, 2023
19f834e
push the latest devs.
sayakpaul Jul 26, 2023
58b9bc9
Merge branch 'main' into feat/tiny-autoenc
sayakpaul Jul 31, 2023
cddf45d
add: conversion script and finish.
sayakpaul Jul 31, 2023
f111387
add: scaling factor args.
sayakpaul Jul 31, 2023
258c074
debugging
sayakpaul Jul 31, 2023
8be9d3d
fix denormalization.
sayakpaul Jul 31, 2023
63a95a5
fix: positional argument.
sayakpaul Jul 31, 2023
b470897
handle use_torch_2_0_or_xformers.
sayakpaul Jul 31, 2023
2ce4bf3
handle post_quant_conv
sayakpaul Jul 31, 2023
c8ff8e5
handle dtype
sayakpaul Jul 31, 2023
0c099ca
fix: sdxl image processor for tiny ae.
sayakpaul Jul 31, 2023
8d43bc8
fix: sdxl image processor for tiny ae.
sayakpaul Jul 31, 2023
7d2dca0
unify upcasting logic.
sayakpaul Jul 31, 2023
a7a8f6f
copied from madness.
sayakpaul Jul 31, 2023
d11fd65
remove trailing whitespace.
sayakpaul Jul 31, 2023
92c72e2
set is_tiny_vae = False
sayakpaul Aug 1, 2023
d6808d1
address PR comments.
sayakpaul Aug 1, 2023
5b7635b
change to AutoencoderTiny
sayakpaul Aug 1, 2023
4bd124d
make act_fn an str throughout
sayakpaul Aug 1, 2023
4ddc077
fix: apply_forward_hook decorator call
sayakpaul Aug 1, 2023
ee35ebd
get rid of the special is_tiny_vae flag.
sayakpaul Aug 1, 2023
4a606ef
directly scale the output.
sayakpaul Aug 1, 2023
56eade8
fix dummies?
sayakpaul Aug 1, 2023
ae3b107
fix: act_fn.
sayakpaul Aug 1, 2023
40966b8
get rid of the Clamp() layer.
sayakpaul Aug 2, 2023
0b04133
bring back copied from.
sayakpaul Aug 2, 2023
13f06fa
movement of the blocks to appropriate modules.
sayakpaul Aug 2, 2023
ef8772a
add: docstrings to AutoencoderTiny
sayakpaul Aug 2, 2023
8c0852c
add: documentation.
sayakpaul Aug 2, 2023
6f17fd3
changes to the conversion script.
sayakpaul Aug 2, 2023
29e3370
add doc entry.
sayakpaul Aug 2, 2023
817bb2b
settle tests.
sayakpaul Aug 2, 2023
2995de1
Merge branch 'main' into feat/tiny-autoenc
sayakpaul Aug 2, 2023
bbf0597
style
sayakpaul Aug 2, 2023
9105fb5
add one slow test.
sayakpaul Aug 2, 2023
ef3eae2
fix
sayakpaul Aug 2, 2023
e05b730
fix 2
sayakpaul Aug 2, 2023
0980796
fix 2
sayakpaul Aug 2, 2023
163c035
fix: 4
sayakpaul Aug 2, 2023
dd0b673
fix: 5
sayakpaul Aug 2, 2023
d7ab16f
finish integration tests
sayakpaul Aug 2, 2023
644d125
Apply suggestions from code review
sayakpaul Aug 2, 2023
e2fcccb
style
sayakpaul Aug 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@
title: AutoencoderKL
- local: api/models/asymmetricautoencoderkl
title: AsymmetricAutoencoderKL
- local: api/models/autoencoder_tiny
title: Tiny Autoencoder
- local: api/models/transformer2d
title: Transformer2D
- local: api/models/transformer_temporal
Expand Down
45 changes: 45 additions & 0 deletions docs/source/en/api/models/autoencoder_tiny.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Tiny Autoencoder

This Autoencoder was introduced in [this repository](https://github.com/madebyollin/taesd). It is tiny in shape and can decode the latents in a [`StableDiffusionPipeline`] or [`StableDiffusionXLPipeline`] almost instantly.

## Using with Stable Diffusion v-2.1

```python
import torch
from diffusers import DiffusionPipeline, AutoencoderTiny

pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16
)
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=torch.float16)
pipe = pipe.to("cuda")

prompt = "slice of delicious New York-style berry cheesecake"
image = pipe(prompt, num_inference_steps=25).images[0]
image.save("cheesecake.png")
```

## Using with Stable Diffusion XL 1.0

```python
import torch
from diffusers import DiffusionPipeline, AutoencoderTiny

pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
)
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16)
pipe = pipe.to("cuda")

prompt = "slice of delicious New York-style berry cheesecake"
image = pipe(prompt, num_inference_steps=25).images[0]
image.save("cheesecake_sdxl.png")
```

## AutoencoderTiny

[[autodoc]] AutoencoderTiny

## AutoencoderTinyOutput

[[autodoc]] models.autoencoder_tiny.AutoencoderTinyOutput
77 changes: 77 additions & 0 deletions scripts/convert_tiny_autoencoder_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import argparse

from diffusers.utils import is_safetensors_available


if is_safetensors_available():
import safetensors.torch
else:
raise ImportError("Please install `safetensors`.")

from diffusers import AutoencoderTiny


"""
Example - From the diffusers root directory:

Download the weights:
```sh
$ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_encoder.safetensors
$ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_decoder.safetensors
```

Convert the model:
```sh
$ python scripts/convert_tiny_autoencoder_to_diffusers.py \
--encoder_ckpt_path taesd_encoder.safetensors \
--decoder_ckpt_path taesd_decoder.safetensors \
--dump_path taesd-diffusers
```
"""

if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
parser.add_argument(
"--encoder_ckpt_path",
default=None,
type=str,
required=True,
help="Path to the encoder ckpt.",
)
parser.add_argument(
"--decoder_ckpt_path",
default=None,
type=str,
required=True,
help="Path to the decoder ckpt.",
)
parser.add_argument(
"--use_safetensors", action="store_true", help="Whether to serialize in the safetensors format."
)
args = parser.parse_args()

print("Loading the original state_dicts of the encoder and the decoder...")
encoder_state_dict = safetensors.torch.load_file(args.encoder_ckpt_path)
decoder_state_dict = safetensors.torch.load_file(args.decoder_ckpt_path)

print("Populating the state_dicts in the diffusers format...")
tiny_autoencoder = AutoencoderTiny()
new_state_dict = {}

# Modify the encoder state dict.
for k in encoder_state_dict:
new_state_dict.update({f"encoder.layers.{k}": encoder_state_dict[k]})

# Modify the decoder state dict.
for k in decoder_state_dict:
layer_id = int(k.split(".")[0]) - 1
new_k = str(layer_id) + "." + ".".join(k.split(".")[1:])
new_state_dict.update({f"decoder.layers.{new_k}": decoder_state_dict[k]})

# Assertion tests with the original implementation can be found here:
# https://gist.github.com/sayakpaul/337b0988f08bd2cf2b248206f760e28f
tiny_autoencoder.load_state_dict(new_state_dict)
print("Population successful, serializing...")
tiny_autoencoder.save_pretrained(args.dump_path, safe_serialization=args.use_safetensors)
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .models import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderTiny,
ControlNetModel,
ModelMixin,
MultiAdapter,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .adapter import MultiAdapter, T2IAdapter
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_tiny import AutoencoderTiny
from .controlnet import ControlNetModel
from .dual_transformer_2d import DualTransformer2DModel
from .modeling_utils import ModelMixin
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ def get_activation(act_fn):
return nn.Mish()
elif act_fn == "gelu":
return nn.GELU()
elif act_fn == "relu":
return nn.ReLU()
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
196 changes: 196 additions & 0 deletions src/diffusers/models/autoencoder_tiny.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# Copyright 2023 Ollin Boer Bohan and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from dataclasses import dataclass
from typing import Tuple, Union

import torch

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, apply_forward_hook
from .modeling_utils import ModelMixin
from .vae import DecoderOutput, DecoderTiny, EncoderTiny


@dataclass
class AutoencoderTinyOutput(BaseOutput):
"""
Output of AutoencoderTiny encoding method.

Args:
latents (`torch.Tensor`): Encoded outputs of the `Encoder`.

"""

latents: torch.Tensor


class AutoencoderTiny(ModelMixin, ConfigMixin):
r"""
A tiny VAE model for encoding images into latents and decoding latent representations into images. It was distilled
by Ollin Boer Bohan as detailed in [https://github.com/madebyollin/taesd](https://github.com/madebyollin/taesd).

[`AutoencoderTiny`] is just wrapper around the original implementation of `TAESD` found in the above-mentioned
repository.

This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
all models (such as downloading or saving).

Parameters:
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
Tuple of integers representing the number of output channels for each encoder block. The length of the
tuple should be equal to the number of encoder blocks.
decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
Tuple of integers representing the number of output channels for each decoder block. The length of the
tuple should be equal to the number of decoder blocks.
act_fn (`str`, *optional*, defaults to "relu"):
Activation function to be used throughout the model.
latent_channels (`int`, *optional*, defaults to 4):
Number of channels in the latent representation. The latent space acts as a compressed representation of
the input image.
upsampling_scaling_factor (`int`, *optional*, defaults to 2):
Scaling factor for upsampling in the decoder. It determines the size of the output image during the
upsampling process.
num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`):
Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The
length of the tuple should be equal to the number of stages in the encoder. Each stage has a different
number of encoder blocks.
num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`):
Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The
length of the tuple should be equal to the number of stages in the decoder. Each stage has a different
number of decoder blocks.
latent_magnitude (`float`, *optional*, defaults to 3.0):
Magnitude of the latent representation. This parameter scales the latent representation values to control
the extent of information preservation.
latent_shift (float, *optional*, defaults to 0.5):
Shift applied to the latent representation. This parameter controls the center of the latent space.
scaling_factor (`float`, *optional*, defaults to 1.0):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder,
however, no such scaling factor was used, hence the value of 1.0 as the default.
force_upcast (`bool`, *optional*, default to `False`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
can be fine-tuned / trained to a lower range without loosing too much precision in which case
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix. This
AutoEncoder, however, is FP16-friendly. Check:
https://github.com/huggingface/diffusers/pull/4384#discussion_r1279410232.
"""
_supports_gradient_checkpointing = True

@register_to_config
def __init__(
self,
in_channels=3,
out_channels=3,
encoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
decoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
act_fn: str = "relu",
latent_channels: int = 4,
upsampling_scaling_factor: int = 2,
num_encoder_blocks: Tuple[int] = (1, 3, 3, 3),
num_decoder_blocks: Tuple[int] = (3, 3, 3, 1),
latent_magnitude: int = 3,
latent_shift: float = 0.5,
force_upcast: float = False,
scaling_factor: float = 1.0,
):
super().__init__()

if len(encoder_block_out_channels) != len(num_encoder_blocks):
raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.")
if len(decoder_block_out_channels) != len(num_decoder_blocks):
raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.")

self.encoder = EncoderTiny(
in_channels=in_channels,
out_channels=latent_channels,
num_blocks=num_encoder_blocks,
block_out_channels=encoder_block_out_channels,
act_fn=act_fn,
)

self.decoder = DecoderTiny(
in_channels=latent_channels,
out_channels=out_channels,
num_blocks=num_decoder_blocks,
block_out_channels=decoder_block_out_channels,
upsampling_scaling_factor=upsampling_scaling_factor,
act_fn=act_fn,
)

self.latent_magnitude = latent_magnitude
self.latent_shift = latent_shift
self.scaling_factor = scaling_factor

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (EncoderTiny, DecoderTiny)):
module.gradient_checkpointing = value

def scale_latents(self, x):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a sentence explaining what this function and unscale_latents does would be clearer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's not a highlighted method from the docs, I think it's fine as is for now.

"""raw latents -> [0, 1]"""
return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)

def unscale_latents(self, x):
"""[0, 1] -> raw latents"""
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)

@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]:
output = self.encoder(x)

if not return_dict:
return (output,)

return AutoencoderTinyOutput(latents=output)

@apply_forward_hook
def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
output = self.decoder(x)
# Refer to the following discussion to know why this is needed.
# https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854
output = output.mul_(2).sub_(1)

if not return_dict:
return (output,)

return DecoderOutput(sample=output)

def forward(
self,
sample: torch.FloatTensor,
return_dict: bool = True,
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
enc = self.encode(sample).latents
scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
unscaled_enc = self.unscale_latents(scaled_enc)
dec = self.decode(unscaled_enc)

if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
23 changes: 23 additions & 0 deletions src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch import nn

from ..utils import is_torch_version, logging
from .activations import get_activation
from .attention import AdaGroupNorm
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
from .dual_transformer_2d import DualTransformer2DModel
Expand Down Expand Up @@ -423,6 +424,28 @@ def get_up_block(
raise ValueError(f"{up_block_type} does not exist.")


class AutoencoderTinyBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, act_fn: str):
super().__init__()
act_fn = get_activation(act_fn)
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
act_fn,
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
act_fn,
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
)
self.skip = (
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
if in_channels != out_channels
else nn.Identity()
)
self.fuse = nn.ReLU()

def forward(self, x):
return self.fuse(self.conv(x) + self.skip(x))


class UNetMidBlock2D(nn.Module):
def __init__(
self,
Expand Down
Loading