|
| 1 | +# Copyright 2023 The HuggingFace Team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import warnings |
| 16 | +from typing import Union |
| 17 | + |
| 18 | +import numpy as np |
| 19 | +import PIL |
| 20 | +import torch |
| 21 | +from PIL import Image |
| 22 | + |
| 23 | +from .configuration_utils import ConfigMixin, register_to_config |
| 24 | +from .utils import CONFIG_NAME, PIL_INTERPOLATION |
| 25 | + |
| 26 | + |
| 27 | +class VaeImageProcessor(ConfigMixin): |
| 28 | + """ |
| 29 | + Image Processor for VAE |
| 30 | +
|
| 31 | + Args: |
| 32 | + do_resize (`bool`, *optional*, defaults to `True`): |
| 33 | + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. |
| 34 | + vae_scale_factor (`int`, *optional*, defaults to `8`): |
| 35 | + VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this |
| 36 | + factor. |
| 37 | + resample (`str`, *optional*, defaults to `lanczos`): |
| 38 | + Resampling filter to use when resizing the image. |
| 39 | + do_normalize (`bool`, *optional*, defaults to `True`): |
| 40 | + Whether to normalize the image to [-1,1] |
| 41 | + """ |
| 42 | + |
| 43 | + config_name = CONFIG_NAME |
| 44 | + |
| 45 | + @register_to_config |
| 46 | + def __init__( |
| 47 | + self, |
| 48 | + do_resize: bool = True, |
| 49 | + vae_scale_factor: int = 8, |
| 50 | + resample: str = "lanczos", |
| 51 | + do_normalize: bool = True, |
| 52 | + ): |
| 53 | + super().__init__() |
| 54 | + |
| 55 | + @staticmethod |
| 56 | + def numpy_to_pil(images): |
| 57 | + """ |
| 58 | + Convert a numpy image or a batch of images to a PIL image. |
| 59 | + """ |
| 60 | + if images.ndim == 3: |
| 61 | + images = images[None, ...] |
| 62 | + images = (images * 255).round().astype("uint8") |
| 63 | + if images.shape[-1] == 1: |
| 64 | + # special case for grayscale (single channel) images |
| 65 | + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] |
| 66 | + else: |
| 67 | + pil_images = [Image.fromarray(image) for image in images] |
| 68 | + |
| 69 | + return pil_images |
| 70 | + |
| 71 | + @staticmethod |
| 72 | + def numpy_to_pt(images): |
| 73 | + """ |
| 74 | + Convert a numpy image to a pytorch tensor |
| 75 | + """ |
| 76 | + if images.ndim == 3: |
| 77 | + images = images[..., None] |
| 78 | + |
| 79 | + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) |
| 80 | + return images |
| 81 | + |
| 82 | + @staticmethod |
| 83 | + def pt_to_numpy(images): |
| 84 | + """ |
| 85 | + Convert a numpy image to a pytorch tensor |
| 86 | + """ |
| 87 | + images = images.cpu().permute(0, 2, 3, 1).float().numpy() |
| 88 | + return images |
| 89 | + |
| 90 | + @staticmethod |
| 91 | + def normalize(images): |
| 92 | + """ |
| 93 | + Normalize an image array to [-1,1] |
| 94 | + """ |
| 95 | + return 2.0 * images - 1.0 |
| 96 | + |
| 97 | + def resize(self, images: PIL.Image.Image) -> PIL.Image.Image: |
| 98 | + """ |
| 99 | + Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor` |
| 100 | + """ |
| 101 | + w, h = images.size |
| 102 | + w, h = map(lambda x: x - x % self.vae_scale_factor, (w, h)) # resize to integer multiple of vae_scale_factor |
| 103 | + images = images.resize((w, h), resample=PIL_INTERPOLATION[self.resample]) |
| 104 | + return images |
| 105 | + |
| 106 | + def preprocess( |
| 107 | + self, |
| 108 | + image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], |
| 109 | + ) -> torch.Tensor: |
| 110 | + """ |
| 111 | + Preprocess the image input, accepted formats are PIL images, numpy arrays or pytorch tensors" |
| 112 | + """ |
| 113 | + supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) |
| 114 | + if isinstance(image, supported_formats): |
| 115 | + image = [image] |
| 116 | + elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)): |
| 117 | + raise ValueError( |
| 118 | + f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}" |
| 119 | + ) |
| 120 | + |
| 121 | + if isinstance(image[0], PIL.Image.Image): |
| 122 | + if self.do_resize: |
| 123 | + image = [self.resize(i) for i in image] |
| 124 | + image = [np.array(i).astype(np.float32) / 255.0 for i in image] |
| 125 | + image = np.stack(image, axis=0) # to np |
| 126 | + image = self.numpy_to_pt(image) # to pt |
| 127 | + |
| 128 | + elif isinstance(image[0], np.ndarray): |
| 129 | + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) |
| 130 | + image = self.numpy_to_pt(image) |
| 131 | + _, _, height, width = image.shape |
| 132 | + if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0): |
| 133 | + raise ValueError( |
| 134 | + f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.vae_scale_factor}" |
| 135 | + f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" |
| 136 | + ) |
| 137 | + |
| 138 | + elif isinstance(image[0], torch.Tensor): |
| 139 | + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) |
| 140 | + _, _, height, width = image.shape |
| 141 | + if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0): |
| 142 | + raise ValueError( |
| 143 | + f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.vae_scale_factor}" |
| 144 | + f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" |
| 145 | + ) |
| 146 | + |
| 147 | + # expected range [0,1], normalize to [-1,1] |
| 148 | + do_normalize = self.do_normalize |
| 149 | + if image.min() < 0: |
| 150 | + warnings.warn( |
| 151 | + "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " |
| 152 | + f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", |
| 153 | + FutureWarning, |
| 154 | + ) |
| 155 | + do_normalize = False |
| 156 | + |
| 157 | + if do_normalize: |
| 158 | + image = self.normalize(image) |
| 159 | + |
| 160 | + return image |
| 161 | + |
| 162 | + def postprocess( |
| 163 | + self, |
| 164 | + image, |
| 165 | + output_type: str = "pil", |
| 166 | + ): |
| 167 | + if isinstance(image, torch.Tensor) and output_type == "pt": |
| 168 | + return image |
| 169 | + |
| 170 | + image = self.pt_to_numpy(image) |
| 171 | + |
| 172 | + if output_type == "np": |
| 173 | + return image |
| 174 | + elif output_type == "pil": |
| 175 | + return self.numpy_to_pil(image) |
| 176 | + else: |
| 177 | + raise ValueError(f"Unsupported output_type {output_type}.") |
0 commit comments