Skip to content

Commit 18fc40c

Browse files
sayakpaulstevhliu
andauthored
[Feat] add tiny Autoencoder for (almost) instant decoding (#4384)
* add: model implementation of tiny autoencoder. * add: inits. * push the latest devs. * add: conversion script and finish. * add: scaling factor args. * debugging * fix denormalization. * fix: positional argument. * handle use_torch_2_0_or_xformers. * handle post_quant_conv * handle dtype * fix: sdxl image processor for tiny ae. * fix: sdxl image processor for tiny ae. * unify upcasting logic. * copied from madness. * remove trailing whitespace. * set is_tiny_vae = False * address PR comments. * change to AutoencoderTiny * make act_fn an str throughout * fix: apply_forward_hook decorator call * get rid of the special is_tiny_vae flag. * directly scale the output. * fix dummies? * fix: act_fn. * get rid of the Clamp() layer. * bring back copied from. * movement of the blocks to appropriate modules. * add: docstrings to AutoencoderTiny * add: documentation. * changes to the conversion script. * add doc entry. * settle tests. * style * add one slow test. * fix * fix 2 * fix 2 * fix: 4 * fix: 5 * finish integration tests * Apply suggestions from code review Co-authored-by: Steven Liu <[email protected]> * style --------- Co-authored-by: Steven Liu <[email protected]>
1 parent 615c04d commit 18fc40c

File tree

11 files changed

+543
-2
lines changed

11 files changed

+543
-2
lines changed

docs/source/en/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@
164164
title: AutoencoderKL
165165
- local: api/models/asymmetricautoencoderkl
166166
title: AsymmetricAutoencoderKL
167+
- local: api/models/autoencoder_tiny
168+
title: Tiny AutoEncoder
167169
- local: api/models/transformer2d
168170
title: Transformer2D
169171
- local: api/models/transformer_temporal
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Tiny AutoEncoder
2+
3+
Tiny AutoEncoder for Stable Diffusion (TAESD) was introduced in [madebyollin/taesd](https://github.com/madebyollin/taesd) by Ollin Boer Bohan. It is a tiny distilled version of Stable Diffusion's VAE that can quickly decode the latents in a [`StableDiffusionPipeline`] or [`StableDiffusionXLPipeline`] almost instantly.
4+
5+
To use with Stable Diffusion v-2.1:
6+
7+
```python
8+
import torch
9+
from diffusers import DiffusionPipeline, AutoencoderTiny
10+
11+
pipe = DiffusionPipeline.from_pretrained(
12+
"stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16
13+
)
14+
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=torch.float16)
15+
pipe = pipe.to("cuda")
16+
17+
prompt = "slice of delicious New York-style berry cheesecake"
18+
image = pipe(prompt, num_inference_steps=25).images[0]
19+
image.save("cheesecake.png")
20+
```
21+
22+
To use with Stable Diffusion XL 1.0
23+
24+
```python
25+
import torch
26+
from diffusers import DiffusionPipeline, AutoencoderTiny
27+
28+
pipe = DiffusionPipeline.from_pretrained(
29+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
30+
)
31+
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16)
32+
pipe = pipe.to("cuda")
33+
34+
prompt = "slice of delicious New York-style berry cheesecake"
35+
image = pipe(prompt, num_inference_steps=25).images[0]
36+
image.save("cheesecake_sdxl.png")
37+
```
38+
39+
## AutoencoderTiny
40+
41+
[[autodoc]] AutoencoderTiny
42+
43+
## AutoencoderTinyOutput
44+
45+
[[autodoc]] models.autoencoder_tiny.AutoencoderTinyOutput
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import argparse
2+
3+
from diffusers.utils import is_safetensors_available
4+
5+
6+
if is_safetensors_available():
7+
import safetensors.torch
8+
else:
9+
raise ImportError("Please install `safetensors`.")
10+
11+
from diffusers import AutoencoderTiny
12+
13+
14+
"""
15+
Example - From the diffusers root directory:
16+
17+
Download the weights:
18+
```sh
19+
$ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_encoder.safetensors
20+
$ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_decoder.safetensors
21+
```
22+
23+
Convert the model:
24+
```sh
25+
$ python scripts/convert_tiny_autoencoder_to_diffusers.py \
26+
--encoder_ckpt_path taesd_encoder.safetensors \
27+
--decoder_ckpt_path taesd_decoder.safetensors \
28+
--dump_path taesd-diffusers
29+
```
30+
"""
31+
32+
if __name__ == "__main__":
33+
parser = argparse.ArgumentParser()
34+
35+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
36+
parser.add_argument(
37+
"--encoder_ckpt_path",
38+
default=None,
39+
type=str,
40+
required=True,
41+
help="Path to the encoder ckpt.",
42+
)
43+
parser.add_argument(
44+
"--decoder_ckpt_path",
45+
default=None,
46+
type=str,
47+
required=True,
48+
help="Path to the decoder ckpt.",
49+
)
50+
parser.add_argument(
51+
"--use_safetensors", action="store_true", help="Whether to serialize in the safetensors format."
52+
)
53+
args = parser.parse_args()
54+
55+
print("Loading the original state_dicts of the encoder and the decoder...")
56+
encoder_state_dict = safetensors.torch.load_file(args.encoder_ckpt_path)
57+
decoder_state_dict = safetensors.torch.load_file(args.decoder_ckpt_path)
58+
59+
print("Populating the state_dicts in the diffusers format...")
60+
tiny_autoencoder = AutoencoderTiny()
61+
new_state_dict = {}
62+
63+
# Modify the encoder state dict.
64+
for k in encoder_state_dict:
65+
new_state_dict.update({f"encoder.layers.{k}": encoder_state_dict[k]})
66+
67+
# Modify the decoder state dict.
68+
for k in decoder_state_dict:
69+
layer_id = int(k.split(".")[0]) - 1
70+
new_k = str(layer_id) + "." + ".".join(k.split(".")[1:])
71+
new_state_dict.update({f"decoder.layers.{new_k}": decoder_state_dict[k]})
72+
73+
# Assertion tests with the original implementation can be found here:
74+
# https://gist.github.com/sayakpaul/337b0988f08bd2cf2b248206f760e28f
75+
tiny_autoencoder.load_state_dict(new_state_dict)
76+
print("Population successful, serializing...")
77+
tiny_autoencoder.save_pretrained(args.dump_path, safe_serialization=args.use_safetensors)

src/diffusers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from .models import (
3939
AsymmetricAutoencoderKL,
4040
AutoencoderKL,
41+
AutoencoderTiny,
4142
ControlNetModel,
4243
ModelMixin,
4344
MultiAdapter,

src/diffusers/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .adapter import MultiAdapter, T2IAdapter
2020
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
2121
from .autoencoder_kl import AutoencoderKL
22+
from .autoencoder_tiny import AutoencoderTiny
2223
from .controlnet import ControlNetModel
2324
from .dual_transformer_2d import DualTransformer2DModel
2425
from .modeling_utils import ModelMixin

src/diffusers/models/activations.py

+2
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,7 @@ def get_activation(act_fn):
88
return nn.Mish()
99
elif act_fn == "gelu":
1010
return nn.GELU()
11+
elif act_fn == "relu":
12+
return nn.ReLU()
1113
else:
1214
raise ValueError(f"Unsupported activation function: {act_fn}")
+193
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# Copyright 2023 Ollin Boer Bohan and The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from dataclasses import dataclass
17+
from typing import Tuple, Union
18+
19+
import torch
20+
21+
from ..configuration_utils import ConfigMixin, register_to_config
22+
from ..utils import BaseOutput, apply_forward_hook
23+
from .modeling_utils import ModelMixin
24+
from .vae import DecoderOutput, DecoderTiny, EncoderTiny
25+
26+
27+
@dataclass
28+
class AutoencoderTinyOutput(BaseOutput):
29+
"""
30+
Output of AutoencoderTiny encoding method.
31+
32+
Args:
33+
latents (`torch.Tensor`): Encoded outputs of the `Encoder`.
34+
35+
"""
36+
37+
latents: torch.Tensor
38+
39+
40+
class AutoencoderTiny(ModelMixin, ConfigMixin):
41+
r"""
42+
A tiny distilled VAE model for encoding images into latents and decoding latent representations into images.
43+
44+
[`AutoencoderTiny`] is a wrapper around the original implementation of `TAESD`.
45+
46+
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
47+
all models (such as downloading or saving).
48+
49+
Parameters:
50+
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
51+
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
52+
encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
53+
Tuple of integers representing the number of output channels for each encoder block. The length of the
54+
tuple should be equal to the number of encoder blocks.
55+
decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
56+
Tuple of integers representing the number of output channels for each decoder block. The length of the
57+
tuple should be equal to the number of decoder blocks.
58+
act_fn (`str`, *optional*, defaults to `"relu"`):
59+
Activation function to be used throughout the model.
60+
latent_channels (`int`, *optional*, defaults to 4):
61+
Number of channels in the latent representation. The latent space acts as a compressed representation of
62+
the input image.
63+
upsampling_scaling_factor (`int`, *optional*, defaults to 2):
64+
Scaling factor for upsampling in the decoder. It determines the size of the output image during the
65+
upsampling process.
66+
num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`):
67+
Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The
68+
length of the tuple should be equal to the number of stages in the encoder. Each stage has a different
69+
number of encoder blocks.
70+
num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`):
71+
Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The
72+
length of the tuple should be equal to the number of stages in the decoder. Each stage has a different
73+
number of decoder blocks.
74+
latent_magnitude (`float`, *optional*, defaults to 3.0):
75+
Magnitude of the latent representation. This parameter scales the latent representation values to control
76+
the extent of information preservation.
77+
latent_shift (float, *optional*, defaults to 0.5):
78+
Shift applied to the latent representation. This parameter controls the center of the latent space.
79+
scaling_factor (`float`, *optional*, defaults to 1.0):
80+
The component-wise standard deviation of the trained latent space computed using the first batch of the
81+
training set. This is used to scale the latent space to have unit variance when training the diffusion
82+
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
83+
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
84+
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
85+
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder,
86+
however, no such scaling factor was used, hence the value of 1.0 as the default.
87+
force_upcast (`bool`, *optional*, default to `False`):
88+
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
89+
can be fine-tuned / trained to a lower range without losing too much precision, in which case
90+
`force_upcast` can be set to `False` (see this fp16-friendly
91+
[AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
92+
"""
93+
_supports_gradient_checkpointing = True
94+
95+
@register_to_config
96+
def __init__(
97+
self,
98+
in_channels=3,
99+
out_channels=3,
100+
encoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
101+
decoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
102+
act_fn: str = "relu",
103+
latent_channels: int = 4,
104+
upsampling_scaling_factor: int = 2,
105+
num_encoder_blocks: Tuple[int] = (1, 3, 3, 3),
106+
num_decoder_blocks: Tuple[int] = (3, 3, 3, 1),
107+
latent_magnitude: int = 3,
108+
latent_shift: float = 0.5,
109+
force_upcast: float = False,
110+
scaling_factor: float = 1.0,
111+
):
112+
super().__init__()
113+
114+
if len(encoder_block_out_channels) != len(num_encoder_blocks):
115+
raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.")
116+
if len(decoder_block_out_channels) != len(num_decoder_blocks):
117+
raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.")
118+
119+
self.encoder = EncoderTiny(
120+
in_channels=in_channels,
121+
out_channels=latent_channels,
122+
num_blocks=num_encoder_blocks,
123+
block_out_channels=encoder_block_out_channels,
124+
act_fn=act_fn,
125+
)
126+
127+
self.decoder = DecoderTiny(
128+
in_channels=latent_channels,
129+
out_channels=out_channels,
130+
num_blocks=num_decoder_blocks,
131+
block_out_channels=decoder_block_out_channels,
132+
upsampling_scaling_factor=upsampling_scaling_factor,
133+
act_fn=act_fn,
134+
)
135+
136+
self.latent_magnitude = latent_magnitude
137+
self.latent_shift = latent_shift
138+
self.scaling_factor = scaling_factor
139+
140+
def _set_gradient_checkpointing(self, module, value=False):
141+
if isinstance(module, (EncoderTiny, DecoderTiny)):
142+
module.gradient_checkpointing = value
143+
144+
def scale_latents(self, x):
145+
"""raw latents -> [0, 1]"""
146+
return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
147+
148+
def unscale_latents(self, x):
149+
"""[0, 1] -> raw latents"""
150+
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
151+
152+
@apply_forward_hook
153+
def encode(
154+
self, x: torch.FloatTensor, return_dict: bool = True
155+
) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]:
156+
output = self.encoder(x)
157+
158+
if not return_dict:
159+
return (output,)
160+
161+
return AutoencoderTinyOutput(latents=output)
162+
163+
@apply_forward_hook
164+
def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
165+
output = self.decoder(x)
166+
# Refer to the following discussion to know why this is needed.
167+
# https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854
168+
output = output.mul_(2).sub_(1)
169+
170+
if not return_dict:
171+
return (output,)
172+
173+
return DecoderOutput(sample=output)
174+
175+
def forward(
176+
self,
177+
sample: torch.FloatTensor,
178+
return_dict: bool = True,
179+
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
180+
r"""
181+
Args:
182+
sample (`torch.FloatTensor`): Input sample.
183+
return_dict (`bool`, *optional*, defaults to `True`):
184+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
185+
"""
186+
enc = self.encode(sample).latents
187+
scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
188+
unscaled_enc = self.unscale_latents(scaled_enc)
189+
dec = self.decode(unscaled_enc)
190+
191+
if not return_dict:
192+
return (dec,)
193+
return DecoderOutput(sample=dec)

src/diffusers/models/unet_2d_blocks.py

+23
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch import nn
2020

2121
from ..utils import is_torch_version, logging
22+
from .activations import get_activation
2223
from .attention import AdaGroupNorm
2324
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
2425
from .dual_transformer_2d import DualTransformer2DModel
@@ -423,6 +424,28 @@ def get_up_block(
423424
raise ValueError(f"{up_block_type} does not exist.")
424425

425426

427+
class AutoencoderTinyBlock(nn.Module):
428+
def __init__(self, in_channels: int, out_channels: int, act_fn: str):
429+
super().__init__()
430+
act_fn = get_activation(act_fn)
431+
self.conv = nn.Sequential(
432+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
433+
act_fn,
434+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
435+
act_fn,
436+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
437+
)
438+
self.skip = (
439+
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
440+
if in_channels != out_channels
441+
else nn.Identity()
442+
)
443+
self.fuse = nn.ReLU()
444+
445+
def forward(self, x):
446+
return self.fuse(self.conv(x) + self.skip(x))
447+
448+
426449
class UNetMidBlock2D(nn.Module):
427450
def __init__(
428451
self,

0 commit comments

Comments
 (0)