Skip to content

Commit 9cee6c0

Browse files
yiyixuxuyiyixuxupatrickvonplatenpcuenca
authored
Add image_processor (huggingface#2617)
* add image_processor --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent 3f462f3 commit 9cee6c0

File tree

3 files changed

+235
-39
lines changed

3 files changed

+235
-39
lines changed

image_processor.py

+177
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Copyright 2023 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import warnings
16+
from typing import Union
17+
18+
import numpy as np
19+
import PIL
20+
import torch
21+
from PIL import Image
22+
23+
from .configuration_utils import ConfigMixin, register_to_config
24+
from .utils import CONFIG_NAME, PIL_INTERPOLATION
25+
26+
27+
class VaeImageProcessor(ConfigMixin):
28+
"""
29+
Image Processor for VAE
30+
31+
Args:
32+
do_resize (`bool`, *optional*, defaults to `True`):
33+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
34+
vae_scale_factor (`int`, *optional*, defaults to `8`):
35+
VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this
36+
factor.
37+
resample (`str`, *optional*, defaults to `lanczos`):
38+
Resampling filter to use when resizing the image.
39+
do_normalize (`bool`, *optional*, defaults to `True`):
40+
Whether to normalize the image to [-1,1]
41+
"""
42+
43+
config_name = CONFIG_NAME
44+
45+
@register_to_config
46+
def __init__(
47+
self,
48+
do_resize: bool = True,
49+
vae_scale_factor: int = 8,
50+
resample: str = "lanczos",
51+
do_normalize: bool = True,
52+
):
53+
super().__init__()
54+
55+
@staticmethod
56+
def numpy_to_pil(images):
57+
"""
58+
Convert a numpy image or a batch of images to a PIL image.
59+
"""
60+
if images.ndim == 3:
61+
images = images[None, ...]
62+
images = (images * 255).round().astype("uint8")
63+
if images.shape[-1] == 1:
64+
# special case for grayscale (single channel) images
65+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
66+
else:
67+
pil_images = [Image.fromarray(image) for image in images]
68+
69+
return pil_images
70+
71+
@staticmethod
72+
def numpy_to_pt(images):
73+
"""
74+
Convert a numpy image to a pytorch tensor
75+
"""
76+
if images.ndim == 3:
77+
images = images[..., None]
78+
79+
images = torch.from_numpy(images.transpose(0, 3, 1, 2))
80+
return images
81+
82+
@staticmethod
83+
def pt_to_numpy(images):
84+
"""
85+
Convert a numpy image to a pytorch tensor
86+
"""
87+
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
88+
return images
89+
90+
@staticmethod
91+
def normalize(images):
92+
"""
93+
Normalize an image array to [-1,1]
94+
"""
95+
return 2.0 * images - 1.0
96+
97+
def resize(self, images: PIL.Image.Image) -> PIL.Image.Image:
98+
"""
99+
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
100+
"""
101+
w, h = images.size
102+
w, h = map(lambda x: x - x % self.vae_scale_factor, (w, h)) # resize to integer multiple of vae_scale_factor
103+
images = images.resize((w, h), resample=PIL_INTERPOLATION[self.resample])
104+
return images
105+
106+
def preprocess(
107+
self,
108+
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
109+
) -> torch.Tensor:
110+
"""
111+
Preprocess the image input, accepted formats are PIL images, numpy arrays or pytorch tensors"
112+
"""
113+
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
114+
if isinstance(image, supported_formats):
115+
image = [image]
116+
elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
117+
raise ValueError(
118+
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
119+
)
120+
121+
if isinstance(image[0], PIL.Image.Image):
122+
if self.do_resize:
123+
image = [self.resize(i) for i in image]
124+
image = [np.array(i).astype(np.float32) / 255.0 for i in image]
125+
image = np.stack(image, axis=0) # to np
126+
image = self.numpy_to_pt(image) # to pt
127+
128+
elif isinstance(image[0], np.ndarray):
129+
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
130+
image = self.numpy_to_pt(image)
131+
_, _, height, width = image.shape
132+
if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0):
133+
raise ValueError(
134+
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.vae_scale_factor}"
135+
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
136+
)
137+
138+
elif isinstance(image[0], torch.Tensor):
139+
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
140+
_, _, height, width = image.shape
141+
if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0):
142+
raise ValueError(
143+
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.vae_scale_factor}"
144+
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
145+
)
146+
147+
# expected range [0,1], normalize to [-1,1]
148+
do_normalize = self.do_normalize
149+
if image.min() < 0:
150+
warnings.warn(
151+
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
152+
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
153+
FutureWarning,
154+
)
155+
do_normalize = False
156+
157+
if do_normalize:
158+
image = self.normalize(image)
159+
160+
return image
161+
162+
def postprocess(
163+
self,
164+
image,
165+
output_type: str = "pil",
166+
):
167+
if isinstance(image, torch.Tensor) and output_type == "pt":
168+
return image
169+
170+
image = self.pt_to_numpy(image)
171+
172+
if output_type == "np":
173+
return image
174+
elif output_type == "pil":
175+
return self.numpy_to_pil(image)
176+
else:
177+
raise ValueError(f"Unsupported output_type {output_type}.")

pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

+29-18
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from diffusers.utils import is_accelerate_available, is_accelerate_version
2525

2626
from ...configuration_utils import FrozenDict
27+
from ...image_processor import VaeImageProcessor
2728
from ...models import AutoencoderKL, UNet2DConditionModel
2829
from ...schedulers import KarrasDiffusionSchedulers
2930
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
@@ -192,7 +193,6 @@ def __init__(
192193
new_config = dict(unet.config)
193194
new_config["sample_size"] = 64
194195
unet._internal_dict = FrozenDict(new_config)
195-
196196
self.register_modules(
197197
vae=vae,
198198
text_encoder=text_encoder,
@@ -203,7 +203,11 @@ def __init__(
203203
feature_extractor=feature_extractor,
204204
)
205205
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
206-
self.register_to_config(requires_safety_checker=requires_safety_checker)
206+
207+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
208+
self.register_to_config(
209+
requires_safety_checker=requires_safety_checker,
210+
)
207211

208212
def enable_sequential_cpu_offload(self, gpu_id=0):
209213
r"""
@@ -415,21 +419,17 @@ def _encode_prompt(
415419
return prompt_embeds
416420

417421
def run_safety_checker(self, image, device, dtype):
418-
if self.safety_checker is not None:
419-
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
420-
image, has_nsfw_concept = self.safety_checker(
421-
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
422-
)
423-
else:
424-
has_nsfw_concept = None
422+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
423+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
424+
image, has_nsfw_concept = self.safety_checker(
425+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
426+
)
425427
return image, has_nsfw_concept
426428

427429
def decode_latents(self, latents):
428430
latents = 1 / self.vae.config.scaling_factor * latents
429431
image = self.vae.decode(latents).sample
430432
image = (image / 2 + 0.5).clamp(0, 1)
431-
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
432-
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
433433
return image
434434

435435
def prepare_extra_step_kwargs(self, generator, eta):
@@ -663,7 +663,7 @@ def __call__(
663663
)
664664

665665
# 4. Preprocess image
666-
image = preprocess(image)
666+
image = self.image_processor.preprocess(image)
667667

668668
# 5. set timesteps
669669
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -703,15 +703,26 @@ def __call__(
703703
if callback is not None and i % callback_steps == 0:
704704
callback(i, t, latents)
705705

706-
# 9. Post-processing
706+
if output_type not in ["latent", "pt", "np", "pil"]:
707+
deprecation_message = (
708+
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
709+
"`pil`, `np`, `pt`, `latent`"
710+
)
711+
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
712+
output_type = "np"
713+
714+
if output_type == "latent":
715+
image = latents
716+
has_nsfw_concept = None
717+
707718
image = self.decode_latents(latents)
708719

709-
# 10. Run safety checker
710-
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
720+
if self.safety_checker is not None:
721+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
722+
else:
723+
has_nsfw_concept = False
711724

712-
# 11. Convert to PIL
713-
if output_type == "pil":
714-
image = self.numpy_to_pil(image)
725+
image = self.image_processor.postprocess(image, output_type=output_type)
715726

716727
# Offload last model to CPU
717728
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:

pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

+29-21
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
2323

2424
from ...configuration_utils import FrozenDict
25+
from ...image_processor import VaeImageProcessor
2526
from ...models import AutoencoderKL, UNet2DConditionModel
2627
from ...schedulers import KarrasDiffusionSchedulers
2728
from ...utils import (
@@ -119,7 +120,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
119120
"""
120121
_optional_components = ["safety_checker", "feature_extractor"]
121122

122-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
123123
def __init__(
124124
self,
125125
vae: AutoencoderKL,
@@ -196,7 +196,6 @@ def __init__(
196196
new_config = dict(unet.config)
197197
new_config["sample_size"] = 64
198198
unet._internal_dict = FrozenDict(new_config)
199-
200199
self.register_modules(
201200
vae=vae,
202201
text_encoder=text_encoder,
@@ -207,7 +206,11 @@ def __init__(
207206
feature_extractor=feature_extractor,
208207
)
209208
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
210-
self.register_to_config(requires_safety_checker=requires_safety_checker)
209+
210+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
211+
self.register_to_config(
212+
requires_safety_checker=requires_safety_checker,
213+
)
211214

212215
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
213216
def enable_sequential_cpu_offload(self, gpu_id=0):
@@ -422,24 +425,18 @@ def _encode_prompt(
422425

423426
return prompt_embeds
424427

425-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
426428
def run_safety_checker(self, image, device, dtype):
427-
if self.safety_checker is not None:
428-
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
429-
image, has_nsfw_concept = self.safety_checker(
430-
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
431-
)
432-
else:
433-
has_nsfw_concept = None
429+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
430+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
431+
image, has_nsfw_concept = self.safety_checker(
432+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
433+
)
434434
return image, has_nsfw_concept
435435

436-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
437436
def decode_latents(self, latents):
438437
latents = 1 / self.vae.config.scaling_factor * latents
439438
image = self.vae.decode(latents).sample
440439
image = (image / 2 + 0.5).clamp(0, 1)
441-
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
442-
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
443440
return image
444441

445442
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
@@ -674,7 +671,7 @@ def __call__(
674671
)
675672

676673
# 4. Preprocess image
677-
image = preprocess(image)
674+
image = self.image_processor.preprocess(image)
678675

679676
# 5. set timesteps
680677
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -714,15 +711,26 @@ def __call__(
714711
if callback is not None and i % callback_steps == 0:
715712
callback(i, t, latents)
716713

717-
# 9. Post-processing
714+
if output_type not in ["latent", "pt", "np", "pil"]:
715+
deprecation_message = (
716+
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
717+
"`pil`, `np`, `pt`, `latent`"
718+
)
719+
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
720+
output_type = "np"
721+
722+
if output_type == "latent":
723+
image = latents
724+
has_nsfw_concept = None
725+
718726
image = self.decode_latents(latents)
719727

720-
# 10. Run safety checker
721-
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
728+
if self.safety_checker is not None:
729+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
730+
else:
731+
has_nsfw_concept = False
722732

723-
# 11. Convert to PIL
724-
if output_type == "pil":
725-
image = self.numpy_to_pil(image)
733+
image = self.image_processor.postprocess(image, output_type=output_type)
726734

727735
# Offload last model to CPU
728736
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:

0 commit comments

Comments
 (0)