|
| 1 | +# Copyright 2023 Ollin Boer Bohan and The HuggingFace Team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | + |
| 16 | +from dataclasses import dataclass |
| 17 | +from typing import Tuple, Union |
| 18 | + |
| 19 | +import torch |
| 20 | + |
| 21 | +from ..configuration_utils import ConfigMixin, register_to_config |
| 22 | +from ..utils import BaseOutput, apply_forward_hook |
| 23 | +from .modeling_utils import ModelMixin |
| 24 | +from .vae import DecoderOutput, DecoderTiny, EncoderTiny |
| 25 | + |
| 26 | + |
| 27 | +@dataclass |
| 28 | +class AutoencoderTinyOutput(BaseOutput): |
| 29 | + """ |
| 30 | + Output of AutoencoderTiny encoding method. |
| 31 | +
|
| 32 | + Args: |
| 33 | + latents (`torch.Tensor`): Encoded outputs of the `Encoder`. |
| 34 | +
|
| 35 | + """ |
| 36 | + |
| 37 | + latents: torch.Tensor |
| 38 | + |
| 39 | + |
| 40 | +class AutoencoderTiny(ModelMixin, ConfigMixin): |
| 41 | + r""" |
| 42 | + A tiny distilled VAE model for encoding images into latents and decoding latent representations into images. |
| 43 | +
|
| 44 | + [`AutoencoderTiny`] is a wrapper around the original implementation of `TAESD`. |
| 45 | +
|
| 46 | + This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for |
| 47 | + all models (such as downloading or saving). |
| 48 | +
|
| 49 | + Parameters: |
| 50 | + in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. |
| 51 | + out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. |
| 52 | + encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`): |
| 53 | + Tuple of integers representing the number of output channels for each encoder block. The length of the |
| 54 | + tuple should be equal to the number of encoder blocks. |
| 55 | + decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`): |
| 56 | + Tuple of integers representing the number of output channels for each decoder block. The length of the |
| 57 | + tuple should be equal to the number of decoder blocks. |
| 58 | + act_fn (`str`, *optional*, defaults to `"relu"`): |
| 59 | + Activation function to be used throughout the model. |
| 60 | + latent_channels (`int`, *optional*, defaults to 4): |
| 61 | + Number of channels in the latent representation. The latent space acts as a compressed representation of |
| 62 | + the input image. |
| 63 | + upsampling_scaling_factor (`int`, *optional*, defaults to 2): |
| 64 | + Scaling factor for upsampling in the decoder. It determines the size of the output image during the |
| 65 | + upsampling process. |
| 66 | + num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`): |
| 67 | + Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The |
| 68 | + length of the tuple should be equal to the number of stages in the encoder. Each stage has a different |
| 69 | + number of encoder blocks. |
| 70 | + num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`): |
| 71 | + Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The |
| 72 | + length of the tuple should be equal to the number of stages in the decoder. Each stage has a different |
| 73 | + number of decoder blocks. |
| 74 | + latent_magnitude (`float`, *optional*, defaults to 3.0): |
| 75 | + Magnitude of the latent representation. This parameter scales the latent representation values to control |
| 76 | + the extent of information preservation. |
| 77 | + latent_shift (float, *optional*, defaults to 0.5): |
| 78 | + Shift applied to the latent representation. This parameter controls the center of the latent space. |
| 79 | + scaling_factor (`float`, *optional*, defaults to 1.0): |
| 80 | + The component-wise standard deviation of the trained latent space computed using the first batch of the |
| 81 | + training set. This is used to scale the latent space to have unit variance when training the diffusion |
| 82 | + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the |
| 83 | + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 |
| 84 | + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image |
| 85 | + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder, |
| 86 | + however, no such scaling factor was used, hence the value of 1.0 as the default. |
| 87 | + force_upcast (`bool`, *optional*, default to `False`): |
| 88 | + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE |
| 89 | + can be fine-tuned / trained to a lower range without losing too much precision, in which case |
| 90 | + `force_upcast` can be set to `False` (see this fp16-friendly |
| 91 | + [AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). |
| 92 | + """ |
| 93 | + _supports_gradient_checkpointing = True |
| 94 | + |
| 95 | + @register_to_config |
| 96 | + def __init__( |
| 97 | + self, |
| 98 | + in_channels=3, |
| 99 | + out_channels=3, |
| 100 | + encoder_block_out_channels: Tuple[int] = (64, 64, 64, 64), |
| 101 | + decoder_block_out_channels: Tuple[int] = (64, 64, 64, 64), |
| 102 | + act_fn: str = "relu", |
| 103 | + latent_channels: int = 4, |
| 104 | + upsampling_scaling_factor: int = 2, |
| 105 | + num_encoder_blocks: Tuple[int] = (1, 3, 3, 3), |
| 106 | + num_decoder_blocks: Tuple[int] = (3, 3, 3, 1), |
| 107 | + latent_magnitude: int = 3, |
| 108 | + latent_shift: float = 0.5, |
| 109 | + force_upcast: float = False, |
| 110 | + scaling_factor: float = 1.0, |
| 111 | + ): |
| 112 | + super().__init__() |
| 113 | + |
| 114 | + if len(encoder_block_out_channels) != len(num_encoder_blocks): |
| 115 | + raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.") |
| 116 | + if len(decoder_block_out_channels) != len(num_decoder_blocks): |
| 117 | + raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.") |
| 118 | + |
| 119 | + self.encoder = EncoderTiny( |
| 120 | + in_channels=in_channels, |
| 121 | + out_channels=latent_channels, |
| 122 | + num_blocks=num_encoder_blocks, |
| 123 | + block_out_channels=encoder_block_out_channels, |
| 124 | + act_fn=act_fn, |
| 125 | + ) |
| 126 | + |
| 127 | + self.decoder = DecoderTiny( |
| 128 | + in_channels=latent_channels, |
| 129 | + out_channels=out_channels, |
| 130 | + num_blocks=num_decoder_blocks, |
| 131 | + block_out_channels=decoder_block_out_channels, |
| 132 | + upsampling_scaling_factor=upsampling_scaling_factor, |
| 133 | + act_fn=act_fn, |
| 134 | + ) |
| 135 | + |
| 136 | + self.latent_magnitude = latent_magnitude |
| 137 | + self.latent_shift = latent_shift |
| 138 | + self.scaling_factor = scaling_factor |
| 139 | + |
| 140 | + def _set_gradient_checkpointing(self, module, value=False): |
| 141 | + if isinstance(module, (EncoderTiny, DecoderTiny)): |
| 142 | + module.gradient_checkpointing = value |
| 143 | + |
| 144 | + def scale_latents(self, x): |
| 145 | + """raw latents -> [0, 1]""" |
| 146 | + return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1) |
| 147 | + |
| 148 | + def unscale_latents(self, x): |
| 149 | + """[0, 1] -> raw latents""" |
| 150 | + return x.sub(self.latent_shift).mul(2 * self.latent_magnitude) |
| 151 | + |
| 152 | + @apply_forward_hook |
| 153 | + def encode( |
| 154 | + self, x: torch.FloatTensor, return_dict: bool = True |
| 155 | + ) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]: |
| 156 | + output = self.encoder(x) |
| 157 | + |
| 158 | + if not return_dict: |
| 159 | + return (output,) |
| 160 | + |
| 161 | + return AutoencoderTinyOutput(latents=output) |
| 162 | + |
| 163 | + @apply_forward_hook |
| 164 | + def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: |
| 165 | + output = self.decoder(x) |
| 166 | + # Refer to the following discussion to know why this is needed. |
| 167 | + # https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854 |
| 168 | + output = output.mul_(2).sub_(1) |
| 169 | + |
| 170 | + if not return_dict: |
| 171 | + return (output,) |
| 172 | + |
| 173 | + return DecoderOutput(sample=output) |
| 174 | + |
| 175 | + def forward( |
| 176 | + self, |
| 177 | + sample: torch.FloatTensor, |
| 178 | + return_dict: bool = True, |
| 179 | + ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: |
| 180 | + r""" |
| 181 | + Args: |
| 182 | + sample (`torch.FloatTensor`): Input sample. |
| 183 | + return_dict (`bool`, *optional*, defaults to `True`): |
| 184 | + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. |
| 185 | + """ |
| 186 | + enc = self.encode(sample).latents |
| 187 | + scaled_enc = self.scale_latents(enc).mul_(255).round_().byte() |
| 188 | + unscaled_enc = self.unscale_latents(scaled_enc) |
| 189 | + dec = self.decode(unscaled_enc) |
| 190 | + |
| 191 | + if not return_dict: |
| 192 | + return (dec,) |
| 193 | + return DecoderOutput(sample=dec) |
0 commit comments