Skip to content

[Feat] add tiny Autoencoder for (almost) instant decoding #4384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
c15d27a
add: model implementation of tiny autoencoder.
sayakpaul Jul 26, 2023
937eb19
add: inits.
sayakpaul Jul 26, 2023
19f834e
push the latest devs.
sayakpaul Jul 26, 2023
58b9bc9
Merge branch 'main' into feat/tiny-autoenc
sayakpaul Jul 31, 2023
cddf45d
add: conversion script and finish.
sayakpaul Jul 31, 2023
f111387
add: scaling factor args.
sayakpaul Jul 31, 2023
258c074
debugging
sayakpaul Jul 31, 2023
8be9d3d
fix denormalization.
sayakpaul Jul 31, 2023
63a95a5
fix: positional argument.
sayakpaul Jul 31, 2023
b470897
handle use_torch_2_0_or_xformers.
sayakpaul Jul 31, 2023
2ce4bf3
handle post_quant_conv
sayakpaul Jul 31, 2023
c8ff8e5
handle dtype
sayakpaul Jul 31, 2023
0c099ca
fix: sdxl image processor for tiny ae.
sayakpaul Jul 31, 2023
8d43bc8
fix: sdxl image processor for tiny ae.
sayakpaul Jul 31, 2023
7d2dca0
unify upcasting logic.
sayakpaul Jul 31, 2023
a7a8f6f
copied from madness.
sayakpaul Jul 31, 2023
d11fd65
remove trailing whitespace.
sayakpaul Jul 31, 2023
92c72e2
set is_tiny_vae = False
sayakpaul Aug 1, 2023
d6808d1
address PR comments.
sayakpaul Aug 1, 2023
5b7635b
change to AutoencoderTiny
sayakpaul Aug 1, 2023
4bd124d
make act_fn an str throughout
sayakpaul Aug 1, 2023
4ddc077
fix: apply_forward_hook decorator call
sayakpaul Aug 1, 2023
ee35ebd
get rid of the special is_tiny_vae flag.
sayakpaul Aug 1, 2023
4a606ef
directly scale the output.
sayakpaul Aug 1, 2023
56eade8
fix dummies?
sayakpaul Aug 1, 2023
ae3b107
fix: act_fn.
sayakpaul Aug 1, 2023
40966b8
get rid of the Clamp() layer.
sayakpaul Aug 2, 2023
0b04133
bring back copied from.
sayakpaul Aug 2, 2023
13f06fa
movement of the blocks to appropriate modules.
sayakpaul Aug 2, 2023
ef8772a
add: docstrings to AutoencoderTiny
sayakpaul Aug 2, 2023
8c0852c
add: documentation.
sayakpaul Aug 2, 2023
6f17fd3
changes to the conversion script.
sayakpaul Aug 2, 2023
29e3370
add doc entry.
sayakpaul Aug 2, 2023
817bb2b
settle tests.
sayakpaul Aug 2, 2023
2995de1
Merge branch 'main' into feat/tiny-autoenc
sayakpaul Aug 2, 2023
bbf0597
style
sayakpaul Aug 2, 2023
9105fb5
add one slow test.
sayakpaul Aug 2, 2023
ef3eae2
fix
sayakpaul Aug 2, 2023
e05b730
fix 2
sayakpaul Aug 2, 2023
0980796
fix 2
sayakpaul Aug 2, 2023
163c035
fix: 4
sayakpaul Aug 2, 2023
dd0b673
fix: 5
sayakpaul Aug 2, 2023
d7ab16f
finish integration tests
sayakpaul Aug 2, 2023
644d125
Apply suggestions from code review
sayakpaul Aug 2, 2023
e2fcccb
style
sayakpaul Aug 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions scripts/convert_tiny_autoencoder_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import argparse

from diffusers.utils import is_safetensors_available


if is_safetensors_available():
import safetensors.torch
else:
raise ImportError("Please install `safetensors`.")

from diffusers import TinyAutoencoder


"""
Example - From the diffusers root directory:

Download the weights:
```sh
$ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_encoder.safetensors
$ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_decoder.safetensors
```

Convert the model:
```sh
$ python scripts/convert_tiny_autoencoder_to_diffusers.py \
--encoder_ckpt_path taesd_encoder.safetensors \
--decoder_ckpt_path taesd_decoder.safetensors \
--dump_path taesd-diffusers
```
"""

if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
parser.add_argument(
"--encoder_ckpt_path",
default=None,
type=str,
required=True,
help="Path to the encoder ckpt.",
)
parser.add_argument(
"--decoder_ckpt_path",
default=None,
type=str,
required=True,
help="Path to the decoder ckpt.",
)
parser.add_argument(
"--use_safetensors", action="store_true", help="Whether to serialize in the safetensors format."
)
args = parser.parse_args()

print("Loading the original state_dicts of the encoder and the decoder...")
encoder_state_dict = safetensors.torch.load_file(args.encoder_ckpt_path)
decoder_state_dict = safetensors.torch.load_file(args.decoder_ckpt_path)

print("Populating the state_dicts in the diffusers format...")
tiny_autoencoder = TinyAutoencoder()
new_state_dict = {}

for k in encoder_state_dict:
new_state_dict.update({f"encoder.layers.{k}": encoder_state_dict[k]})
for k in decoder_state_dict:
new_state_dict.update({f"decoder.layers.{k}": decoder_state_dict[k]})

# Assertion tests with the original implementation can be found here:
# https://gist.github.com/sayakpaul/337b0988f08bd2cf2b248206f760e28f
tiny_autoencoder.load_state_dict(new_state_dict)
print("Population successful, serializing...")
tiny_autoencoder.save_pretrained(args.dump_path, safe_serialization=args.use_safetensors)
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
PriorTransformer,
T2IAdapter,
T5FilmDecoder,
TinyAutoencoder,
Transformer2DModel,
UNet1DModel,
UNet2DConditionModel,
Expand Down
13 changes: 10 additions & 3 deletions src/diffusers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,17 @@ def normalize(images):
return 2.0 * images - 1.0

@staticmethod
def denormalize(images):
def denormalize(images, is_tiny_vae):
"""
Denormalize an image array to [0,1].

Refer to https://github.com/madebyollin/taesd/issues/3#issuecomment-1657729279 to know why `is_tiny_vae`
exists.
"""
return (images / 2 + 0.5).clamp(0, 1)
if not is_tiny_vae:
return (images / 2 + 0.5).clamp(0, 1)
else:
return images.clamp(0, 1)

@staticmethod
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
Expand Down Expand Up @@ -217,6 +223,7 @@ def postprocess(
image: torch.FloatTensor,
output_type: str = "pil",
do_denormalize: Optional[List[bool]] = None,
is_tiny_vae: bool = False,
):
if not isinstance(image, torch.Tensor):
raise ValueError(
Expand All @@ -237,7 +244,7 @@ def postprocess(
do_denormalize = [self.config.do_normalize] * image.shape[0]

image = torch.stack(
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
[self.denormalize(image[i], is_tiny_vae) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
)

if output_type == "pt":
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .modeling_utils import ModelMixin
from .prior_transformer import PriorTransformer
from .t5_film_transformer import T5FilmDecoder
from .tiny_autoencoder import TinyAutoencoder
from .transformer_2d import Transformer2DModel
from .unet_1d import UNet1DModel
from .unet_2d import UNet2DModel
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ def get_activation(act_fn):
return nn.Mish()
elif act_fn == "gelu":
return nn.GELU()
elif act_fn == "relu":
return nn.ReLU()
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
Loading