From e5ca2b28f70e9cc16a9b26390a7f21ac20f1cbe1 Mon Sep 17 00:00:00 2001 From: sunhs Date: Fri, 30 Jun 2023 15:26:13 +0800 Subject: [PATCH 01/15] add noise_sampler to StableDiffusionKDiffusionPipeline --- .../pipeline_stable_diffusion_k_diffusion.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index e2800342e578..283ddb6f19b3 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -13,12 +13,13 @@ # limitations under the License. import importlib +import inspect import warnings from typing import Callable, List, Optional, Union import torch from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser -from k_diffusion.sampling import get_sigmas_karras +from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin @@ -464,6 +465,7 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, use_karras_sigmas: Optional[bool] = False, + noise_sampler_seed: Optional[int] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -524,6 +526,8 @@ def __call__( Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to `DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M Karras`. + noise_sampler_seed (`int`, *optional*, defaults to `None`): + The random seed to use for the noise sampler. If `None`, a random seed will be generated. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. @@ -608,7 +612,14 @@ def model_fn(x, t): return noise_pred # 8. Run k-diffusion solver - latents = self.sampler(model_fn, latents, sigmas) + sampler_kwargs = {} + + if "noise_sampler" in inspect.signature(self.sampler).parameters: + min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed) + sampler_kwargs["noise_sampler"] = noise_sampler + + latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] From af16bf6ac7882721827e25956ee53b495b662bb8 Mon Sep 17 00:00:00 2001 From: Aisuko Date: Sat, 1 Jul 2023 16:07:59 +1000 Subject: [PATCH 02/15] fix/docs: Fix the broken doc links (#3897) * fix/docs: Fix the broken doc links Signed-off-by: GitHub * Update docs/source/en/using-diffusers/write_own_pipeline.mdx Co-authored-by: Pedro Cuenca --------- Signed-off-by: GitHub Co-authored-by: Pedro Cuenca --- docs/source/en/using-diffusers/write_own_pipeline.mdx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/using-diffusers/write_own_pipeline.mdx b/docs/source/en/using-diffusers/write_own_pipeline.mdx index c7e257f4fa36..ca5ea38b4ad2 100644 --- a/docs/source/en/using-diffusers/write_own_pipeline.mdx +++ b/docs/source/en/using-diffusers/write_own_pipeline.mdx @@ -286,5 +286,5 @@ This is really what 🧨 Diffusers is designed for: to make it intuitive and eas For your next steps, feel free to: -* Learn how to [build and contribute a pipeline](using-diffusers/#contribute_pipeline) to 🧨 Diffusers. We can't wait and see what you'll come up with! -* Explore [existing pipelines](./api/pipelines/overview) in the library, and see if you can deconstruct and build a pipeline from scratch using the models and schedulers separately. +* Learn how to [build and contribute a pipeline](contribute_pipeline) to 🧨 Diffusers. We can't wait and see what you'll come up with! +* Explore [existing pipelines](../api/pipelines/overview) in the library, and see if you can deconstruct and build a pipeline from scratch using the models and schedulers separately. From 6f61885139a8b28da18c0ea035611505b790dea1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 2 Jul 2023 13:19:27 +0200 Subject: [PATCH 03/15] Add video img2img (#3900) * Add image to image video * Improve * better naming * make fix copies * add docs * finish tests * trigger tests * make style * correct * finish * Fix more * make style * finish --- .../source/en/api/pipelines/text_to_video.mdx | 63 ++ src/diffusers/__init__.py | 1 + src/diffusers/models/autoencoder_kl.py | 7 +- src/diffusers/pipelines/__init__.py | 2 +- .../text_to_video_synthesis/__init__.py | 3 +- .../pipeline_text_to_video_synth.py | 3 + .../pipeline_text_to_video_synth_img2img.py | 770 ++++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 + tests/pipelines/test_pipelines_common.py | 4 +- .../text_to_video/test_video_to_video.py | 195 +++++ 10 files changed, 1058 insertions(+), 5 deletions(-) create mode 100644 src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py create mode 100644 tests/pipelines/text_to_video/test_video_to_video.py diff --git a/docs/source/en/api/pipelines/text_to_video.mdx b/docs/source/en/api/pipelines/text_to_video.mdx index 82b2f19ce1b2..75868d7dd6ea 100644 --- a/docs/source/en/api/pipelines/text_to_video.mdx +++ b/docs/source/en/api/pipelines/text_to_video.mdx @@ -37,9 +37,12 @@ Resources: | Pipeline | Tasks | Demo |---|---|:---:| | [TextToVideoSDPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py) | *Text-to-Video Generation* | [🤗 Spaces](https://huggingface.co/spaces/damo-vilab/modelscope-text-to-video-synthesis) +| [VideoToVideoSDPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py) | *Text-Guided Video-to-Video Generation* | [(TODO)🤗 Spaces]() ## Usage example +### `text-to-video-ms-1.7b` + Let's start by generating a short video with the default length of 16 frames (2s at 8 fps): ```python @@ -119,12 +122,72 @@ Here are some sample outputs: +### `cerspense/zeroscope_v2_576w` & `cerspense/zeroscope_v2_XL` + +Zeroscope are watermark-free model and have been trained on specific sizes such as `576x320` and `1024x576`. +One should first generate a video using the lower resolution checkpoint [`cerspense/zeroscope_v2_576w`](https://huggingface.co/cerspense/zeroscope_v2_576w) with [`TextToVideoSDPipeline`], +which can then be upscaled using [`VideoToVideoSDPipeline`] and [`cerspense/zeroscope_v2_XL`](https://huggingface.co/cerspense/zeroscope_v2_XL). + + +```py +import torch +from diffusers import DiffusionPipeline +from diffusers.utils import export_to_video + +pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16) +pipe.enable_model_cpu_offload() + +# memory optimization +pipe.enable_vae_slicing() + +prompt = "Darth Vader surfing a wave" +video_frames = pipe(prompt, num_frames=24).frames +video_path = export_to_video(video_frames) +video_path +``` + +Now the video can be upscaled: + +```py +pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_XL", torch_dtype=torch.float16) +pipe.vae.enable_slicing() +pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) +pipe.enable_model_cpu_offload() + +video = [Image.fromarray(frame).resize((1024, 576)) for frame in video_frames] + +video_frames = pipe(prompt, video=video, strength=0.6).frames +video_path = export_to_video(video_frames) +video_path +``` + +Here are some sample outputs: + + + + + +
+ Darth vader surfing in waves. +
+ Darth vader surfing in waves. +
+ ## Available checkpoints * [damo-vilab/text-to-video-ms-1.7b](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b/) * [damo-vilab/text-to-video-ms-1.7b-legacy](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b-legacy) +* [cerspense/zeroscope_v2_576w](https://huggingface.co/cerspense/zeroscope_v2_576w) +* [cerspense/zeroscope_v2_XL](https://huggingface.co/cerspense/zeroscope_v2_XL) ## TextToVideoSDPipeline [[autodoc]] TextToVideoSDPipeline - all - __call__ + +## VideoToVideoSDPipeline +[[autodoc]] VideoToVideoSDPipeline + - all + - __call__ diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 02907075345e..764f9204dffb 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -173,6 +173,7 @@ VersatileDiffusionImageVariationPipeline, VersatileDiffusionPipeline, VersatileDiffusionTextToImagePipeline, + VideoToVideoSDPipeline, VQDiffusionPipeline, ) diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index d61281a53e7c..ddb9bde0ee0a 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -229,7 +229,12 @@ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderK if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): return self.tiled_encode(x, return_dict=return_dict) - h = self.encoder(x) + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self.encoder(x) + moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b1650240848a..ca57756c6aa4 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -89,7 +89,7 @@ StableUnCLIPPipeline, ) from .stable_diffusion_safe import StableDiffusionPipelineSafe - from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline + from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline, VideoToVideoSDPipeline from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder from .versatile_diffusion import ( diff --git a/src/diffusers/pipelines/text_to_video_synthesis/__init__.py b/src/diffusers/pipelines/text_to_video_synthesis/__init__.py index 165a1a0f0d98..d70c1c2ea2a8 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/__init__.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/__init__.py @@ -28,5 +28,6 @@ class TextToVideoSDPipelineOutput(BaseOutput): except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: - from .pipeline_text_to_video_synth import TextToVideoSDPipeline # noqa: F401 + from .pipeline_text_to_video_synth import TextToVideoSDPipeline + from .pipeline_text_to_video_synth_img2img import VideoToVideoSDPipeline # noqa: F401 from .pipeline_text_to_video_zero import TextToVideoZeroPipeline diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 8bf4bafa4fe5..e30f183808a5 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -672,6 +672,9 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) + if output_type == "latent": + return TextToVideoSDPipelineOutput(frames=latents) + video_tensor = self.decode_latents(latents) if output_type == "pt": diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py new file mode 100644 index 000000000000..ce5109a58213 --- /dev/null +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -0,0 +1,770 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet3DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import TextToVideoSDPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler + >>> from diffusers.utils import export_to_video + + >>> pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16) + >>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + >>> pipe.to("cuda") + + >>> prompt = "spiderman running in the desert" + >>> video_frames = pipe(prompt, num_inference_steps=40, height=320, width=576, num_frames=24).frames + >>> # safe low-res video + >>> video_path = export_to_video(video_frames, output_video_path="./video_576_spiderman.mp4") + + >>> # let's offload the text-to-image model + >>> pipe.to("cpu") + + >>> # and load the image-to-image model + >>> pipe = DiffusionPipeline.from_pretrained( + ... "cerspense/zeroscope_v2_XL", torch_dtype=torch.float16, revision="refs/pr/15" + ... ) + >>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + >>> pipe.enable_model_cpu_offload() + + >>> # The VAE consumes A LOT of memory, let's make sure we run it in sliced mode + >>> pipe.vae.enable_slicing() + + >>> # now let's upscale it + >>> video = [Image.fromarray(frame).resize((1024, 576)) for frame in video_frames] + + >>> # and denoise it + >>> video_frames = pipe(prompt, video=video, strength=0.6).frames + >>> video_path = export_to_video(video_frames, output_video_path="./video_1024_spiderman.mp4") + >>> video_path + ``` +""" + + +def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: + # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 + # reshape to ncfhw + mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) + std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) + # unnormalize back to [0,1] + video = video.mul_(std).add_(mean) + video.clamp_(0, 1) + # prepare the final outputs + i, c, f, h, w = video.shape + images = video.permute(2, 3, 0, 4, 1).reshape( + f, h, i * w, c + ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) + images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) + images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c + return images + + +def preprocess_video(video): + supported_formats = (np.ndarray, torch.Tensor, PIL.Image.Image) + + if isinstance(video, supported_formats): + video = [video] + elif not (isinstance(video, list) and all(isinstance(i, supported_formats) for i in video)): + raise ValueError( + f"Input is in incorrect format: {[type(i) for i in video]}. Currently, we only support {', '.join(supported_formats)}" + ) + + if isinstance(video[0], PIL.Image.Image): + video = [np.array(frame) for frame in video] + + if isinstance(video[0], np.ndarray): + video = np.concatenate(video, axis=0) if video[0].ndim == 5 else np.stack(video, axis=0) + + if video.dtype == np.uint8: + video = np.array(video).astype(np.float32) / 255.0 + + if video.ndim == 4: + video = video[None, ...] + + video = torch.from_numpy(video.transpose(0, 4, 1, 2, 3)) + + elif isinstance(video[0], torch.Tensor): + video = torch.cat(video, axis=0) if video[0].ndim == 5 else torch.stack(video, axis=0) + + # don't need any preprocess if the video is latents + channel = video.shape[1] + if channel == 4: + return video + + # move channels before num_frames + video = video.permute(0, 2, 1, 3, 4) + + # normalize video + video = 2.0 * video - 1.0 + + return video + + +class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Same as Stable Diffusion 2. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded + to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a + submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.vae, self.unet]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents).sample + video = ( + image[None, :] + .reshape( + ( + batch_size, + num_frames, + -1, + ) + + image.shape[2:] + ) + .permute(0, 2, 1, 3, 4) + ) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs + def check_inputs( + self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, video, timestep, batch_size, dtype, device, generator=None): + video = video.to(device=device, dtype=dtype) + + # change from (b, c, f, h, w) -> (b * f, c, w, h) + bsz, channel, frames, width, height = video.shape + video = video.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + + if video.shape[1] == 4: + init_latents = video + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + self.vae.encode(video[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(video).latent_dist.sample(generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `video` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + latents = latents[None, :].reshape((bsz, frames, latents.shape[1]) + latents.shape[2:]).permute(0, 2, 1, 3, 4) + + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + video: Union[List[np.ndarray], torch.FloatTensor] = None, + strength: float = 0.6, + num_inference_steps: int = 50, + guidance_scale: float = 15.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + video: (`List[np.ndarray]` or `torch.FloatTensor`): + `video` frames or tensor representing a video batch, that will be used as the starting point for the + process. Can also accpet video latents as `image`, if passing latents directly, it will not be encoded + again. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, + usually at the expense of lower video quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generate video. Choose between `torch.FloatTensor` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated frames. + """ + # 0. Default height and width to unet + num_images_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Preprocess video + video = preprocess_video(video) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + latents = self.prepare_latents(video, latent_timestep, batch_size, prompt_embeds.dtype, device, generator) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # reshape latents + bsz, channel, frames, width, height = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # reshape latents back + latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + return TextToVideoSDPipelineOutput(frames=latents) + + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + + video_tensor = self.decode_latents(latents) + + if output_type == "pt": + video = video_tensor + else: + video = tensor2vid(video_tensor) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (video,) + + return TextToVideoSDPipelineOutput(frames=video) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 3f0b17d879e5..0dbc8f1f6f99 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -782,6 +782,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class VideoToVideoSDPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class VQDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 008a8a2e6367..9fb3e167facc 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -652,11 +652,11 @@ def _test_xformers_attention_forwardGenerator_pass( pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(torch_device) - output_without_offload = pipe(**inputs)[0] + output_without_offload = pipe(**inputs)[0].cpu() pipe.enable_xformers_memory_efficient_attention() inputs = self.get_dummy_inputs(torch_device) - output_with_offload = pipe(**inputs)[0] + output_with_offload = pipe(**inputs)[0].cpu() if test_max_difference: max_diff = np.abs(output_with_offload - output_without_offload).max() diff --git a/tests/pipelines/text_to_video/test_video_to_video.py b/tests/pipelines/text_to_video/test_video_to_video.py new file mode 100644 index 000000000000..41e213c43dea --- /dev/null +++ b/tests/pipelines/text_to_video/test_video_to_video.py @@ -0,0 +1,195 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + UNet3DConditionModel, + VideoToVideoSDPipeline, +) +from diffusers.utils import floats_tensor, is_xformers_available, skip_mps +from diffusers.utils.testing_utils import enable_full_determinism, slow, torch_device + +from ..pipeline_params import ( + TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_PARAMS, +) +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +@skip_mps +class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = VideoToVideoSDPipeline + params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS.union({"video"}) - {"image", "width", "height"} + batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"video"}) - {"image"} + required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} + test_attention_slicing = False + + # No `output_type`. + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback", + "callback_steps", + ] + ) + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet3DConditionModel( + block_out_channels=(32, 64, 64, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D"), + up_block_types=("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), + cross_attention_dim=32, + attention_head_dim=4, + ) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=512, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + # 3 frames + video = floats_tensor((1, 3, 3, 32, 32), rng=random.Random(seed)).to(device) + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "video": video, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "pt", + } + return inputs + + def test_text_to_video_default_case(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = VideoToVideoSDPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["output_type"] = "np" + frames = sd_pipe(**inputs).frames + image_slice = frames[0][-3:, -3:, -1] + + assert frames[0].shape == (32, 32, 3) + expected_slice = np.array([106, 117, 113, 174, 137, 112, 148, 151, 131]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False, expected_max_diff=5e-3) + + # (todo): sayakpaul + @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.") + def test_inference_batch_consistent(self): + pass + + # (todo): sayakpaul + @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.") + def test_inference_batch_single_identical(self): + pass + + @unittest.skip(reason="`num_images_per_prompt` argument is not supported for this pipeline.") + def test_num_images_per_prompt(self): + pass + + def test_progress_bar(self): + return super().test_progress_bar() + + +@slow +@skip_mps +class VideoToVideoSDPipelineSlowTests(unittest.TestCase): + def test_two_step_model(self): + pipe = VideoToVideoSDPipeline.from_pretrained("cerspense/zeroscope_v2_XL", torch_dtype=torch.float16) + pipe.enable_model_cpu_offload() + + # 10 frames + generator = torch.Generator(device="cpu").manual_seed(0) + video = torch.randn((1, 10, 3, 1024, 576), generator=generator) + video = video.to("cuda") + + prompt = "Spiderman is surfing" + + video_frames = pipe(prompt, video=video, generator=generator, num_inference_steps=3, output_type="pt").frames + + expected_array = np.array([-1.0458984, -1.1279297, -0.9663086, -0.91503906, -0.75097656]) + assert np.abs(video_frames.cpu().numpy()[0, 0, 0, 0, -5:] - expected_array).sum() < 1e-2 From a23f30cbd486434d636f54838ac91831cb358111 Mon Sep 17 00:00:00 2001 From: Aisuko Date: Mon, 3 Jul 2023 20:28:05 +1000 Subject: [PATCH 04/15] fix/doc-code: Updating to the latest version parameters (#3924) fix/doc-code: update to use the new parameter Signed-off-by: GitHub --- docs/source/en/tutorials/basic_training.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/tutorials/basic_training.mdx b/docs/source/en/tutorials/basic_training.mdx index c8f5c7fac780..2cf9128f3dd4 100644 --- a/docs/source/en/tutorials/basic_training.mdx +++ b/docs/source/en/tutorials/basic_training.mdx @@ -313,7 +313,7 @@ Now you can wrap all these components together in a training loop with 🤗 Acce ... mixed_precision=config.mixed_precision, ... gradient_accumulation_steps=config.gradient_accumulation_steps, ... log_with="tensorboard", -... logging_dir=os.path.join(config.output_dir, "logs"), +... project_dir=os.path.join(config.output_dir, "logs"), ... ) ... if accelerator.is_main_process: ... if config.push_to_hub: From ccc329631333cfadd10695b57a4709904a9ece09 Mon Sep 17 00:00:00 2001 From: Aisuko Date: Mon, 3 Jul 2023 20:28:42 +1000 Subject: [PATCH 05/15] fix/doc: no import torch issue (#3923) Ffix/doc: no import torch issue Signed-off-by: GitHub --- docs/source/en/stable_diffusion.mdx | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/stable_diffusion.mdx b/docs/source/en/stable_diffusion.mdx index 78fa848421d8..7684052c313a 100644 --- a/docs/source/en/stable_diffusion.mdx +++ b/docs/source/en/stable_diffusion.mdx @@ -52,6 +52,8 @@ pipeline = pipeline.to("cuda") To make sure you can use the same image and improve on it, use a [`Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) and set a seed for [reproducibility](./using-diffusers/reproducibility): ```python +import torch + generator = torch.Generator("cuda").manual_seed(0) ``` From 65e1433378c63365bec378f8463941e072ee7572 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 3 Jul 2023 15:10:07 +0200 Subject: [PATCH 06/15] Correct controlnet out of list error (#3928) * Correct controlnet out of list error * Apply suggestions from code review * correct tests * correct tests * fix * test all * Apply suggestions from code review * test all * test all * Apply suggestions from code review * Apply suggestions from code review * fix more tests * Fix more * Apply suggestions from code review * finish * Apply suggestions from code review * Update src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py * finish --- .../controlnet/pipeline_controlnet.py | 4 +- .../controlnet/pipeline_controlnet_img2img.py | 4 +- .../controlnet/pipeline_controlnet_inpaint.py | 4 +- ...eline_stable_diffusion_instruct_pix2pix.py | 2 +- .../schedulers/scheduling_deis_multistep.py | 15 +++++- .../scheduling_dpmsolver_multistep.py | 7 +-- .../scheduling_dpmsolver_singlestep.py | 7 +-- .../schedulers/scheduling_unipc_multistep.py | 14 ++++++ .../altdiffusion/test_alt_diffusion.py | 6 ++- tests/pipelines/controlnet/test_controlnet.py | 14 ++++-- .../controlnet/test_controlnet_img2img.py | 14 ++++-- .../controlnet/test_controlnet_inpaint.py | 14 ++++-- .../stable_diffusion/test_stable_diffusion.py | 6 ++- .../test_stable_diffusion_image_variation.py | 4 +- .../test_stable_diffusion_img2img.py | 6 ++- .../test_stable_diffusion_inpaint.py | 6 ++- ...st_stable_diffusion_instruction_pix2pix.py | 4 +- .../test_stable_diffusion_model_editing.py | 6 ++- .../test_stable_diffusion_pix2pix_zero.py | 6 ++- .../test_stable_diffusion.py | 6 ++- ...test_stable_diffusion_attend_and_excite.py | 4 +- .../test_stable_diffusion_depth.py | 6 ++- .../test_stable_diffusion_inpaint.py | 6 ++- .../test_stable_diffusion_latent_upscale.py | 49 ++++++++++++++++++- .../stable_unclip/test_stable_unclip.py | 11 ++++- .../test_stable_unclip_img2img.py | 5 +- tests/pipelines/test_pipelines_common.py | 46 +++++++++++++++++ 27 files changed, 225 insertions(+), 51 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index dddfc3591b66..c266e8b20e74 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -947,9 +947,9 @@ def __call__( # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] - for i in range(num_inference_steps): + for i in range(len(timesteps)): keeps = [ - 1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e) + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index c7a0db96e8c0..fd013c4974f1 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -1040,9 +1040,9 @@ def __call__( # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] - for i in range(num_inference_steps): + for i in range(len(timesteps)): keeps = [ - 1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e) + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index bfaaaae49401..7de3f1dd9d88 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -1275,9 +1275,9 @@ def __call__( # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] - for i in range(num_inference_steps): + for i in range(len(timesteps)): keeps = [ - 1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e) + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 25102ae7cf4a..367e401d57f8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -374,7 +374,7 @@ def __call__( # predicted_original_sample instead of the noise_pred. So we need to compute the # predicted_original_sample here if we are using a karras style scheduler. if scheduler_is_in_sigma_space: - step_index = (self.scheduler.timesteps == t).nonzero().item() + step_index = (self.scheduler.timesteps == t).nonzero()[0].item() sigma = self.scheduler.sigmas[step_index] noise_pred = latent_model_input - sigma * noise_pred diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 8ea001a882d0..56c362018c18 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -103,7 +103,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DEIS for steps < 15, especially for steps <= 10. - + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -125,6 +128,7 @@ def __init__( algorithm_type: str = "deis", solver_type: str = "logrho", lower_order_final: bool = True, + use_karras_sigmas: Optional[bool] = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -188,6 +192,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic .astype(np.int64) ) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + log_sigmas = np.log(sigmas) + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps = np.flip(timesteps).copy().astype(np.int64) + + self.sigmas = torch.from_numpy(sigmas) + # when num_inference_steps == num_train_timesteps, we can end up with # duplicates in timesteps. _, unique_indices = np.unique(timesteps, return_index=True) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index e72b1bdc23b5..d7c29d5488a5 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -203,7 +203,6 @@ def __init__( self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 - self.use_karras_sigmas = use_karras_sigmas def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ @@ -225,13 +224,15 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc .astype(np.int64) ) - if self.use_karras_sigmas: - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.flip(timesteps).copy().astype(np.int64) + self.sigmas = torch.from_numpy(sigmas) + # when num_inference_steps == num_train_timesteps, we can end up with # duplicates in timesteps. _, unique_indices = np.unique(timesteps, return_index=True) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 7fa8eabb5a15..721dd5e5bb85 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -202,7 +202,6 @@ def __init__( self.model_outputs = [None] * solver_order self.sample = None self.order_list = self.get_order_list(num_train_timesteps) - self.use_karras_sigmas = use_karras_sigmas def get_order_list(self, num_inference_steps: int) -> List[int]: """ @@ -259,13 +258,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic .astype(np.int64) ) - if self.use_karras_sigmas: - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.flip(timesteps).copy().astype(np.int64) + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to(device) self.model_outputs = [None] * self.config.solver_order self.sample = None diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 2cce68f7d962..7233258a4766 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -117,6 +117,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): by disable the corrector at the first few steps (e.g., disable_corrector=[0]) solver_p (`SchedulerMixin`, default `None`): can be any other scheduler. If specified, the algorithm will become solver_p + UniC. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -140,6 +144,7 @@ def __init__( lower_order_final: bool = True, disable_corrector: List[int] = [], solver_p: SchedulerMixin = None, + use_karras_sigmas: Optional[bool] = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -201,6 +206,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic .astype(np.int64) ) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + log_sigmas = np.log(sigmas) + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps = np.flip(timesteps).copy().astype(np.int64) + + self.sigmas = torch.from_numpy(sigmas) + # when num_inference_steps == num_train_timesteps, we can end up with # duplicates in timesteps. _, unique_indices = np.unique(timesteps, return_index=True) diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion.py b/tests/pipelines/altdiffusion/test_alt_diffusion.py index 1344d33a2552..3f16964fd567 100644 --- a/tests/pipelines/altdiffusion/test_alt_diffusion.py +++ b/tests/pipelines/altdiffusion/test_alt_diffusion.py @@ -29,13 +29,15 @@ from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() -class AltDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class AltDiffusionPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = AltDiffusionPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index 906e1e7ee66f..a548983c3841 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -46,7 +46,11 @@ TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS, ) -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import ( + PipelineKarrasSchedulerTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, +) enable_full_determinism() @@ -97,7 +101,9 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout): out_queue.join() -class ControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class ControlNetPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionControlNetPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -207,7 +213,9 @@ def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) -class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionMultiControlNetPipelineFastTests( + PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionControlNetPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py index bc0e96b2f92b..c46593f03e5e 100644 --- a/tests/pipelines/controlnet/test_controlnet_img2img.py +++ b/tests/pipelines/controlnet/test_controlnet_img2img.py @@ -42,13 +42,19 @@ TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS, ) -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import ( + PipelineKarrasSchedulerTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, +) enable_full_determinism() -class ControlNetImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class ControlNetImg2ImgPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionControlNetImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS @@ -161,7 +167,9 @@ def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) -class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionMultiControlNetPipelineFastTests( + PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionControlNetImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py index 81647d968b6b..cf423f4c49d0 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -42,13 +42,19 @@ TEXT_GUIDED_IMAGE_INPAINTING_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, ) -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import ( + PipelineKarrasSchedulerTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, +) enable_full_determinism() -class ControlNetInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class ControlNetInpaintPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionControlNetInpaintPipeline params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS @@ -237,7 +243,9 @@ def get_dummy_components(self): return components -class MultiControlNetInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class MultiControlNetInpaintPipelineFastTests( + PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionControlNetInpaintPipeline params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 93abe7ae58bc..7daf3fcda4a2 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -50,7 +50,7 @@ from ...models.test_lora_layers import create_unet_lora_layers from ...models.test_models_unet_2d_condition import create_lora_layers from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() @@ -88,7 +88,9 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout): out_queue.join() -class StableDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableDiffusionPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py index e16478f06112..580c78675a92 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py @@ -33,14 +33,14 @@ from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu from ..pipeline_params import IMAGE_VARIATION_BATCH_PARAMS, IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() class StableDiffusionImageVariationPipelineFastTests( - PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase ): pipeline_class = StableDiffusionImageVariationPipeline params = IMAGE_VARIATION_PARAMS diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index eefbc83ce9d7..d1f7a49467e6 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -46,7 +46,7 @@ TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS, ) -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() @@ -84,7 +84,9 @@ def _test_img2img_compile(in_queue, out_queue, timeout): out_queue.join() -class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableDiffusionImg2ImgPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index f761f245883f..e7b084acb280 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -42,7 +42,7 @@ from ...models.test_models_unet_2d_condition import create_lora_layers from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() @@ -82,7 +82,9 @@ def _test_inpaint_compile(in_queue, out_queue, timeout): out_queue.join() -class StableDiffusionInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableDiffusionInpaintPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionInpaintPipeline params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py index 691427b1c6eb..513e11c105d5 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py @@ -40,14 +40,14 @@ TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS, ) -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() class StableDiffusionInstructPix2PixPipelineFastTests( - PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase ): pipeline_class = StableDiffusionInstructPix2PixPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "cross_attention_kwargs"} diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py index f47a70c4ece8..81d1baed5df6 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py @@ -32,14 +32,16 @@ from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() @skip_mps -class StableDiffusionModelEditingPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableDiffusionModelEditingPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionModelEditingPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index 6f41d2c43c8e..1b17f8b31be9 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -41,7 +41,11 @@ TEXT_GUIDED_IMAGE_VARIATION_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, ) -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, assert_mean_pixel_difference +from ..test_pipelines_common import ( + PipelineLatentTesterMixin, + PipelineTesterMixin, + assert_mean_pixel_difference, +) enable_full_determinism() diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index 33cc7f638ec2..a26abfa50096 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -36,13 +36,15 @@ from diffusers.utils.testing_utils import CaptureLogger, enable_full_determinism, require_torch_gpu from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() -class StableDiffusion2PipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableDiffusion2PipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py index 304ddacd2c36..b4d49c92425c 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py @@ -30,7 +30,7 @@ from diffusers.utils.testing_utils import require_torch_gpu from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin torch.backends.cuda.matmul.allow_tf32 = False @@ -38,7 +38,7 @@ @skip_mps class StableDiffusionAttendAndExcitePipelineFastTests( - PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase ): pipeline_class = StableDiffusionAttendAndExcitePipeline test_attention_slicing = False diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py index f393967c7de4..fe2cf73da096 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py @@ -57,14 +57,16 @@ TEXT_GUIDED_IMAGE_VARIATION_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, ) -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() @skip_mps -class StableDiffusionDepth2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableDiffusionDepth2ImgPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionDepth2ImgPipeline test_save_load_optional_components = False params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py index 37c254f367f3..68a4b5132375 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py @@ -27,13 +27,15 @@ from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() -class StableDiffusion2InpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableDiffusion2InpaintPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionInpaintPipeline params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py index b94aaca4258a..5e1d610efcaf 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py @@ -21,6 +21,7 @@ import torch from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer +import diffusers from diffusers import ( AutoencoderKL, EulerDiscreteScheduler, @@ -28,17 +29,25 @@ StableDiffusionPipeline, UNet2DConditionModel, ) +from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() -class StableDiffusionLatentUpscalePipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +def check_same_shape(tensor_list): + shapes = [tensor.shape for tensor in tensor_list] + return all(shape == shapes[0] for shape in shapes[1:]) + + +class StableDiffusionLatentUpscalePipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionLatentUpscalePipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - { "height", @@ -185,6 +194,42 @@ def test_save_load_local(self): def test_save_load_optional_components(self): super().test_save_load_optional_components(expected_max_difference=3e-3) + def test_karras_schedulers_shape(self): + skip_schedulers = [ + "DDIMScheduler", + "DDPMScheduler", + "PNDMScheduler", + "HeunDiscreteScheduler", + "EulerAncestralDiscreteScheduler", + "KDPM2DiscreteScheduler", + "KDPM2AncestralDiscreteScheduler", + "DPMSolverSDEScheduler", + ] + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + + # make sure that PNDM does not need warm-up + pipe.scheduler.register_to_config(skip_prk_steps=True) + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + inputs = self.get_dummy_inputs(torch_device) + inputs["num_inference_steps"] = 2 + + outputs = [] + for scheduler_enum in KarrasDiffusionSchedulers: + if scheduler_enum.name in skip_schedulers: + # no sigma schedulers are not supported + # no schedulers + continue + + scheduler_cls = getattr(diffusers, scheduler_enum.name) + pipe.scheduler = scheduler_cls.from_config(pipe.scheduler.config) + output = pipe(**inputs)[0] + outputs.append(output) + + assert check_same_shape(outputs) + @require_torch_gpu @slow diff --git a/tests/pipelines/stable_unclip/test_stable_unclip.py b/tests/pipelines/stable_unclip/test_stable_unclip.py index 4bbbad757edf..8d5edda16904 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip.py @@ -16,13 +16,20 @@ from diffusers.utils.testing_utils import enable_full_determinism, load_numpy, require_torch_gpu, slow, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, assert_mean_pixel_difference +from ..test_pipelines_common import ( + PipelineKarrasSchedulerTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, + assert_mean_pixel_difference, +) enable_full_determinism() -class StableUnCLIPPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableUnCLIPPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableUnCLIPPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS diff --git a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py index 741343066133..52581eb574e0 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py @@ -30,6 +30,7 @@ from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ..test_pipelines_common import ( + PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, assert_mean_pixel_difference, @@ -39,7 +40,9 @@ enable_full_determinism() -class StableUnCLIPImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableUnCLIPImg2ImgPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableUnCLIPImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 9fb3e167facc..52dd4afd6b21 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -14,6 +14,7 @@ import diffusers from diffusers import DiffusionPipeline from diffusers.image_processor import VaeImageProcessor +from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import logging from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available from diffusers.utils.testing_utils import require_torch, torch_device @@ -26,6 +27,11 @@ def to_np(tensor): return tensor +def check_same_shape(tensor_list): + shapes = [tensor.shape for tensor in tensor_list] + return all(shape == shapes[0] for shape in shapes[1:]) + + class PipelineLatentTesterMixin: """ This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. @@ -155,6 +161,46 @@ def test_latents_input(self): self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image") +@require_torch +class PipelineKarrasSchedulerTesterMixin: + """ + This mixin is designed to be used with unittest.TestCase classes. + It provides a set of common tests for each PyTorch pipeline that makes use of KarrasDiffusionSchedulers + equivalence of dict and tuple outputs, etc. + """ + + def test_karras_schedulers_shape(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + + # make sure that PNDM does not need warm-up + pipe.scheduler.register_to_config(skip_prk_steps=True) + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + inputs = self.get_dummy_inputs(torch_device) + inputs["num_inference_steps"] = 2 + + if "strength" in inputs: + inputs["num_inference_steps"] = 4 + inputs["strength"] = 0.5 + + outputs = [] + for scheduler_enum in KarrasDiffusionSchedulers: + if "KDPM2" in scheduler_enum.name: + inputs["num_inference_steps"] = 5 + + scheduler_cls = getattr(diffusers, scheduler_enum.name) + pipe.scheduler = scheduler_cls.from_config(pipe.scheduler.config) + output = pipe(**inputs)[0] + outputs.append(output) + + if "KDPM2" in scheduler_enum.name: + inputs["num_inference_steps"] = 2 + + assert check_same_shape(outputs) + + @require_torch class PipelineTesterMixin: """ From 6140a5486479997985e91dd609a02e322850b8a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Mauricio=20Repetto=20Ferrero?= Date: Mon, 3 Jul 2023 10:55:45 -0500 Subject: [PATCH 07/15] Adding better way to define multiple concepts and also validation capabilities. (#3807) * - Added validation parameters - Changed some parameter descriptions to better explain their use. - Fixed a few typos. - Added concept_list parameter for better management of multiple subjects - changed logic for image validation * - Fixed bad logic for class data root directories * Defaulting validation_steps to None for an easier logic * Fixed multiple validation prompts * Fixed bug on validation negative prompt * Changed validation logic for tracker. * Added uuid for validation image labeling * Fix error when comparing validation prompts and validation negative prompts * Improved error message when negative prompts for validation are more than the number of prompts * - Changed image tracking number from epoch to global_step - Added Typing for functions * Added some validations more when using concept_list parameter and the regular ones. * Fixed error message * Added more validations for validation parameters * Improved messaging for errors * Fixed validation error for parameters with default values * - Added train step to image name for validation - reformatted code * - Added train step to image's name for validation - reformatted code * Updated README.md file. * reverted back original script of train_dreambooth.py * reverted back original script of train_dreambooth.py * left one blank line at the eof * reverted back setup.py * reverted back setup.py * added same logic for when parameters for prior preservation are used without enabling the flag while using concept_list parameter. * Ran black formatter. * fixed a few strings * fixed import sort with isort and removed fstrings without placeholder * fixed import order with ruff (since with isort wasn't ok) --------- Co-authored-by: Patrick von Platen --- .../multi_subject_dreambooth/README.md | 47 ++ .../train_multi_subject_dreambooth.py | 410 +++++++++++++++--- 2 files changed, 404 insertions(+), 53 deletions(-) diff --git a/examples/research_projects/multi_subject_dreambooth/README.md b/examples/research_projects/multi_subject_dreambooth/README.md index cf7dd31d0797..d1a7705cfebb 100644 --- a/examples/research_projects/multi_subject_dreambooth/README.md +++ b/examples/research_projects/multi_subject_dreambooth/README.md @@ -86,6 +86,53 @@ This example shows training for 2 subjects, but please note that the model can b Note also that in this script, `sks` and `t@y` were used as tokens to learn the new subjects ([this thread](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/issues/71) inspired the use of `t@y` as our second identifier). However, there may be better rare tokens to experiment with, and results also seemed to be good when more intuitive words are used. +**Important**: New parameters are added to the script, making possible to validate the progress of the training by +generating images at specified steps. Taking also into account that a comma separated list in a text field for a prompt +it's never a good idea (simply because it is very common in prompts to have them as part of a regular text) we +introduce the `concept_list` parameter: allowing to specify a json-like file where you can define the different +configuration for each subject that you want to train. + +An example of how to generate the file: +```python +import json + +# here we are using parameters for prior-preservation and validation as well. +concepts_list = [ + { + "instance_prompt": "drawing of a t@y meme", + "class_prompt": "drawing of a meme", + "instance_data_dir": "/some_folder/meme_toy", + "class_data_dir": "/data/meme", + "validation_prompt": "drawing of a t@y meme about football in Uruguay", + "validation_negative_prompt": "black and white" + }, + { + "instance_prompt": "drawing of a sks sir", + "class_prompt": "drawing of a sir", + "instance_data_dir": "/some_other_folder/sir_sks", + "class_data_dir": "/data/sir", + "validation_prompt": "drawing of a sks sir with the Uruguayan sun in his chest", + "validation_negative_prompt": "an old man", + "validation_guidance_scale": 20, + "validation_number_images": 3, + "validation_inference_steps": 10 + } +] + +with open("concepts_list.json", "w") as f: + json.dump(concepts_list, f, indent=4) +``` +And then just point to the file when executing the script: + +```bash +# exports... +accelerate launch train_multi_subject_dreambooth.py \ +# more parameters... +--concepts_list="concepts_list.json" +``` + +You can use the helper from the script to get a better sense of each parameter. + ### Inference Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt. diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index f24c6057fd8c..c75a0a9acc64 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -1,13 +1,18 @@ import argparse import hashlib import itertools +import json import logging import math -import os +import uuid import warnings +from os import environ, listdir, makedirs +from os.path import basename, join from pathlib import Path +from typing import List import datasets +import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -17,24 +22,140 @@ from accelerate.utils import ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from PIL import Image +from torch import dtype +from torch.nn import Module from torch.utils.data import Dataset from torchvision import transforms from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig import diffusers -from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + UNet2DConditionModel, +) from diffusers.optimization import get_scheduler -from diffusers.utils import check_min_version +from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +if is_wandb_available(): + import wandb + # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.13.0.dev0") logger = get_logger(__name__) +def log_validation_images_to_tracker( + images: List[np.array], label: str, validation_prompt: str, accelerator: Accelerator, epoch: int +): + logger.info(f"Logging images to tracker for validation prompt: {validation_prompt}.") + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{label}_{epoch}_{i}: {validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + +# TODO: Add `prompt_embeds` and `negative_prompt_embeds` parameters to the function when `pre_compute_text_embeddings` +# argument is implemented. +def generate_validation_images( + text_encoder: Module, + tokenizer: Module, + unet: Module, + vae: Module, + arguments: argparse.Namespace, + accelerator: Accelerator, + weight_dtype: dtype, +): + logger.info("Running validation images.") + + pipeline_args = {} + + if text_encoder is not None: + pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder) + + if vae is not None: + pipeline_args["vae"] = vae + + # create pipeline (note: unet and vae are loaded again in float32) + pipeline = DiffusionPipeline.from_pretrained( + arguments.pretrained_model_name_or_path, + tokenizer=tokenizer, + unet=accelerator.unwrap_model(unet), + revision=arguments.revision, + torch_dtype=weight_dtype, + **pipeline_args, + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the + # scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + generator = ( + None if arguments.seed is None else torch.Generator(device=accelerator.device).manual_seed(arguments.seed) + ) + + images_sets = [] + for vp, nvi, vnp, vis, vgs in zip( + arguments.validation_prompt, + arguments.validation_number_images, + arguments.validation_negative_prompt, + arguments.validation_inference_steps, + arguments.validation_guidance_scale, + ): + images = [] + if vp is not None: + logger.info( + f"Generating {nvi} images with prompt: '{vp}', negative prompt: '{vnp}', inference steps: {vis}, " + f"guidance scale: {vgs}." + ) + + pipeline_args = {"prompt": vp, "negative_prompt": vnp, "num_inference_steps": vis, "guidance_scale": vgs} + + # run inference + # TODO: it would be good to measure whether it's faster to run inference on all images at once, one at a + # time or in small batches + for _ in range(nvi): + with torch.autocast("cuda"): + image = pipeline(**pipeline_args, num_images_per_prompt=1, generator=generator).images[0] + images.append(image) + + images_sets.append(images) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return images_sets + + def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, @@ -81,7 +202,7 @@ def parse_args(input_args=None): "--instance_data_dir", type=str, default=None, - required=True, + required=False, help="A folder containing the training data of instance images.", ) parser.add_argument( @@ -95,7 +216,7 @@ def parse_args(input_args=None): "--instance_prompt", type=str, default=None, - required=True, + required=False, help="The prompt with identifier specifying the instance", ) parser.add_argument( @@ -272,6 +393,52 @@ def parse_args(input_args=None): ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) + parser.add_argument( + "--validation_steps", + type=int, + default=None, + help=( + "Run validation every X steps. Validation consists of running the prompt(s) `validation_prompt` " + "multiple times (`validation_number_images`) and logging the images." + ), + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning. You can use commas to " + "define multiple negative prompts. This parameter can be defined also within the file given by " + "`concepts_list` parameter in the respective subject.", + ) + parser.add_argument( + "--validation_number_images", + type=int, + default=4, + help="Number of images that should be generated during validation with the validation parameters given. This " + "can be defined within the file given by `concepts_list` parameter in the respective subject.", + ) + parser.add_argument( + "--validation_negative_prompt", + type=str, + default=None, + help="A negative prompt that is used during validation to verify that the model is learning. You can use commas" + " to define multiple negative prompts, each one corresponding to a validation prompt. This parameter can " + "be defined also within the file given by `concepts_list` parameter in the respective subject.", + ) + parser.add_argument( + "--validation_inference_steps", + type=int, + default=25, + help="Number of inference steps (denoising steps) to run during validation. This can be defined within the " + "file given by `concepts_list` parameter in the respective subject.", + ) + parser.add_argument( + "--validation_guidance_scale", + type=float, + default=7.5, + help="To control how much the image generation process follows the text prompt. This can be defined within the " + "file given by `concepts_list` parameter in the respective subject.", + ) parser.add_argument( "--mixed_precision", type=str, @@ -297,27 +464,80 @@ def parse_args(input_args=None): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + parser.add_argument( + "--concepts_list", + type=str, + default=None, + help="Path to json file containing a list of multiple concepts, will overwrite parameters like instance_prompt," + " class_prompt, etc.", + ) - if input_args is not None: + if input_args: args = parser.parse_args(input_args) else: args = parser.parse_args() - env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if not args.concepts_list and (not args.instance_data_dir or not args.instance_prompt): + raise ValueError( + "You must specify either instance parameters (data directory, prompt, etc.) or use " + "the `concept_list` parameter and specify them within the file." + ) + + if args.concepts_list: + if args.instance_prompt: + raise ValueError("If you are using `concepts_list` parameter, define the instance prompt within the file.") + if args.instance_data_dir: + raise ValueError( + "If you are using `concepts_list` parameter, define the instance data directory within the file." + ) + if args.validation_steps and (args.validation_prompt or args.validation_negative_prompt): + raise ValueError( + "If you are using `concepts_list` parameter, define validation parameters for " + "each subject within the file:\n - `validation_prompt`." + "\n - `validation_negative_prompt`.\n - `validation_guidance_scale`." + "\n - `validation_number_images`.\n - `validation_prompt`." + "\n - `validation_inference_steps`.\nThe `validation_steps` parameter is the only one " + "that needs to be defined outside the file." + ) + + env_local_rank = int(environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank if args.with_prior_preservation: - if args.class_data_dir is None: - raise ValueError("You must specify a data directory for class images.") - if args.class_prompt is None: - raise ValueError("You must specify prompt for class images.") + if not args.concepts_list: + if not args.class_data_dir: + raise ValueError("You must specify a data directory for class images.") + if not args.class_prompt: + raise ValueError("You must specify prompt for class images.") + else: + if args.class_data_dir: + raise ValueError( + "If you are using `concepts_list` parameter, define the class data directory within the file." + ) + if args.class_prompt: + raise ValueError( + "If you are using `concepts_list` parameter, define the class prompt within the file." + ) else: # logger is not available yet - if args.class_data_dir is not None: - warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") - if args.class_prompt is not None: - warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + if not args.class_data_dir: + warnings.warn( + "Ignoring `class_data_dir` parameter, you need to use it together with `with_prior_preservation`." + ) + if not args.class_prompt: + warnings.warn( + "Ignoring `class_prompt` parameter, you need to use it together with `with_prior_preservation`." + ) return args @@ -325,7 +545,7 @@ def parse_args(input_args=None): class DreamBoothDataset(Dataset): """ A dataset to prepare the instance and class images with the prompts for fine-tuning the model. - It pre-processes the images and the tokenizes prompts. + It pre-processes the images and then tokenizes prompts. """ def __init__( @@ -346,7 +566,7 @@ def __init__( self.instance_images_path = [] self.num_instance_images = [] self.instance_prompt = [] - self.class_data_root = [] + self.class_data_root = [] if class_data_root is not None else None self.class_images_path = [] self.num_class_images = [] self.class_prompt = [] @@ -371,8 +591,6 @@ def __init__( self._length -= self.num_instance_images[i] self._length += self.num_class_images[i] self.class_prompt.append(class_prompt[i]) - else: - self.class_data_root = None self.image_transforms = transforms.Compose( [ @@ -446,7 +664,7 @@ def collate_fn(num_instances, examples, with_prior_preservation=False): class PromptDataset(Dataset): - "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + """A simple dataset to prepare the prompts to generate class images on multiple GPUs.""" def __init__(self, prompt, num_samples): self.prompt = prompt @@ -474,6 +692,10 @@ def main(args): project_config=accelerator_project_config, ) + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. @@ -483,23 +705,84 @@ def main(args): "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." ) - # Parse instance and class inputs, and double check that lengths match - instance_data_dir = args.instance_data_dir.split(",") - instance_prompt = args.instance_prompt.split(",") - assert all( - x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)] - ), "Instance data dir and prompt inputs are not of the same length." + instance_data_dir = [] + instance_prompt = [] + class_data_dir = [] if args.with_prior_preservation else None + class_prompt = [] if args.with_prior_preservation else None + if args.concepts_list: + with open(args.concepts_list, "r") as f: + concepts_list = json.load(f) + + if args.validation_steps: + args.validation_prompt = [] + args.validation_number_images = [] + args.validation_negative_prompt = [] + args.validation_inference_steps = [] + args.validation_guidance_scale = [] + + for concept in concepts_list: + instance_data_dir.append(concept["instance_data_dir"]) + instance_prompt.append(concept["instance_prompt"]) + + if args.with_prior_preservation: + try: + class_data_dir.append(concept["class_data_dir"]) + class_prompt.append(concept["class_prompt"]) + except KeyError: + raise KeyError( + "`class_data_dir` or `class_prompt` not found in concepts_list while using " + "`with_prior_preservation`." + ) + else: + if "class_data_dir" in concept: + warnings.warn( + "Ignoring `class_data_dir` key, to use it you need to enable `with_prior_preservation`." + ) + if "class_prompt" in concept: + warnings.warn( + "Ignoring `class_prompt` key, to use it you need to enable `with_prior_preservation`." + ) - if args.with_prior_preservation: - class_data_dir = args.class_data_dir.split(",") - class_prompt = args.class_prompt.split(",") - assert all( - x == len(instance_data_dir) - for x in [len(instance_data_dir), len(instance_prompt), len(class_data_dir), len(class_prompt)] - ), "Instance & class data dir or prompt inputs are not of the same length." + if args.validation_steps: + args.validation_prompt.append(concept.get("validation_prompt", None)) + args.validation_number_images.append(concept.get("validation_number_images", 4)) + args.validation_negative_prompt.append(concept.get("validation_negative_prompt", None)) + args.validation_inference_steps.append(concept.get("validation_inference_steps", 25)) + args.validation_guidance_scale.append(concept.get("validation_guidance_scale", 7.5)) else: - class_data_dir = args.class_data_dir - class_prompt = args.class_prompt + # Parse instance and class inputs, and double check that lengths match + instance_data_dir = args.instance_data_dir.split(",") + instance_prompt = args.instance_prompt.split(",") + assert all( + x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)] + ), "Instance data dir and prompt inputs are not of the same length." + + if args.with_prior_preservation: + class_data_dir = args.class_data_dir.split(",") + class_prompt = args.class_prompt.split(",") + assert all( + x == len(instance_data_dir) + for x in [len(instance_data_dir), len(instance_prompt), len(class_data_dir), len(class_prompt)] + ), "Instance & class data dir or prompt inputs are not of the same length." + + if args.validation_steps: + validation_prompts = args.validation_prompt.split(",") + num_of_validation_prompts = len(validation_prompts) + args.validation_prompt = validation_prompts + args.validation_number_images = [args.validation_number_images] * num_of_validation_prompts + + negative_validation_prompts = [None] * num_of_validation_prompts + if args.validation_negative_prompt: + negative_validation_prompts = args.validation_negative_prompt.split(",") + while len(negative_validation_prompts) < num_of_validation_prompts: + negative_validation_prompts.append(None) + args.validation_negative_prompt = negative_validation_prompts + + assert num_of_validation_prompts == len( + negative_validation_prompts + ), "The length of negative prompts for validation is greater than the number of validation prompts." + args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts + args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -559,21 +842,24 @@ def main(args): ): images = pipeline(example["prompt"]).images - for i, image in enumerate(images): + for ii, image in enumerate(images): hash_image = hashlib.sha1(image.tobytes()).hexdigest() image_filename = ( - class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + class_images_dir / f"{example['index'][ii] + cur_class_images}-{hash_image}.jpg" ) image.save(image_filename) + # Clean up the memory deleting one-time-use variables. del pipeline + del sample_dataloader + del sample_dataset if torch.cuda.is_available(): torch.cuda.empty_cache() # Handle the repository creation if accelerator.is_main_process: if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) + makedirs(args.output_dir, exist_ok=True) if args.push_to_hub: repo_id = create_repo( @@ -581,6 +867,7 @@ def main(args): ).repo_id # Load the tokenizer + tokenizer = None if args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) elif args.pretrained_model_name_or_path: @@ -658,7 +945,7 @@ def main(args): train_dataset = DreamBoothDataset( instance_data_root=instance_data_dir, instance_prompt=instance_prompt, - class_data_root=class_data_dir if args.with_prior_preservation else None, + class_data_root=class_data_dir, class_prompt=class_prompt, tokenizer=tokenizer, size=args.resolution, @@ -720,7 +1007,7 @@ def main(args): args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. + # The trackers initialize automatically on the main process. if accelerator.is_main_process: accelerator.init_trackers("dreambooth", config=vars(args)) @@ -741,10 +1028,10 @@ def main(args): # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": - path = os.path.basename(args.resume_from_checkpoint) + path = basename(args.resume_from_checkpoint) else: # Get the mos recent checkpoint - dirs = os.listdir(args.output_dir) + dirs = listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None @@ -756,7 +1043,7 @@ def main(args): args.resume_from_checkpoint = None else: accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) + accelerator.load_state(join(args.output_dir, path)) global_step = int(path.split("-")[1]) resume_global_step = global_step * args.gradient_accumulation_steps @@ -787,24 +1074,26 @@ def main(args): noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) - timesteps = timesteps.long() + time_steps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ) + time_steps = time_steps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(latents, noise, time_steps) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(noisy_latents, time_steps, encoder_hidden_states).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) + target = noise_scheduler.get_velocity(latents, noise, time_steps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") @@ -834,19 +1123,34 @@ def main(args): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() - optimizer.zero_grad() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 - if global_step % args.checkpointing_steps == 0: - if accelerator.is_main_process: - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + save_path = join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") + if ( + args.validation_steps + and any(args.validation_prompt) + and global_step % args.validation_steps == 0 + ): + images_set = generate_validation_images( + text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype + ) + for images, validation_prompt in zip(images_set, args.validation_prompt): + if len(images) > 0: + label = str(uuid.uuid1())[:8] # generate an id for different set of images + log_validation_images_to_tracker( + images, label, validation_prompt, accelerator, global_step + ) + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) @@ -854,7 +1158,7 @@ def main(args): if global_step >= args.max_train_steps: break - # Create the pipeline using using the trained modules and save it. + # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: pipeline = DiffusionPipeline.from_pretrained( From 22f9b92107a5128e37d6adbfa80b6df52846b7f1 Mon Sep 17 00:00:00 2001 From: estelleafl Date: Mon, 3 Jul 2023 19:15:46 +0300 Subject: [PATCH 08/15] [ldm3d] Update code to be functional with the new checkpoints (#3875) * fixed typo * updated doc to be consistent in naming * make style/quality * preprocessing for 4 channels and not 6 * make style * test for 4c * make style/quality * fixed test on cpu --------- Co-authored-by: Aflalo Co-authored-by: Aflalo Co-authored-by: Aflalo --- src/diffusers/image_processor.py | 21 ++++++++---- .../test_stable_diffusion_ldm3d.py | 32 +++++++++++++++---- 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 2a433ee14d98..6ccf9b465ebd 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -312,12 +312,17 @@ def numpy_to_depth(self, images): """ if images.ndim == 3: images = images[None, ...] - images = (images * 255).round().astype("uint8") - if images.shape[-1] == 1: - # special case for grayscale (single channel) images - raise Exception("Not supported") + images_depth = images[:, :, :, 3:] + if images.shape[-1] == 6: + images_depth = (images_depth * 255).round().astype("uint8") + pil_images = [ + Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth + ] + elif images.shape[-1] == 4: + images_depth = (images_depth * 65535.0).astype(np.uint16) + pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth] else: - pil_images = [Image.fromarray(self.rgblike_to_depthmap(image[:, :, 3:]), mode="I;16") for image in images] + raise Exception("Not supported") return pil_images @@ -349,7 +354,11 @@ def postprocess( image = self.pt_to_numpy(image) if output_type == "np": - return image[:, :, :, :3], np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0) + if image.shape[-1] == 6: + image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0) + else: + image_depth = image[:, :, :, 3:] + return image[:, :, :, :3], image_depth if output_type == "pil": return self.numpy_to_pil(image), self.numpy_to_depth(image) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_ldm3d.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_ldm3d.py index 933e4307a41b..e2164e8117ad 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_ldm3d.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_ldm3d.py @@ -130,9 +130,9 @@ def test_stable_diffusion_ddim(self): assert depth.shape == (1, 64, 64) expected_slice_rgb = np.array( - [0.37301102, 0.7023895, 0.7418312, 0.5163375, 0.5825485, 0.60929704, 0.4188174, 0.48407027, 0.46555096] + [0.37338176, 0.70247, 0.74203193, 0.51643604, 0.58256793, 0.60932136, 0.4181095, 0.48355877, 0.46535262] ) - expected_slice_depth = np.array([103.4673, 85.81202, 87.84926]) + expected_slice_depth = np.array([103.46727, 85.812004, 87.849236]) assert np.abs(image_slice_rgb.flatten() - expected_slice_rgb).max() < 1e-2 assert np.abs(image_slice_depth.flatten() - expected_slice_depth).max() < 1e-2 @@ -280,10 +280,30 @@ def test_ldm3d(self): output = ldm3d_pipe(**inputs) rgb, depth = output.rgb, output.depth - expected_rgb_mean = 0.54461557 - expected_rgb_std = 0.2806707 - expected_depth_mean = 143.64595 - expected_depth_std = 83.491776 + expected_rgb_mean = 0.495586 + expected_rgb_std = 0.33795515 + expected_depth_mean = 112.48518 + expected_depth_std = 98.489746 + assert np.abs(expected_rgb_mean - rgb.mean()) < 1e-3 + assert np.abs(expected_rgb_std - rgb.std()) < 1e-3 + assert np.abs(expected_depth_mean - depth.mean()) < 1e-3 + assert np.abs(expected_depth_std - depth.std()) < 1e-3 + + def test_ldm3d_v2(self): + ldm3d_pipe = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d-4c").to(torch_device) + ldm3d_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_inputs(torch_device) + output = ldm3d_pipe(**inputs) + rgb, depth = output.rgb, output.depth + + expected_rgb_mean = 0.4194127 + expected_rgb_std = 0.35375586 + expected_depth_mean = 0.5638502 + expected_depth_std = 0.34686103 + + assert rgb.shape == (1, 512, 512, 3) + assert depth.shape == (1, 512, 512, 1) assert np.abs(expected_rgb_mean - rgb.mean()) < 1e-3 assert np.abs(expected_rgb_std - rgb.std()) < 1e-3 assert np.abs(expected_depth_mean - depth.mean()) < 1e-3 From 4e494670f82a4af4496edb6a1c8d83bb8e0a5e0f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 3 Jul 2023 18:17:34 +0200 Subject: [PATCH 09/15] Improve memory text to video (#3930) * Improve memory text to video * Apply suggestions from code review * add test * Apply suggestions from code review Co-authored-by: Pedro Cuenca * finish test setup --------- Co-authored-by: Pedro Cuenca --- src/diffusers/models/attention.py | 25 +++++++++++- src/diffusers/models/unet_3d_condition.py | 40 +++++++++++++++++++ .../pipeline_text_to_video_synth.py | 3 ++ .../pipeline_text_to_video_synth_img2img.py | 3 ++ tests/models/test_models_unet_3d_condition.py | 18 +++++++++ 5 files changed, 88 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 8805257ebe9a..6b05bf35e87f 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -119,6 +119,15 @@ def __init__( self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + def forward( self, hidden_states: torch.FloatTensor, @@ -141,6 +150,7 @@ def forward( norm_hidden_states = self.norm1(hidden_states) cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, @@ -171,7 +181,20 @@ def forward( if self.use_ada_layer_norm_zero: norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - ff_output = self.ff(norm_hidden_states) + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) if self.use_ada_layer_norm_zero: ff_output = gate_mlp.unsqueeze(1) * ff_output diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 36dcaf21f827..9bc89c571c52 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -389,6 +389,46 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) + def enable_forward_chunking(self, chunk_size=None, dim=0): + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + def disable_forward_chunking(self): + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, None, 0) + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index e30f183808a5..680a524732e9 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -634,6 +634,9 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 6.1 Chunk feed-forward computation to save memory + self.unet.enable_forward_chunking(chunk_size=1, dim=1) + # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index ce5109a58213..1b6cd9c2b392 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -709,6 +709,9 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 6.1 Chunk feed-forward computation to save memory + self.unet.enable_forward_chunking(chunk_size=1, dim=1) + # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index 3f29d0a41e18..72a33854bdcd 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -399,5 +399,23 @@ def test_lora_xformers_on_off(self): assert (sample - on_sample).abs().max() < 1e-4 assert (sample - off_sample).abs().max() < 1e-4 + def test_feed_forward_chunking(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict["norm_num_groups"] = 32 + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict)[0] + + model.enable_forward_chunking() + with torch.no_grad(): + output_2 = model(**inputs_dict)[0] + + self.assertEqual(output.shape, output_2.shape, "Shape doesn't match") + assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2 + # (todo: sayakpaul) implement SLOW tests. From 74ab49116ab3db2fa7affc7503bfd50f3a3a1664 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 3 Jul 2023 23:12:41 +0200 Subject: [PATCH 10/15] revert automatic chunking (#3934) * revert automatic chunking * Apply suggestions from code review * revert automatic chunking --- .../source/en/api/pipelines/text_to_video.mdx | 28 ++++++++++++++++++- .../pipeline_text_to_video_synth.py | 3 -- .../pipeline_text_to_video_synth_img2img.py | 3 -- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/docs/source/en/api/pipelines/text_to_video.mdx b/docs/source/en/api/pipelines/text_to_video.mdx index 75868d7dd6ea..583d461ea948 100644 --- a/docs/source/en/api/pipelines/text_to_video.mdx +++ b/docs/source/en/api/pipelines/text_to_video.mdx @@ -138,6 +138,7 @@ pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dt pipe.enable_model_cpu_offload() # memory optimization +pipe.unet.enable_forward_chunking(chunk_size=1, dim=1) pipe.enable_vae_slicing() prompt = "Darth Vader surfing a wave" @@ -150,10 +151,13 @@ Now the video can be upscaled: ```py pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_XL", torch_dtype=torch.float16) -pipe.vae.enable_slicing() pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) pipe.enable_model_cpu_offload() +# memory optimization +pipe.unet.enable_forward_chunking(chunk_size=1, dim=1) +pipe.enable_vae_slicing() + video = [Image.fromarray(frame).resize((1024, 576)) for frame in video_frames] video_frames = pipe(prompt, video=video, strength=0.6).frames @@ -175,6 +179,28 @@ Here are some sample outputs: +### Memory optimizations + +Text-guided video generation with [`~TextToVideoSDPipeline`] and [`~VideoToVideoSDPipeline`] is very memory intensive both +when denoising with [`~UNet3DConditionModel`] and when decoding with [`~AutoencoderKL`]. It is possible though to reduce +memory usage at the cost of increased runtime to achieve the exact same result. To do so, it is recommended to enable +**forward chunking** and **vae slicing**: + +Forward chunking via [`~UNet3DConditionModel.enable_forward_chunking`]is explained in [this blog post](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers) and +allows to significantly reduce the required memory for the unet. You can chunk the feed forward layer over the `num_frames` +dimension by doing: + +```py +pipe.unet.enable_forward_chunking(chunk_size=1, dim=1) +``` + +Vae slicing via [`~TextToVideoSDPipeline.enable_vae_slicing`] and [`~VideoToVideoSDPipeline.enable_vae_slicing`] also +gives significant memory savings since the two pipelines decode all image frames at once. + +```py +pipe.enable_vae_slicing() +``` + ## Available checkpoints * [damo-vilab/text-to-video-ms-1.7b](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b/) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 680a524732e9..e30f183808a5 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -634,9 +634,6 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 6.1 Chunk feed-forward computation to save memory - self.unet.enable_forward_chunking(chunk_size=1, dim=1) - # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index 1b6cd9c2b392..ce5109a58213 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -709,9 +709,6 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 6.1 Chunk feed-forward computation to save memory - self.unet.enable_forward_chunking(chunk_size=1, dim=1) - # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: From cec76d2aa89b2392b65e65086607a373f4b13149 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Mon, 3 Jul 2023 18:49:49 -0700 Subject: [PATCH 11/15] avoid upcasting by assigning dtype to noise tensor (#3713) * avoid upcasting by assigning dtype to noise tensor * make style * Update train_unconditional.py * Update train_unconditional.py * make style * add unit test for pickle * revert change --------- Co-authored-by: root Co-authored-by: Patrick von Platen Co-authored-by: Prathik Rao --- .../unconditional_image_generation/train_unconditional.py | 4 +++- .../unconditional_image_generation/train_unconditional.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py index a42187fadea1..12ff40bbd680 100644 --- a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py +++ b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py @@ -568,7 +568,9 @@ def transform_images(examples): clean_images = batch["input"] # Sample noise that we'll add to the images - noise = torch.randn(clean_images.shape).to(clean_images.device) + noise = torch.randn( + clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16) + ).to(clean_images.device) bsz = clean_images.shape[0] # Sample a random timestep for each image timesteps = torch.randint( diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index d6e4b17ba889..e10e6d302457 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -557,7 +557,9 @@ def transform_images(examples): clean_images = batch["input"] # Sample noise that we'll add to the images - noise = torch.randn(clean_images.shape).to(clean_images.device) + noise = torch.randn( + clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16) + ).to(clean_images.device) bsz = clean_images.shape[0] # Sample a random timestep for each image timesteps = torch.randint( From c1e23de7877ba394c2f03fb46aac5d57ecd582db Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 4 Jul 2023 14:00:43 +0200 Subject: [PATCH 12/15] Fix failing np tests (#3942) * Fix failing np tests * Apply suggestions from code review * Update tests/pipelines/test_pipelines_common.py --- tests/pipelines/test_pipelines_common.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 52dd4afd6b21..e97bdb352b22 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -698,11 +698,13 @@ def _test_xformers_attention_forwardGenerator_pass( pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(torch_device) - output_without_offload = pipe(**inputs)[0].cpu() + output_without_offload = pipe(**inputs)[0] + output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload pipe.enable_xformers_memory_efficient_attention() inputs = self.get_dummy_inputs(torch_device) - output_with_offload = pipe(**inputs)[0].cpu() + output_with_offload = pipe(**inputs)[0] + output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload if test_max_difference: max_diff = np.abs(output_with_offload - output_without_offload).max() From a166fd654b5b9b9fe5cba48ba25665adf460aecb Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 5 Jul 2023 15:49:30 +0200 Subject: [PATCH 13/15] Add `timestep_spacing` and `steps_offset` to schedulers (#3947) * Add timestep_spacing to DDPM, LMSDiscrete, PNDM. * Remove spurious line. * More easy schedulers. * Add `linspace` to DDIM * Noise sigma for `trailing`. * Add timestep_spacing to DEISMultistepScheduler. Not sure the range is the way it was intended. * Fix: remove line used to debug. * Support timestep_spacing in DPMSolverMultistep, DPMSolverSDE, UniPC * Fix: convert to numpy. * Use sched. defaults when instantiating from_config For params not present in the original configuration. This makes it possible to switch pipeline schedulers even if they use different timestep_spacing (or any other param). * Apply suggestions from code review Co-authored-by: Patrick von Platen * Missing args in DPMSolverMultistep * Test: default args not in config * Style * Fix scheduler name in test * Remove duplicated entries * Add test for solver_type This test currently fails in main. When switching from DEIS to UniPC, solver_type is "logrho" (the default value from DEIS), which gets translated to "bh1" by UniPC. This is different to the default value for UniPC: "bh2". This is where the translation happens: https://github.com/huggingface/diffusers/blob/36d22d0709dc19776e3016fb3392d0f5578b0ab2/src/diffusers/schedulers/scheduling_unipc_multistep.py#L171 * UniPC: use same default for solver_type Fixes a bug when switching from UniPC from another scheduler (i.e., DEIS) that uses a different solver type. The solver is now the same as if we had instantiated the scheduler directly. * do not save use default values * fix more * fix all * fix schedulers * fix more * finish for real * finish for real * flaky tests * Update tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py * Default steps_offset to 0. * Add missing docstrings * Apply suggestions from code review --------- Co-authored-by: Patrick von Platen --- src/diffusers/configuration_utils.py | 16 +++++- src/diffusers/schedulers/scheduling_ddim.py | 11 +++- .../schedulers/scheduling_ddim_parallel.py | 11 +++- src/diffusers/schedulers/scheduling_ddpm.py | 37 +++++++++++-- .../schedulers/scheduling_ddpm_parallel.py | 37 +++++++++++-- .../schedulers/scheduling_deis_multistep.py | 39 +++++++++++--- .../scheduling_dpmsolver_multistep.py | 38 ++++++++++--- .../schedulers/scheduling_dpmsolver_sde.py | 40 ++++++++++++-- .../scheduling_euler_ancestral_discrete.py | 44 +++++++++++++-- .../schedulers/scheduling_euler_discrete.py | 43 +++++++++++++-- .../schedulers/scheduling_heun_discrete.py | 40 ++++++++++++-- .../scheduling_k_dpm_2_ancestral_discrete.py | 40 ++++++++++++-- .../schedulers/scheduling_k_dpm_2_discrete.py | 40 ++++++++++++-- .../schedulers/scheduling_lms_discrete.py | 42 +++++++++++++-- src/diffusers/schedulers/scheduling_pndm.py | 33 +++++++++--- .../schedulers/scheduling_unipc_multistep.py | 41 +++++++++++--- tests/others/test_config.py | 53 +++++++++++++++++++ .../test_stable_diffusion_panorama.py | 2 +- tests/schedulers/test_scheduler_euler.py | 4 +- .../test_scheduler_euler_ancestral.py | 4 +- tests/schedulers/test_scheduler_lms.py | 2 +- tests/schedulers/test_scheduler_unipc.py | 10 ++-- tests/schedulers/test_schedulers.py | 48 ++++++++++++++++- 23 files changed, 598 insertions(+), 77 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 1a030e467134..202905db52c6 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -423,6 +423,10 @@ def _get_init_keys(cls): @classmethod def extract_init_dict(cls, config_dict, **kwargs): + # Skip keys that were not present in the original config, so default __init__ values were used + used_defaults = config_dict.get("_use_default_values", []) + config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"} + # 0. Copy origin config dict original_dict = dict(config_dict.items()) @@ -544,8 +548,9 @@ def to_json_saveable(value): return value config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} - # Don't save "_ignore_files" + # Don't save "_ignore_files" or "_use_default_values" config_dict.pop("_ignore_files", None) + config_dict.pop("_use_default_values", None) return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" @@ -599,6 +604,11 @@ def inner_init(self, *args, **kwargs): if k not in ignore and k not in new_kwargs } ) + + # Take note of the parameters that were not present in the loaded config + if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: + new_kwargs["_use_default_values"] = set(new_kwargs.keys()) - set(init_kwargs) + new_kwargs = {**config_init_kwargs, **new_kwargs} getattr(self, "register_to_config")(**new_kwargs) init(self, *args, **init_kwargs) @@ -643,6 +653,10 @@ def init(self, *args, **kwargs): name = fields[i].name new_kwargs[name] = arg + # Take note of the parameters that were not present in the loaded config + if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: + new_kwargs["_use_default_values"] = set(new_kwargs.keys()) - set(init_kwargs) + getattr(self, "register_to_config")(**new_kwargs) original_init(self, *args, **kwargs) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index bab6f8acea03..99602d14038b 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -302,8 +302,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.num_inference_steps = num_inference_steps - # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "leading": + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 diff --git a/src/diffusers/schedulers/scheduling_ddim_parallel.py b/src/diffusers/schedulers/scheduling_ddim_parallel.py index 22b7d8ec97dc..8875aa73208b 100644 --- a/src/diffusers/schedulers/scheduling_ddim_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddim_parallel.py @@ -321,8 +321,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.num_inference_steps = num_inference_steps - # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "leading": + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 5d24766d68c7..ddf27d409d88 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -114,6 +114,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -134,6 +141,8 @@ def __init__( dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -228,11 +237,33 @@ def set_timesteps( ) self.num_inference_steps = num_inference_steps - - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) self.custom_timesteps = False + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + self.timesteps = torch.from_numpy(timesteps).to(device) def _get_variance(self, t, predicted_variance=None, variance_type=None): diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py index 2719d90b9314..e4d858efde8f 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py @@ -116,6 +116,13 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -138,6 +145,8 @@ def __init__( dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -234,11 +243,33 @@ def set_timesteps( ) self.num_inference_steps = num_inference_steps - - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) self.custom_timesteps = False + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + self.timesteps = torch.from_numpy(timesteps).to(device) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._get_variance diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 56c362018c18..c504fb19231a 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -107,6 +107,13 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -129,6 +136,8 @@ def __init__( solver_type: str = "logrho", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -185,12 +194,30 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index d7c29d5488a5..528b7b838b1c 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -134,6 +134,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -158,6 +165,8 @@ def __init__( use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -217,12 +226,29 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) - timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index ae9229981152..da8b71788b75 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -133,6 +133,13 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. noise_sampler_seed (`int`, *optional*, defaults to `None`): The random seed to use for the noise sampler. If `None`, a random seed will be generated. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -149,6 +156,8 @@ def __init__( prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, noise_sampler_seed: Optional[int] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -187,6 +196,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): pos = 0 return indices[pos].item() + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, @@ -226,7 +243,25 @@ def set_timesteps( num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) @@ -242,9 +277,6 @@ def set_timesteps( sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - timesteps = torch.from_numpy(timesteps) second_order_timesteps = torch.from_numpy(second_order_timesteps) timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 6b08e9bfc207..6b8c2f1a8a28 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -99,7 +99,13 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) - + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -114,6 +120,8 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -137,15 +145,20 @@ def __init__( sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.is_scale_input_called = False + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: @@ -179,7 +192,28 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic """ self.num_inference_steps = num_inference_steps - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ + ::-1 + ].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 7237128cbf07..fc52c50ebc7f 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -107,6 +107,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -123,6 +130,8 @@ def __init__( prediction_type: str = "epsilon", interpolation_type: str = "linear", use_karras_sigmas: Optional[bool] = False, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -146,9 +155,6 @@ def __init__( sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() @@ -156,6 +162,14 @@ def __init__( self.is_scale_input_called = False self.use_karras_sigmas = use_karras_sigmas + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: @@ -191,7 +205,28 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic """ self.num_inference_steps = num_inference_steps - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ + ::-1 + ].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index 100e2012ea20..28f29067a544 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -78,6 +78,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -93,6 +100,8 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -128,6 +137,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): pos = 0 return indices[pos].item() + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, @@ -166,7 +183,25 @@ def set_timesteps( num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) @@ -180,9 +215,6 @@ def set_timesteps( sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - timesteps = torch.from_numpy(timesteps) timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index 2fa0431e1292..d4a35ab82502 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -78,6 +78,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -92,6 +99,8 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -127,6 +136,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): pos = 0 return indices[pos].item() + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, @@ -169,7 +186,25 @@ def set_timesteps( num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) @@ -197,9 +232,6 @@ def set_timesteps( self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]]) self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]]) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - if str(device).startswith("mps"): # mps does not support float64 timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index bb80c4a54bfe..39079fde10d2 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -77,6 +77,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -91,6 +98,8 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -126,6 +135,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): pos = 0 return indices[pos].item() + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, @@ -168,7 +185,25 @@ def set_timesteps( num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) @@ -185,9 +220,6 @@ def set_timesteps( [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]] ) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - if str(device).startswith("mps"): # mps does not support float64 timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 0656475c3093..1256660b843c 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -102,6 +102,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -117,6 +124,8 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, use_karras_sigmas: Optional[bool] = False, prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -140,9 +149,6 @@ def __init__( sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - # setable values self.num_inference_steps = None self.use_karras_sigmas = use_karras_sigmas @@ -150,6 +156,14 @@ def __init__( self.derivatives = [] self.is_scale_input_called = False + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: @@ -205,7 +219,27 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic """ self.num_inference_steps = num_inference_steps - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ + ::-1 + ].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 01c02a21bbfc..70ee1301129c 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -85,11 +85,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. - """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -106,6 +108,7 @@ def __init__( skip_prk_steps: bool = False, set_alpha_to_one: bool = False, prediction_type: str = "epsilon", + timestep_spacing: str = "leading", steps_offset: int = 0, ): if trained_betas is not None: @@ -159,11 +162,29 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic """ self.num_inference_steps = num_inference_steps - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round() - self._timesteps += self.config.steps_offset + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + self._timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round().astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round() + self._timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self._timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio))[::-1].astype( + np.int64 + ) + self._timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) if self.config.skip_prk_steps: # for some models like stable diffusion the prk steps can/should be skipped to diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 7233258a4766..3caa01a58562 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -121,6 +121,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -145,6 +152,8 @@ def __init__( disable_corrector: List[int] = [], solver_p: SchedulerMixin = None, use_karras_sigmas: Optional[bool] = False, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -173,7 +182,7 @@ def __init__( if solver_type not in ["bh1", "bh2"]: if solver_type in ["midpoint", "heun", "logrho"]: - self.register_to_config(solver_type="bh1") + self.register_to_config(solver_type="bh2") else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") @@ -199,12 +208,30 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: diff --git a/tests/others/test_config.py b/tests/others/test_config.py index a29190c199ca..d1f8a6e054d4 100644 --- a/tests/others/test_config.py +++ b/tests/others/test_config.py @@ -75,6 +75,22 @@ def __init__( pass +class SampleObject4(ConfigMixin): + config_name = "config.json" + + @register_to_config + def __init__( + self, + a=2, + b=5, + c=(2, 5), + d="for diffusion", + e=[1, 5], + f=[5, 4], + ): + pass + + class ConfigTester(unittest.TestCase): def test_load_not_from_mixin(self): with self.assertRaises(ValueError): @@ -137,6 +153,7 @@ def test_save_load(self): assert config.pop("c") == (2, 5) # instantiated as tuple assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json + config.pop("_use_default_values") assert config == new_config def test_load_ddim_from_pndm(self): @@ -233,3 +250,39 @@ def test_load_dpmsolver(self): assert dpm.__class__ == DPMSolverMultistepScheduler # no warning should be thrown assert cap_logger.out == "" + + def test_use_default_values(self): + # let's first save a config that should be in the form + # a=2, + # b=5, + # c=(2, 5), + # d="for diffusion", + # e=[1, 3], + + config = SampleObject() + + config_dict = {k: v for k, v in config.config.items() if not k.startswith("_")} + + # make sure that default config has all keys in `_use_default_values` + assert set(config_dict.keys()) == config.config._use_default_values + + with tempfile.TemporaryDirectory() as tmpdirname: + config.save_config(tmpdirname) + + # now loading it with SampleObject2 should put f into `_use_default_values` + config = SampleObject2.from_config(tmpdirname) + + assert "f" in config._use_default_values + assert config.f == [1, 3] + + # now loading the config, should **NOT** use [1, 3] for `f`, but the default [1, 4] value + # **BECAUSE** it is part of `config._use_default_values` + new_config = SampleObject4.from_config(config.config) + assert new_config.f == [5, 4] + + config.config._use_default_values.pop() + new_config_2 = SampleObject4.from_config(config.config) + assert new_config_2.f == [1, 3] + + # Nevertheless "e" should still be correctly loaded to [1, 3] from SampleObject2 instead of defaulting to [1, 5] + assert new_config_2.e == [1, 3] diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py index 32541c980a15..080bd0091f4f 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py @@ -186,7 +186,7 @@ def test_stable_diffusion_panorama_euler(self): assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.4886, 0.5586, 0.4476, 0.5053, 0.6013, 0.4737, 0.5538, 0.5100, 0.4927]) + expected_slice = np.array([0.4024, 0.6510, 0.4901, 0.5378, 0.5813, 0.5622, 0.4795, 0.4467, 0.4952]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/schedulers/test_scheduler_euler.py b/tests/schedulers/test_scheduler_euler.py index aa46ef31885a..0c3b065161db 100644 --- a/tests/schedulers/test_scheduler_euler.py +++ b/tests/schedulers/test_scheduler_euler.py @@ -101,7 +101,7 @@ def test_full_loop_device(self): generator = torch.manual_seed(0) model = self.dummy_model() - sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = self.dummy_sample_deter * scheduler.init_noise_sigma.cpu() sample = sample.to(torch_device) for t in scheduler.timesteps: @@ -128,7 +128,7 @@ def test_full_loop_device_karras_sigmas(self): generator = torch.manual_seed(0) model = self.dummy_model() - sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = self.dummy_sample_deter * scheduler.init_noise_sigma.cpu() sample = sample.to(torch_device) for t in scheduler.timesteps: diff --git a/tests/schedulers/test_scheduler_euler_ancestral.py b/tests/schedulers/test_scheduler_euler_ancestral.py index 5fa36be6bc64..9866bd12d6af 100644 --- a/tests/schedulers/test_scheduler_euler_ancestral.py +++ b/tests/schedulers/test_scheduler_euler_ancestral.py @@ -47,7 +47,7 @@ def test_full_loop_no_noise(self): generator = torch.manual_seed(0) model = self.dummy_model() - sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = self.dummy_sample_deter * scheduler.init_noise_sigma.cpu() sample = sample.to(torch_device) for i, t in enumerate(scheduler.timesteps): @@ -100,7 +100,7 @@ def test_full_loop_device(self): generator = torch.manual_seed(0) model = self.dummy_model() - sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = self.dummy_sample_deter * scheduler.init_noise_sigma.cpu() sample = sample.to(torch_device) for t in scheduler.timesteps: diff --git a/tests/schedulers/test_scheduler_lms.py b/tests/schedulers/test_scheduler_lms.py index 2682886a788d..1e0a8212354d 100644 --- a/tests/schedulers/test_scheduler_lms.py +++ b/tests/schedulers/test_scheduler_lms.py @@ -97,7 +97,7 @@ def test_full_loop_device(self): scheduler.set_timesteps(self.num_inference_steps, device=torch_device) model = self.dummy_model() - sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = self.dummy_sample_deter * scheduler.init_noise_sigma.cpu() sample = sample.to(torch_device) for i, t in enumerate(scheduler.timesteps): diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py index 62cffc67388c..171ee85be1d3 100644 --- a/tests/schedulers/test_scheduler_unipc.py +++ b/tests/schedulers/test_scheduler_unipc.py @@ -23,7 +23,7 @@ def get_scheduler_config(self, **kwargs): "beta_end": 0.02, "beta_schedule": "linear", "solver_order": 2, - "solver_type": "bh1", + "solver_type": "bh2", } config.update(**kwargs) @@ -144,7 +144,7 @@ def test_switch(self): sample = self.full_loop(scheduler=scheduler) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_mean.item() - 0.2521) < 1e-3 + assert abs(result_mean.item() - 0.2464) < 1e-3 scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config) scheduler = DEISMultistepScheduler.from_config(scheduler.config) @@ -154,7 +154,7 @@ def test_switch(self): sample = self.full_loop(scheduler=scheduler) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_mean.item() - 0.2521) < 1e-3 + assert abs(result_mean.item() - 0.2464) < 1e-3 def test_timesteps(self): for timesteps in [25, 50, 100, 999, 1000]: @@ -206,13 +206,13 @@ def test_full_loop_no_noise(self): sample = self.full_loop() result_mean = torch.mean(torch.abs(sample)) - assert abs(result_mean.item() - 0.2521) < 1e-3 + assert abs(result_mean.item() - 0.2464) < 1e-3 def test_full_loop_with_v_prediction(self): sample = self.full_loop(prediction_type="v_prediction") result_mean = torch.mean(torch.abs(sample)) - assert abs(result_mean.item() - 0.1096) < 1e-3 + assert abs(result_mean.item() - 0.1014) < 1e-3 def test_fp16_support(self): scheduler_class = self.scheduler_classes[0] diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py index a2d065f388bd..d1ae333c0cd2 100755 --- a/tests/schedulers/test_schedulers.py +++ b/tests/schedulers/test_schedulers.py @@ -24,10 +24,14 @@ import diffusers from diffusers import ( + DDIMScheduler, + DEISMultistepScheduler, + DiffusionPipeline, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, IPNDMScheduler, LMSDiscreteScheduler, + UniPCMultistepScheduler, VQDiffusionScheduler, logging, ) @@ -202,6 +206,44 @@ def test_save_load_from_different_config_comp_schedulers(self): assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n" assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n" + def test_default_arguments_not_in_config(self): + pipe = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe", torch_dtype=torch.float16 + ) + assert pipe.scheduler.__class__ == DDIMScheduler + + # Default for DDIMScheduler + assert pipe.scheduler.config.timestep_spacing == "leading" + + # Switch to a different one, verify we use the default for that class + pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) + assert pipe.scheduler.config.timestep_spacing == "linspace" + + # Override with kwargs + pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") + assert pipe.scheduler.config.timestep_spacing == "trailing" + + # Verify overridden kwargs stick + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) + assert pipe.scheduler.config.timestep_spacing == "trailing" + + # And stick + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) + assert pipe.scheduler.config.timestep_spacing == "trailing" + + def test_default_solver_type_after_switch(self): + pipe = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe", torch_dtype=torch.float16 + ) + assert pipe.scheduler.__class__ == DDIMScheduler + + pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config) + assert pipe.scheduler.config.solver_type == "logrho" + + # Switch to UniPC, verify the solver is the default + pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + assert pipe.scheduler.config.solver_type == "bh2" + class SchedulerCommonTest(unittest.TestCase): scheduler_classes = () @@ -414,7 +456,11 @@ def test_from_pretrained(self): scheduler.save_pretrained(tmpdirname) new_scheduler = scheduler_class.from_pretrained(tmpdirname) - assert scheduler.config == new_scheduler.config + # `_use_default_values` should not exist for just saved & loaded scheduler + scheduler_config = dict(scheduler.config) + del scheduler_config["_use_default_values"] + + assert scheduler_config == new_scheduler.config def test_step_shape(self): kwargs = dict(self.forward_default_kwargs) From 9ff77272fdd0f8c20a671526df83496d9ba0dc96 Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Wed, 5 Jul 2023 10:33:58 -0700 Subject: [PATCH 14/15] Add Consistency Models Pipeline (#3492) * initial commit * Improve consistency models sampling implementation. * Add CMStochasticIterativeScheduler, which implements the multi-step sampler (stochastic_iterative_sampler) in the original code, and make further improvements to sampling. * Add Unet blocks for consistency models * Add conversion script for Unet * Fix bug in new unet blocks * Fix attention weight loading * Make design improvements to ConsistencyModelPipeline and CMStochasticIterativeScheduler and add initial version of tests. * make style * Make small random test UNet class conditional and set resnet_time_scale_shift to 'scale_shift' to better match consistency model checkpoints. * Add support for converting a test UNet and non-class-conditional UNets to the consistency models conversion script. * make style * Change num_class_embeds to 1000 to better match the original consistency models implementation. * Add support for distillation in pipeline_consistency_models.py. * Improve consistency model tests: - Get small testing checkpoints from hub - Modify tests to take into account "distillation" parameter of ConsistencyModelPipeline - Add onestep, multistep tests for distillation and distillation + class conditional - Add expected image slices for onestep tests * make style * Improve ConsistencyModelPipeline: - Add initial support for class-conditional generation - Fix initial sigma for onestep generation - Fix some sigma shape issues * make style * Improve ConsistencyModelPipeline: - add latents __call__ argument and prepare_latents method - add check_inputs method - add initial docstrings for ConsistencyModelPipeline.__call__ * make style * Fix bug when randomly generating class labels for class-conditional generation. * Switch CMStochasticIterativeScheduler to configuring a sigma schedule and make related changes to the pipeline and tests. * Remove some unused code and make style. * Fix small bug in CMStochasticIterativeScheduler. * Add expected slices for multistep sampling tests and make them pass. * Work on consistency model fast tests: - in pipeline, call self.scheduler.scale_model_input before denoising - get expected slices for Euler and Heun scheduler tests - make Euler test pass - mark Heun test as expected fail because it doesn't support prediction_type "sample" yet - remove DPM and Euler Ancestral tests because they don't support use_karras_sigmas * make style * Refactor conversion script to make it easier to add more model architectures to convert in the future. * Work on ConsistencyModelPipeline tests: - Fix device bug when handling class labels in ConsistencyModelPipeline.__call__ - Add slow tests for onestep and multistep sampling and make them pass - Refactor fast tests - Refactor ConsistencyModelPipeline.__init__ * make style * Remove the add_noise and add_noise_to_input methods from CMStochasticIterativeScheduler for now. * Run python utils/check_copies.py --fix_and_overwrite python utils/check_dummies.py --fix_and_overwrite to make dummy objects for new pipeline and scheduler. * Make fast tests from PipelineTesterMixin pass. * make style * Refactor consistency models pipeline and scheduler: - Remove support for Karras schedulers (only support CMStochasticIterativeScheduler) - Move sigma manipulation, input scaling, denoising from pipeline to scheduler - Make corresponding changes to tests and ensure they pass * make style * Add docstrings and further refactor pipeline and scheduler. * make style * Add initial version of the consistency models documentation. * Refactor custom timesteps logic following DDPMScheduler/IFPipeline and temporarily add torch 2.0 SDPA kernel selection logic for debugging. * make style * Convert current slow tests to use fp16 and flash attention. * make style * Add slow tests for normal attention on cuda device. * make style * Fix attention weights loading * Update consistency model fast tests for new test checkpoints with attention fix. * make style * apply suggestions * Add add_noise method to CMStochasticIterativeScheduler (copied from EulerDiscreteScheduler). * Conversion script now outputs pipeline instead of UNet and add support for LSUN-256 models and different schedulers. * When both timesteps and num_inference_steps are supplied, raise warning instead of error (timesteps take precedence). * make style * Add remaining diffusers model checkpoints for models in the original consistency model release and update usage example. * apply suggestions from review * make style * fix attention naming * Add tests for CMStochasticIterativeScheduler. * make style * Make CMStochasticIterativeScheduler tests pass. * make style * Override test_step_shape in CMStochasticIterativeSchedulerTest instead of modifying it in SchedulerCommonTest. * make style * rename some models * Improve API * rename some models * Remove duplicated block * Add docstring and make torch compile work * More fixes * Fixes * Apply suggestions from code review * Apply suggestions from code review * add more docstring * update consistency conversion script --------- Co-authored-by: ayushmangal Co-authored-by: Ayush Mangal <43698245+ayushtues@users.noreply.github.com> Co-authored-by: Patrick von Platen --- docs/source/en/_toctree.yml | 4 + .../en/api/pipelines/consistency_models.mdx | 87 ++++ .../schedulers/cm_stochastic_iterative.mdx | 11 + scripts/convert_consistency_to_diffusers.py | 313 +++++++++++++++ src/diffusers/__init__.py | 2 + src/diffusers/models/unet_2d.py | 8 + src/diffusers/models/unet_2d_blocks.py | 74 +++- src/diffusers/pipelines/__init__.py | 1 + .../pipelines/consistency_models/__init__.py | 1 + .../pipeline_consistency_models.py | 337 ++++++++++++++++ src/diffusers/schedulers/__init__.py | 1 + .../scheduling_consistency_models.py | 380 ++++++++++++++++++ src/diffusers/utils/dummy_pt_objects.py | 30 ++ .../pipelines/consistency_models/__init__.py | 0 .../test_consistency_models.py | 288 +++++++++++++ .../test_scheduler_consistency_model.py | 150 +++++++ tests/schedulers/test_schedulers.py | 36 +- 17 files changed, 1710 insertions(+), 13 deletions(-) create mode 100644 docs/source/en/api/pipelines/consistency_models.mdx create mode 100644 docs/source/en/api/schedulers/cm_stochastic_iterative.mdx create mode 100644 scripts/convert_consistency_to_diffusers.py create mode 100644 src/diffusers/pipelines/consistency_models/__init__.py create mode 100644 src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py create mode 100644 src/diffusers/schedulers/scheduling_consistency_models.py create mode 100644 tests/pipelines/consistency_models/__init__.py create mode 100644 tests/pipelines/consistency_models/test_consistency_models.py create mode 100644 tests/schedulers/test_scheduler_consistency_model.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 72808df049c9..db9e72a4ea20 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -184,6 +184,8 @@ title: Audio Diffusion - local: api/pipelines/audioldm title: AudioLDM + - local: api/pipelines/consistency_models + title: Consistency Models - local: api/pipelines/controlnet title: ControlNet - local: api/pipelines/cycle_diffusion @@ -274,6 +276,8 @@ - sections: - local: api/schedulers/overview title: Overview + - local: api/schedulers/cm_stochastic_iterative + title: Consistency Model Multistep Scheduler - local: api/schedulers/ddim title: DDIM - local: api/schedulers/ddim_inverse diff --git a/docs/source/en/api/pipelines/consistency_models.mdx b/docs/source/en/api/pipelines/consistency_models.mdx new file mode 100644 index 000000000000..715743b87a12 --- /dev/null +++ b/docs/source/en/api/pipelines/consistency_models.mdx @@ -0,0 +1,87 @@ +# Consistency Models + +Consistency Models were proposed in [Consistency Models](https://arxiv.org/abs/2303.01469) by Yang Song, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever. + +The abstract of the [paper](https://arxiv.org/pdf/2303.01469.pdf) is as follows: + +*Diffusion models have significantly advanced the fields of image, audio, and video generation, but they depend on an iterative sampling process that causes slow generation. To overcome this limitation, we propose consistency models, a new family of models that generate high quality samples by directly mapping noise to data. They support fast one-step generation by design, while still allowing multistep sampling to trade compute for sample quality. They also support zero-shot data editing, such as image inpainting, colorization, and super-resolution, without requiring explicit training on these tasks. Consistency models can be trained either by distilling pre-trained diffusion models, or as standalone generative models altogether. Through extensive experiments, we demonstrate that they outperform existing distillation techniques for diffusion models in one- and few-step sampling, achieving the new state-of-the-art FID of 3.55 on CIFAR-10 and 6.20 on ImageNet 64x64 for one-step generation. When trained in isolation, consistency models become a new family of generative models that can outperform existing one-step, non-adversarial generative models on standard benchmarks such as CIFAR-10, ImageNet 64x64 and LSUN 256x256. * + +Resources: + +* [Paper](https://arxiv.org/abs/2303.01469) +* [Original Code](https://github.com/openai/consistency_models) + +Available Checkpoints are: +- *cd_imagenet64_l2 (64x64 resolution)* [openai/consistency-model-pipelines](https://huggingface.co/openai/consistency-model-pipelines) +- *cd_imagenet64_lpips (64x64 resolution)* [openai/diffusers-cd_imagenet64_lpips](https://huggingface.co/openai/diffusers-cd_imagenet64_lpips) +- *ct_imagenet64 (64x64 resolution)* [openai/diffusers-ct_imagenet64](https://huggingface.co/openai/diffusers-ct_imagenet64) +- *cd_bedroom256_l2 (256x256 resolution)* [openai/diffusers-cd_bedroom256_l2](https://huggingface.co/openai/diffusers-cd_bedroom256_l2) +- *cd_bedroom256_lpips (256x256 resolution)* [openai/diffusers-cd_bedroom256_lpips](https://huggingface.co/openai/diffusers-cd_bedroom256_lpips) +- *ct_bedroom256 (256x256 resolution)* [openai/diffusers-ct_bedroom256](https://huggingface.co/openai/diffusers-ct_bedroom256) +- *cd_cat256_l2 (256x256 resolution)* [openai/diffusers-cd_cat256_l2](https://huggingface.co/openai/diffusers-cd_cat256_l2) +- *cd_cat256_lpips (256x256 resolution)* [openai/diffusers-cd_cat256_lpips](https://huggingface.co/openai/diffusers-cd_cat256_lpips) +- *ct_cat256 (256x256 resolution)* [openai/diffusers-ct_cat256](https://huggingface.co/openai/diffusers-ct_cat256) + +## Available Pipelines + +| Pipeline | Tasks | Demo | Colab | +|:---:|:---:|:---:|:---:| +| [ConsistencyModelPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_consistency_models.py) | *Unconditional Image Generation* | | | + +This pipeline was contributed by our community members [dg845](https://github.com/dg845) and [ayushtues](https://huggingface.co/ayushtues) :heart: + +## Usage Example + +```python +import torch + +from diffusers import ConsistencyModelPipeline + +device = "cuda" +# Load the cd_imagenet64_l2 checkpoint. +model_id_or_path = "openai/diffusers-cd_imagenet64_l2" +pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) +pipe.to(device) + +# Onestep Sampling +image = pipe(num_inference_steps=1).images[0] +image.save("consistency_model_onestep_sample.png") + +# Onestep sampling, class-conditional image generation +# ImageNet-64 class label 145 corresponds to king penguins +image = pipe(num_inference_steps=1, class_labels=145).images[0] +image.save("consistency_model_onestep_sample_penguin.png") + +# Multistep sampling, class-conditional image generation +# Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo. +# https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77 +image = pipe(timesteps=[22, 0], class_labels=145).images[0] +image.save("consistency_model_multistep_sample_penguin.png") +``` + +For an additional speed-up, one can also make use of `torch.compile`. Multiple images can be generated in <1 second as follows: + +```py +import torch +from diffusers import ConsistencyModelPipeline + +device = "cuda" +# Load the cd_bedroom256_lpips checkpoint. +model_id_or_path = "openai/diffusers-cd_bedroom256_lpips" +pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) +pipe.to(device) + +pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + +# Multistep sampling +# Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo: +# https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L83 +for _ in range(10): + image = pipe(timesteps=[17, 0]).images[0] + image.show() +``` + +## ConsistencyModelPipeline +[[autodoc]] ConsistencyModelPipeline + - all + - __call__ diff --git a/docs/source/en/api/schedulers/cm_stochastic_iterative.mdx b/docs/source/en/api/schedulers/cm_stochastic_iterative.mdx new file mode 100644 index 000000000000..0cc40bde47a0 --- /dev/null +++ b/docs/source/en/api/schedulers/cm_stochastic_iterative.mdx @@ -0,0 +1,11 @@ +# Consistency Model Multistep Scheduler + +## Overview + +Multistep and onestep scheduler (Algorithm 1) introduced alongside consistency models in the paper [Consistency Models](https://arxiv.org/abs/2303.01469) by Yang Song, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever. +Based on the [original consistency models implementation](https://github.com/openai/consistency_models). +Should generate good samples from [`ConsistencyModelPipeline`] in one or a small number of steps. + +## CMStochasticIterativeScheduler +[[autodoc]] CMStochasticIterativeScheduler + diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py new file mode 100644 index 000000000000..5a6158bb9867 --- /dev/null +++ b/scripts/convert_consistency_to_diffusers.py @@ -0,0 +1,313 @@ +import argparse +import os + +import torch + +from diffusers import ( + CMStochasticIterativeScheduler, + ConsistencyModelPipeline, + UNet2DModel, +) + + +TEST_UNET_CONFIG = { + "sample_size": 32, + "in_channels": 3, + "out_channels": 3, + "layers_per_block": 2, + "num_class_embeds": 1000, + "block_out_channels": [32, 64], + "attention_head_dim": 8, + "down_block_types": [ + "ResnetDownsampleBlock2D", + "AttnDownBlock2D", + ], + "up_block_types": [ + "AttnUpBlock2D", + "ResnetUpsampleBlock2D", + ], + "resnet_time_scale_shift": "scale_shift", + "upsample_type": "resnet", + "downsample_type": "resnet", +} + +IMAGENET_64_UNET_CONFIG = { + "sample_size": 64, + "in_channels": 3, + "out_channels": 3, + "layers_per_block": 3, + "num_class_embeds": 1000, + "block_out_channels": [192, 192 * 2, 192 * 3, 192 * 4], + "attention_head_dim": 64, + "down_block_types": [ + "ResnetDownsampleBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + ], + "up_block_types": [ + "AttnUpBlock2D", + "AttnUpBlock2D", + "AttnUpBlock2D", + "ResnetUpsampleBlock2D", + ], + "resnet_time_scale_shift": "scale_shift", + "upsample_type": "resnet", + "downsample_type": "resnet", +} + +LSUN_256_UNET_CONFIG = { + "sample_size": 256, + "in_channels": 3, + "out_channels": 3, + "layers_per_block": 2, + "num_class_embeds": None, + "block_out_channels": [256, 256, 256 * 2, 256 * 2, 256 * 4, 256 * 4], + "attention_head_dim": 64, + "down_block_types": [ + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + ], + "up_block_types": [ + "AttnUpBlock2D", + "AttnUpBlock2D", + "AttnUpBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ], + "resnet_time_scale_shift": "default", + "upsample_type": "resnet", + "downsample_type": "resnet", +} + +CD_SCHEDULER_CONFIG = { + "num_train_timesteps": 40, + "sigma_min": 0.002, + "sigma_max": 80.0, +} + +CT_IMAGENET_64_SCHEDULER_CONFIG = { + "num_train_timesteps": 201, + "sigma_min": 0.002, + "sigma_max": 80.0, +} + +CT_LSUN_256_SCHEDULER_CONFIG = { + "num_train_timesteps": 151, + "sigma_min": 0.002, + "sigma_max": 80.0, +} + + +def str2bool(v): + """ + https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse + """ + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("boolean value expected") + + +def convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=False): + new_checkpoint[f"{new_prefix}.norm1.weight"] = checkpoint[f"{old_prefix}.in_layers.0.weight"] + new_checkpoint[f"{new_prefix}.norm1.bias"] = checkpoint[f"{old_prefix}.in_layers.0.bias"] + new_checkpoint[f"{new_prefix}.conv1.weight"] = checkpoint[f"{old_prefix}.in_layers.2.weight"] + new_checkpoint[f"{new_prefix}.conv1.bias"] = checkpoint[f"{old_prefix}.in_layers.2.bias"] + new_checkpoint[f"{new_prefix}.time_emb_proj.weight"] = checkpoint[f"{old_prefix}.emb_layers.1.weight"] + new_checkpoint[f"{new_prefix}.time_emb_proj.bias"] = checkpoint[f"{old_prefix}.emb_layers.1.bias"] + new_checkpoint[f"{new_prefix}.norm2.weight"] = checkpoint[f"{old_prefix}.out_layers.0.weight"] + new_checkpoint[f"{new_prefix}.norm2.bias"] = checkpoint[f"{old_prefix}.out_layers.0.bias"] + new_checkpoint[f"{new_prefix}.conv2.weight"] = checkpoint[f"{old_prefix}.out_layers.3.weight"] + new_checkpoint[f"{new_prefix}.conv2.bias"] = checkpoint[f"{old_prefix}.out_layers.3.bias"] + + if has_skip: + new_checkpoint[f"{new_prefix}.conv_shortcut.weight"] = checkpoint[f"{old_prefix}.skip_connection.weight"] + new_checkpoint[f"{new_prefix}.conv_shortcut.bias"] = checkpoint[f"{old_prefix}.skip_connection.bias"] + + return new_checkpoint + + +def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_dim=None): + weight_q, weight_k, weight_v = checkpoint[f"{old_prefix}.qkv.weight"].chunk(3, dim=0) + bias_q, bias_k, bias_v = checkpoint[f"{old_prefix}.qkv.bias"].chunk(3, dim=0) + + new_checkpoint[f"{new_prefix}.group_norm.weight"] = checkpoint[f"{old_prefix}.norm.weight"] + new_checkpoint[f"{new_prefix}.group_norm.bias"] = checkpoint[f"{old_prefix}.norm.bias"] + + new_checkpoint[f"{new_prefix}.to_q.weight"] = weight_q.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_q.bias"] = bias_q.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_k.weight"] = weight_k.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_k.bias"] = bias_k.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_v.weight"] = weight_v.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_v.bias"] = bias_v.squeeze(-1).squeeze(-1) + + new_checkpoint[f"{new_prefix}.to_out.0.weight"] = ( + checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1) + ) + new_checkpoint[f"{new_prefix}.to_out.0.bias"] = checkpoint[f"{old_prefix}.proj_out.bias"].squeeze(-1).squeeze(-1) + + return new_checkpoint + + +def con_pt_to_diffuser(checkpoint_path: str, unet_config): + checkpoint = torch.load(checkpoint_path, map_location="cpu") + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"] + + if unet_config["num_class_embeds"] is not None: + new_checkpoint["class_embedding.weight"] = checkpoint["label_emb.weight"] + + new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"] + + down_block_types = unet_config["down_block_types"] + layers_per_block = unet_config["layers_per_block"] + attention_head_dim = unet_config["attention_head_dim"] + channels_list = unet_config["block_out_channels"] + current_layer = 1 + prev_channels = channels_list[0] + + for i, layer_type in enumerate(down_block_types): + current_channels = channels_list[i] + downsample_block_has_skip = current_channels != prev_channels + if layer_type == "ResnetDownsampleBlock2D": + for j in range(layers_per_block): + new_prefix = f"down_blocks.{i}.resnets.{j}" + old_prefix = f"input_blocks.{current_layer}.0" + has_skip = True if j == 0 and downsample_block_has_skip else False + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip) + current_layer += 1 + + elif layer_type == "AttnDownBlock2D": + for j in range(layers_per_block): + new_prefix = f"down_blocks.{i}.resnets.{j}" + old_prefix = f"input_blocks.{current_layer}.0" + has_skip = True if j == 0 and downsample_block_has_skip else False + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip) + new_prefix = f"down_blocks.{i}.attentions.{j}" + old_prefix = f"input_blocks.{current_layer}.1" + new_checkpoint = convert_attention( + checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim + ) + current_layer += 1 + + if i != len(down_block_types) - 1: + new_prefix = f"down_blocks.{i}.downsamplers.0" + old_prefix = f"input_blocks.{current_layer}.0" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + current_layer += 1 + + prev_channels = current_channels + + # hardcoded the mid-block for now + new_prefix = "mid_block.resnets.0" + old_prefix = "middle_block.0" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + new_prefix = "mid_block.attentions.0" + old_prefix = "middle_block.1" + new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim) + new_prefix = "mid_block.resnets.1" + old_prefix = "middle_block.2" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + + current_layer = 0 + up_block_types = unet_config["up_block_types"] + + for i, layer_type in enumerate(up_block_types): + if layer_type == "ResnetUpsampleBlock2D": + for j in range(layers_per_block + 1): + new_prefix = f"up_blocks.{i}.resnets.{j}" + old_prefix = f"output_blocks.{current_layer}.0" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True) + current_layer += 1 + + if i != len(up_block_types) - 1: + new_prefix = f"up_blocks.{i}.upsamplers.0" + old_prefix = f"output_blocks.{current_layer-1}.1" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + elif layer_type == "AttnUpBlock2D": + for j in range(layers_per_block + 1): + new_prefix = f"up_blocks.{i}.resnets.{j}" + old_prefix = f"output_blocks.{current_layer}.0" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True) + new_prefix = f"up_blocks.{i}.attentions.{j}" + old_prefix = f"output_blocks.{current_layer}.1" + new_checkpoint = convert_attention( + checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim + ) + current_layer += 1 + + if i != len(up_block_types) - 1: + new_prefix = f"up_blocks.{i}.upsamplers.0" + old_prefix = f"output_blocks.{current_layer-1}.2" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + + new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"] + new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"] + new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"] + + return new_checkpoint + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--unet_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.") + parser.add_argument( + "--dump_path", default=None, type=str, required=True, help="Path to output the converted UNet model." + ) + parser.add_argument("--class_cond", default=True, type=str, help="Whether the model is class-conditional.") + + args = parser.parse_args() + args.class_cond = str2bool(args.class_cond) + + ckpt_name = os.path.basename(args.unet_path) + print(f"Checkpoint: {ckpt_name}") + + # Get U-Net config + if "imagenet64" in ckpt_name: + unet_config = IMAGENET_64_UNET_CONFIG + elif "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)): + unet_config = LSUN_256_UNET_CONFIG + elif "test" in ckpt_name: + unet_config = TEST_UNET_CONFIG + else: + raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.") + + if not args.class_cond: + unet_config["num_class_embeds"] = None + + converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, unet_config) + + image_unet = UNet2DModel(**unet_config) + image_unet.load_state_dict(converted_unet_ckpt) + + # Get scheduler config + if "cd" in ckpt_name or "test" in ckpt_name: + scheduler_config = CD_SCHEDULER_CONFIG + elif "ct" in ckpt_name and "imagenet64" in ckpt_name: + scheduler_config = CT_IMAGENET_64_SCHEDULER_CONFIG + elif "ct" in ckpt_name and "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)): + scheduler_config = CT_LSUN_256_SCHEDULER_CONFIG + else: + raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.") + + cm_scheduler = CMStochasticIterativeScheduler(**scheduler_config) + + consistency_model = ConsistencyModelPipeline(unet=image_unet, scheduler=cm_scheduler) + consistency_model.save_pretrained(args.dump_path) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 764f9204dffb..f0c25edd3fdc 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -58,6 +58,7 @@ ) from .pipelines import ( AudioPipelineOutput, + ConsistencyModelPipeline, DanceDiffusionPipeline, DDIMPipeline, DDPMPipeline, @@ -72,6 +73,7 @@ ScoreSdeVePipeline, ) from .schedulers import ( + CMStochasticIterativeScheduler, DDIMInverseScheduler, DDIMParallelScheduler, DDIMScheduler, diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 7077aa889190..3b17acd3d829 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -66,6 +66,10 @@ class UNet2DModel(ModelMixin, ConfigMixin): layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. + downsample_type (`str`, *optional*, defaults to `conv`): + The downsample type for downsampling layers. Choose between "conv" and "resnet" + upsample_type (`str`, *optional*, defaults to `conv`): + The upsample type for upsampling layers. Choose between "conv" and "resnet" act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. @@ -96,6 +100,8 @@ def __init__( layers_per_block: int = 2, mid_block_scale_factor: float = 1, downsample_padding: int = 1, + downsample_type: str = "conv", + upsample_type: str = "conv", act_fn: str = "silu", attention_head_dim: Optional[int] = 8, norm_num_groups: int = 32, @@ -168,6 +174,7 @@ def __init__( attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, ) self.down_blocks.append(down_block) @@ -207,6 +214,7 @@ def __init__( resnet_groups=norm_num_groups, attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, ) self.up_blocks.append(up_block) prev_output_channel = output_channel diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index eee7e6023e88..d4e7bd4e03f7 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -51,6 +51,7 @@ def get_down_block( resnet_out_scale_factor=1.0, cross_attention_norm=None, attention_head_dim=None, + downsample_type=None, ): # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: @@ -88,18 +89,22 @@ def get_down_block( output_scale_factor=resnet_out_scale_factor, ) elif down_block_type == "AttnDownBlock2D": + if add_downsample is False: + downsample_type = None + else: + downsample_type = downsample_type or "conv" # default to 'conv' return AttnDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, - add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, ) elif down_block_type == "CrossAttnDownBlock2D": if cross_attention_dim is None: @@ -239,6 +244,7 @@ def get_up_block( resnet_out_scale_factor=1.0, cross_attention_norm=None, attention_head_dim=None, + upsample_type=None, ): # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: @@ -319,18 +325,23 @@ def get_up_block( cross_attention_norm=cross_attention_norm, ) elif up_block_type == "AttnUpBlock2D": + if add_upsample is False: + upsample_type = None + else: + upsample_type = upsample_type or "conv" # default to 'conv' + return AttnUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, - add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, ) elif up_block_type == "SkipUpBlock2D": return SkipUpBlock2D( @@ -747,11 +758,12 @@ def __init__( attention_head_dim=1, output_scale_factor=1.0, downsample_padding=1, - add_downsample=True, + downsample_type="conv", ): super().__init__() resnets = [] attentions = [] + self.downsample_type = downsample_type if attention_head_dim is None: logger.warn( @@ -793,7 +805,7 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - if add_downsample: + if downsample_type == "conv": self.downsamplers = nn.ModuleList( [ Downsample2D( @@ -801,6 +813,24 @@ def __init__( ) ] ) + elif downsample_type == "resnet": + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + down=True, + ) + ] + ) else: self.downsamplers = None @@ -810,11 +840,14 @@ def forward(self, hidden_states, temb=None, upsample_size=None): for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + if self.downsample_type == "resnet": + hidden_states = downsampler(hidden_states, temb=temb) + else: + hidden_states = downsampler(hidden_states) output_states += (hidden_states,) @@ -1860,12 +1893,14 @@ def __init__( resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=1.0, - add_upsample=True, + upsample_type="conv", ): super().__init__() resnets = [] attentions = [] + self.upsample_type = upsample_type + if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." @@ -1908,8 +1943,26 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - if add_upsample: + if upsample_type == "conv": self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + elif upsample_type == "resnet": + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + up=True, + ) + ] + ) else: self.upsamplers = None @@ -1925,7 +1978,10 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + if self.upsample_type == "resnet": + hidden_states = upsampler(hidden_states, temb=temb) + else: + hidden_states = upsampler(hidden_states) return hidden_states diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index ca57756c6aa4..3926b3413e01 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -16,6 +16,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: + from .consistency_models import ConsistencyModelPipeline from .dance_diffusion import DanceDiffusionPipeline from .ddim import DDIMPipeline from .ddpm import DDPMPipeline diff --git a/src/diffusers/pipelines/consistency_models/__init__.py b/src/diffusers/pipelines/consistency_models/__init__.py new file mode 100644 index 000000000000..fd78ddb3aae2 --- /dev/null +++ b/src/diffusers/pipelines/consistency_models/__init__.py @@ -0,0 +1 @@ +from .pipeline_consistency_models import ConsistencyModelPipeline diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py new file mode 100644 index 000000000000..4e72e3fdbafe --- /dev/null +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -0,0 +1,337 @@ +from typing import Callable, List, Optional, Union + +import torch + +from ...models import UNet2DModel +from ...schedulers import CMStochasticIterativeScheduler +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + + >>> from diffusers import ConsistencyModelPipeline + + >>> device = "cuda" + >>> # Load the cd_imagenet64_l2 checkpoint. + >>> model_id_or_path = "openai/diffusers-cd_imagenet64_l2" + >>> pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) + >>> pipe.to(device) + + >>> # Onestep Sampling + >>> image = pipe(num_inference_steps=1).images[0] + >>> image.save("cd_imagenet64_l2_onestep_sample.png") + + >>> # Onestep sampling, class-conditional image generation + >>> # ImageNet-64 class label 145 corresponds to king penguins + >>> image = pipe(num_inference_steps=1, class_labels=145).images[0] + >>> image.save("cd_imagenet64_l2_onestep_sample_penguin.png") + + >>> # Multistep sampling, class-conditional image generation + >>> # Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo: + >>> # https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77 + >>> image = pipe(num_inference_steps=None, timesteps=[22, 0], class_labels=145).images[0] + >>> image.save("cd_imagenet64_l2_multistep_sample_penguin.png") + ``` +""" + + +class ConsistencyModelPipeline(DiffusionPipeline): + r""" + Pipeline for consistency models for unconditional or class-conditional image generation, as introduced in [1]. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + [1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models" + https://arxiv.org/pdf/2303.01469 + + Args: + unet ([`UNet2DModel`]): + Unconditional or class-conditional U-Net architecture to denoise image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the image latents. Currently only compatible + with [`CMStochasticIterativeScheduler`]. + """ + + def __init__(self, unet: UNet2DModel, scheduler: CMStochasticIterativeScheduler) -> None: + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + ) + + self.safety_checker = None + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.unet]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Follows diffusers.VaeImageProcessor.postprocess + def postprocess_image(self, sample: torch.FloatTensor, output_type: str = "pil"): + if output_type not in ["pt", "np", "pil"]: + raise ValueError( + f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']" + ) + + # Equivalent to diffusers.VaeImageProcessor.denormalize + sample = (sample / 2 + 0.5).clamp(0, 1) + if output_type == "pt": + return sample + + # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy + sample = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "np": + return sample + + # Output_type must be 'pil' + sample = self.numpy_to_pil(sample) + return sample + + def prepare_class_labels(self, batch_size, device, class_labels=None): + if self.unet.config.num_class_embeds is not None: + if isinstance(class_labels, list): + class_labels = torch.tensor(class_labels, dtype=torch.int) + elif isinstance(class_labels, int): + assert batch_size == 1, "Batch size must be 1 if classes is an int" + class_labels = torch.tensor([class_labels], dtype=torch.int) + elif class_labels is None: + # Randomly generate batch_size class labels + # TODO: should use generator here? int analogue of randn_tensor is not exposed in ...utils + class_labels = torch.randint(0, self.unet.config.num_class_embeds, size=(batch_size,)) + class_labels = class_labels.to(device) + else: + class_labels = None + return class_labels + + def check_inputs(self, num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps): + if num_inference_steps is None and timesteps is None: + raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.") + + if num_inference_steps is not None and timesteps is not None: + logger.warning( + f"Both `num_inference_steps`: {num_inference_steps} and `timesteps`: {timesteps} are supplied;" + " `timesteps` will be used over `num_inference_steps`." + ) + + if latents is not None: + expected_shape = (batch_size, 3, img_size, img_size) + if latents.shape != expected_shape: + raise ValueError(f"The shape of latents is {latents.shape} but is expected to be {expected_shape}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + batch_size: int = 1, + class_labels: Optional[Union[torch.Tensor, List[int], int]] = None, + num_inference_steps: int = 1, + timesteps: List[int] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + ): + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + class_labels (`torch.Tensor` or `List[int]` or `int`, *optional*): + Optional class labels for conditioning class-conditional consistency models. Will not be used if the + model is not class-conditional. + num_inference_steps (`int`, *optional*, defaults to 1): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Prepare call parameters + img_size = self.unet.config.sample_size + device = self._execution_device + + # 1. Check inputs + self.check_inputs(num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps) + + # 2. Prepare image latents + # Sample image latents x_0 ~ N(0, sigma_0^2 * I) + sample = self.prepare_latents( + batch_size=batch_size, + num_channels=self.unet.config.in_channels, + height=img_size, + width=img_size, + dtype=self.unet.dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 3. Handle class_labels for class-conditional models + class_labels = self.prepare_class_labels(batch_size, device, class_labels=class_labels) + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps) + timesteps = self.scheduler.timesteps + + # 5. Denoising loop + # Multistep sampling: implements Algorithm 1 in the paper + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + scaled_sample = self.scheduler.scale_model_input(sample, t) + model_output = self.unet(scaled_sample, t, class_labels=class_labels, return_dict=False)[0] + + sample = self.scheduler.step(model_output, t, sample, generator=generator)[0] + + # call the callback, if provided + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, sample) + + # 6. Post-process image sample + image = self.postprocess_image(sample, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 935759bbb6af..0a07ce4baed2 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -28,6 +28,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: + from .scheduling_consistency_models import CMStochasticIterativeScheduler from .scheduling_ddim import DDIMScheduler from .scheduling_ddim_inverse import DDIMInverseScheduler from .scheduling_ddim_parallel import DDIMParallelScheduler diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py new file mode 100644 index 000000000000..fb296054d65b --- /dev/null +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -0,0 +1,380 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging, randn_tensor +from .scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class CMStochasticIterativeSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): + """ + Multistep and onestep sampling for consistency models from Song et al. 2023 [1]. This implements Algorithm 1 in the + paper [1]. + + [1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models" + https://arxiv.org/pdf/2303.01469 [2] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based + Generative Models." https://arxiv.org/abs/2206.00364 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + sigma_min (`float`): + Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the original implementation. + sigma_max (`float`): + Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the original implementation. + sigma_data (`float`): + The standard deviation of the data distribution, following the EDM paper [2]. This was set to 0.5 in the + original implementation, which is also the original value suggested in the EDM paper. + s_noise (`float`): + The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000, + 1.011]. This was set to 1.0 in the original implementation. + rho (`float`): + The rho parameter used for calculating the Karras sigma schedule, introduced in the EDM paper [2]. This was + set to 7.0 in the original implementation, which is also the original value suggested in the EDM paper. + clip_denoised (`bool`): + Whether to clip the denoised outputs to `(-1, 1)`. Defaults to `True`. + timesteps (`List` or `np.ndarray` or `torch.Tensor`, *optional*): + Optionally, an explicit timestep schedule can be specified. The timesteps are expected to be in increasing + order. + """ + + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 40, + sigma_min: float = 0.002, + sigma_max: float = 80.0, + sigma_data: float = 0.5, + s_noise: float = 1.0, + rho: float = 7.0, + clip_denoised: bool = True, + ): + # standard deviation of the initial noise distribution + self.init_noise_sigma = sigma_max + + ramp = np.linspace(0, 1, num_train_timesteps) + sigmas = self._convert_to_karras(ramp) + timesteps = self.sigma_to_t(sigmas) + + # setable values + self.num_inference_steps = None + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps) + self.custom_timesteps = False + self.is_scale_input_called = False + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + return indices.item() + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] + ) -> torch.FloatTensor: + """ + Scales the consistency model input by `(sigma**2 + sigma_data**2) ** 0.5`, following the EDM model. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain + Returns: + `torch.FloatTensor`: scaled input sample + """ + # Get sigma corresponding to timestep + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_idx = self.index_for_timestep(timestep) + sigma = self.sigmas[step_idx] + + sample = sample / ((sigma**2 + self.config.sigma_data**2) ** 0.5) + + self.is_scale_input_called = True + return sample + + def sigma_to_t(self, sigmas: Union[float, np.ndarray]): + """ + Gets scaled timesteps from the Karras sigmas, for input to the consistency model. + + Args: + sigmas (`float` or `np.ndarray`): single Karras sigma or array of Karras sigmas + Returns: + `float` or `np.ndarray`: scaled input timestep or scaled input timestep array + """ + if not isinstance(sigmas, np.ndarray): + sigmas = np.array(sigmas, dtype=np.float64) + + timesteps = 1000 * 0.25 * np.log(sigmas + 1e-44) + + return timesteps + + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + ): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, optional): + custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps` + must be `None`. + """ + if num_inference_steps is None and timesteps is None: + raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.") + + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `timesteps`.") + + # Follow DDPMScheduler custom timesteps logic + if timesteps is not None: + for i in range(1, len(timesteps)): + if timesteps[i] >= timesteps[i - 1]: + raise ValueError("`timesteps` must be in descending order.") + + if timesteps[0] >= self.config.num_train_timesteps: + raise ValueError( + f"`timesteps` must start before `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps}." + ) + + timesteps = np.array(timesteps, dtype=np.int64) + self.custom_timesteps = True + else: + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + self.custom_timesteps = False + + # Map timesteps to Karras sigmas directly for multistep sampling + # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675 + num_train_timesteps = self.config.num_train_timesteps + ramp = timesteps[::-1].copy() + ramp = ramp / (num_train_timesteps - 1) + sigmas = self._convert_to_karras(ramp) + timesteps = self.sigma_to_t(sigmas) + + sigmas = np.concatenate([sigmas, [self.sigma_min]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=device) + + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) + else: + self.timesteps = torch.from_numpy(timesteps).to(device=device) + + # Modified _convert_to_karras implementation that takes in ramp as argument + def _convert_to_karras(self, ramp): + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = self.config.sigma_min + sigma_max: float = self.config.sigma_max + + rho = self.config.rho + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def get_scalings(self, sigma): + sigma_data = self.config.sigma_data + + c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) + c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + return c_skip, c_out + + def get_scalings_for_boundary_condition(self, sigma): + """ + Gets the scalings used in the consistency model parameterization, following Appendix C of the original paper. + This enforces the consistency model boundary condition. + + Note that `epsilon` in the equations for c_skip and c_out is set to sigma_min. + + Args: + sigma (`torch.FloatTensor`): + The current sigma in the Karras sigma schedule. + Returns: + `tuple`: + A two-element tuple where c_skip (which weights the current sample) is the first element and c_out + (which weights the consistency model output) is the second element. + """ + sigma_min = self.config.sigma_min + sigma_data = self.config.sigma_data + + c_skip = sigma_data**2 / ((sigma - sigma_min) ** 2 + sigma_data**2) + c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + return c_skip, c_out + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[CMStochasticIterativeSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`float`): current timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + generator (`torch.Generator`, *optional*): Random number generator. + return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class + Returns: + [`~schedulers.scheduling_utils.CMStochasticIterativeSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.CMStochasticIterativeSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + f" `{self.__class__}.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + logger.warning( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + sigma_min = self.config.sigma_min + sigma_max = self.config.sigma_max + + step_index = self.index_for_timestep(timestep) + + # sigma_next corresponds to next_t in original implementation + sigma = self.sigmas[step_index] + if step_index + 1 < self.config.num_train_timesteps: + sigma_next = self.sigmas[step_index + 1] + else: + # Set sigma_next to sigma_min + sigma_next = self.sigmas[-1] + + # Get scalings for boundary conditions + c_skip, c_out = self.get_scalings_for_boundary_condition(sigma) + + # 1. Denoise model output using boundary conditions + denoised = c_out * model_output + c_skip * sample + if self.config.clip_denoised: + denoised = denoised.clamp(-1, 1) + + # 2. Sample z ~ N(0, s_noise^2 * I) + # Noise is not used for onestep sampling. + if len(self.timesteps) > 1: + noise = randn_tensor( + model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator + ) + else: + noise = torch.zeros_like(model_output) + z = noise * self.config.s_noise + + sigma_hat = sigma_next.clamp(min=sigma_min, max=sigma_max) + + # 3. Return noisy sample + # tau = sigma_hat, eps = sigma_min + prev_sample = denoised + z * (sigma_hat**2 - sigma_min**2) ** 0.5 + + if not return_dict: + return (prev_sample,) + + return CMStochasticIterativeSchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 7a13bc89e883..20dbf84681d3 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -210,6 +210,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class ConsistencyModelPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DanceDiffusionPipeline(metaclass=DummyObject): _backends = ["torch"] @@ -390,6 +405,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class CMStochasticIterativeScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DDIMInverseScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/consistency_models/__init__.py b/tests/pipelines/consistency_models/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py new file mode 100644 index 000000000000..8dce90318505 --- /dev/null +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -0,0 +1,288 @@ +import gc +import unittest + +import numpy as np +import torch +from torch.backends.cuda import sdp_kernel + +from diffusers import ( + CMStochasticIterativeScheduler, + ConsistencyModelPipeline, + UNet2DModel, +) +from diffusers.utils import randn_tensor, slow, torch_device +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_2, require_torch_gpu + +from ..pipeline_params import UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS, UNCONDITIONAL_IMAGE_GENERATION_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class ConsistencyModelPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = ConsistencyModelPipeline + params = UNCONDITIONAL_IMAGE_GENERATION_PARAMS + batch_params = UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS + + # Override required_optional_params to remove num_images_per_prompt + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "output_type", + "return_dict", + "callback", + "callback_steps", + ] + ) + + @property + def dummy_uncond_unet(self): + unet = UNet2DModel.from_pretrained( + "diffusers/consistency-models-test", + subfolder="test_unet", + ) + return unet + + @property + def dummy_cond_unet(self): + unet = UNet2DModel.from_pretrained( + "diffusers/consistency-models-test", + subfolder="test_unet_class_cond", + ) + return unet + + def get_dummy_components(self, class_cond=False): + if class_cond: + unet = self.dummy_cond_unet + else: + unet = self.dummy_uncond_unet + + # Default to CM multistep sampler + scheduler = CMStochasticIterativeScheduler( + num_train_timesteps=40, + sigma_min=0.002, + sigma_max=80.0, + ) + + components = { + "unet": unet, + "scheduler": scheduler, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "batch_size": 1, + "num_inference_steps": None, + "timesteps": [22, 0], + "generator": generator, + "output_type": "np", + } + + return inputs + + def test_consistency_model_pipeline_multistep(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = ConsistencyModelPipeline(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.3572, 0.6273, 0.4031, 0.3961, 0.4321, 0.5730, 0.5266, 0.4780, 0.5004]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_consistency_model_pipeline_multistep_class_cond(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components(class_cond=True) + pipe = ConsistencyModelPipeline(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["class_labels"] = 0 + image = pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.3572, 0.6273, 0.4031, 0.3961, 0.4321, 0.5730, 0.5266, 0.4780, 0.5004]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_consistency_model_pipeline_onestep(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = ConsistencyModelPipeline(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 1 + inputs["timesteps"] = None + image = pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.5004, 0.5004, 0.4994, 0.5008, 0.4976, 0.5018, 0.4990, 0.4982, 0.4987]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_consistency_model_pipeline_onestep_class_cond(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components(class_cond=True) + pipe = ConsistencyModelPipeline(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 1 + inputs["timesteps"] = None + inputs["class_labels"] = 0 + image = pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.5004, 0.5004, 0.4994, 0.5008, 0.4976, 0.5018, 0.4990, 0.4982, 0.4987]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + +@slow +@require_torch_gpu +class ConsistencyModelPipelineSlowTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_inputs(self, seed=0, get_fixed_latents=False, device="cpu", dtype=torch.float32, shape=(1, 3, 64, 64)): + generator = torch.manual_seed(seed) + + inputs = { + "num_inference_steps": None, + "timesteps": [22, 0], + "class_labels": 0, + "generator": generator, + "output_type": "np", + } + + if get_fixed_latents: + latents = self.get_fixed_latents(seed=seed, device=device, dtype=dtype, shape=shape) + inputs["latents"] = latents + + return inputs + + def get_fixed_latents(self, seed=0, device="cpu", dtype=torch.float32, shape=(1, 3, 64, 64)): + if type(device) == str: + device = torch.device(device) + generator = torch.Generator(device=device).manual_seed(seed) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def test_consistency_model_cd_multistep(self): + unet = UNet2DModel.from_pretrained("diffusers/consistency_models", subfolder="diffusers_cd_imagenet64_l2") + scheduler = CMStochasticIterativeScheduler( + num_train_timesteps=40, + sigma_min=0.002, + sigma_max=80.0, + ) + pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) + pipe.to(torch_device=torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_inputs() + image = pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + image_slice = image[0, -3:, -3:, -1] + + expected_slice = np.array([0.0888, 0.0881, 0.0666, 0.0479, 0.0292, 0.0195, 0.0201, 0.0163, 0.0254]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2 + + def test_consistency_model_cd_onestep(self): + unet = UNet2DModel.from_pretrained("diffusers/consistency_models", subfolder="diffusers_cd_imagenet64_l2") + scheduler = CMStochasticIterativeScheduler( + num_train_timesteps=40, + sigma_min=0.002, + sigma_max=80.0, + ) + pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) + pipe.to(torch_device=torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_inputs() + inputs["num_inference_steps"] = 1 + inputs["timesteps"] = None + image = pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + image_slice = image[0, -3:, -3:, -1] + + expected_slice = np.array([0.0340, 0.0152, 0.0063, 0.0267, 0.0221, 0.0107, 0.0416, 0.0186, 0.0217]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2 + + @require_torch_2 + def test_consistency_model_cd_multistep_flash_attn(self): + unet = UNet2DModel.from_pretrained("diffusers/consistency_models", subfolder="diffusers_cd_imagenet64_l2") + scheduler = CMStochasticIterativeScheduler( + num_train_timesteps=40, + sigma_min=0.002, + sigma_max=80.0, + ) + pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) + pipe.to(torch_device=torch_device, torch_dtype=torch.float16) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_inputs(get_fixed_latents=True, device=torch_device) + # Ensure usage of flash attention in torch 2.0 + with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + image = pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + image_slice = image[0, -3:, -3:, -1] + + expected_slice = np.array([0.1875, 0.1428, 0.1289, 0.2151, 0.2092, 0.1477, 0.1877, 0.1641, 0.1353]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + @require_torch_2 + def test_consistency_model_cd_onestep_flash_attn(self): + unet = UNet2DModel.from_pretrained("diffusers/consistency_models", subfolder="diffusers_cd_imagenet64_l2") + scheduler = CMStochasticIterativeScheduler( + num_train_timesteps=40, + sigma_min=0.002, + sigma_max=80.0, + ) + pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) + pipe.to(torch_device=torch_device, torch_dtype=torch.float16) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_inputs(get_fixed_latents=True, device=torch_device) + inputs["num_inference_steps"] = 1 + inputs["timesteps"] = None + # Ensure usage of flash attention in torch 2.0 + with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + image = pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + image_slice = image[0, -3:, -3:, -1] + + expected_slice = np.array([0.1663, 0.1948, 0.2275, 0.1680, 0.1204, 0.1245, 0.1858, 0.1338, 0.2095]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 diff --git a/tests/schedulers/test_scheduler_consistency_model.py b/tests/schedulers/test_scheduler_consistency_model.py new file mode 100644 index 000000000000..66f07d024783 --- /dev/null +++ b/tests/schedulers/test_scheduler_consistency_model.py @@ -0,0 +1,150 @@ +import torch + +from diffusers import CMStochasticIterativeScheduler + +from .test_schedulers import SchedulerCommonTest + + +class CMStochasticIterativeSchedulerTest(SchedulerCommonTest): + scheduler_classes = (CMStochasticIterativeScheduler,) + num_inference_steps = 10 + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 201, + "sigma_min": 0.002, + "sigma_max": 80.0, + } + + config.update(**kwargs) + return config + + # Override test_step_shape to add CMStochasticIterativeScheduler-specific logic regarding timesteps + # Problem is that we don't know two timesteps that will always be in the timestep schedule from only the scheduler + # config; scaled sigma_max is always in the timestep schedule, but sigma_min is in the sigma schedule while scaled + # sigma_min is not in the timestep schedule + def test_step_shape(self): + num_inference_steps = 10 + + scheduler_config = self.get_scheduler_config() + scheduler = self.scheduler_classes[0](**scheduler_config) + + scheduler.set_timesteps(num_inference_steps) + + timestep_0 = scheduler.timesteps[0] + timestep_1 = scheduler.timesteps[1] + + sample = self.dummy_sample + residual = 0.1 * sample + + output_0 = scheduler.step(residual, timestep_0, sample).prev_sample + output_1 = scheduler.step(residual, timestep_1, sample).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_timesteps(self): + for timesteps in [10, 50, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_clip_denoised(self): + for clip_denoised in [True, False]: + self.check_over_configs(clip_denoised=clip_denoised) + + def test_full_loop_no_noise_onestep(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps = 1 + scheduler.set_timesteps(num_inference_steps) + timesteps = scheduler.timesteps + + generator = torch.manual_seed(0) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + + for i, t in enumerate(timesteps): + # 1. scale model input + scaled_sample = scheduler.scale_model_input(sample, t) + + # 2. predict noise residual + residual = model(scaled_sample, t) + + # 3. predict previous sample x_t-1 + pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample + + sample = pred_prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 192.7614) < 1e-2 + assert abs(result_mean.item() - 0.2510) < 1e-3 + + def test_full_loop_no_noise_multistep(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + timesteps = [106, 0] + scheduler.set_timesteps(timesteps=timesteps) + timesteps = scheduler.timesteps + + generator = torch.manual_seed(0) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + + for t in timesteps: + # 1. scale model input + scaled_sample = scheduler.scale_model_input(sample, t) + + # 2. predict noise residual + residual = model(scaled_sample, t) + + # 3. predict previous sample x_t-1 + pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample + + sample = pred_prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 347.6357) < 1e-2 + assert abs(result_mean.item() - 0.4527) < 1e-3 + + def test_custom_timesteps_increasing_order(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + timesteps = [39, 30, 12, 15, 0] + + with self.assertRaises(ValueError, msg="`timesteps` must be in descending order."): + scheduler.set_timesteps(timesteps=timesteps) + + def test_custom_timesteps_passing_both_num_inference_steps_and_timesteps(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + timesteps = [39, 30, 12, 1, 0] + num_inference_steps = len(timesteps) + + with self.assertRaises(ValueError, msg="Can only pass one of `num_inference_steps` or `timesteps`."): + scheduler.set_timesteps(num_inference_steps=num_inference_steps, timesteps=timesteps) + + def test_custom_timesteps_too_large(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + timesteps = [scheduler.config.num_train_timesteps] + + with self.assertRaises( + ValueError, + msg="`timesteps` must start before `self.config.train_timesteps`: {scheduler.config.num_train_timesteps}}", + ): + scheduler.set_timesteps(timesteps=timesteps) diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py index d1ae333c0cd2..d9423d621966 100755 --- a/tests/schedulers/test_schedulers.py +++ b/tests/schedulers/test_schedulers.py @@ -24,6 +24,7 @@ import diffusers from diffusers import ( + CMStochasticIterativeScheduler, DDIMScheduler, DEISMultistepScheduler, DiffusionPipeline, @@ -303,6 +304,11 @@ def check_over_configs(self, time_step=0, **config): scheduler_config = self.get_scheduler_config(**config) scheduler = scheduler_class(**scheduler_config) + if scheduler_class == CMStochasticIterativeScheduler: + # Get valid timestep based on sigma_max, which should always be in timestep schedule. + scaled_sigma_max = scheduler.sigma_to_t(scheduler.config.sigma_max) + time_step = scaled_sigma_max + if scheduler_class == VQDiffusionScheduler: num_vec_classes = scheduler_config["num_vec_classes"] sample = self.dummy_sample(num_vec_classes) @@ -323,7 +329,11 @@ def check_over_configs(self, time_step=0, **config): kwargs["num_inference_steps"] = num_inference_steps # Make sure `scale_model_input` is invoked to prevent a warning - if scheduler_class != VQDiffusionScheduler: + if scheduler_class == CMStochasticIterativeScheduler: + # Get valid timestep based on sigma_max, which should always be in timestep schedule. + _ = scheduler.scale_model_input(sample, scaled_sigma_max) + _ = new_scheduler.scale_model_input(sample, scaled_sigma_max) + elif scheduler_class != VQDiffusionScheduler: _ = scheduler.scale_model_input(sample, 0) _ = new_scheduler.scale_model_input(sample, 0) @@ -393,6 +403,10 @@ def test_from_save_pretrained(self): scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) + if scheduler_class == CMStochasticIterativeScheduler: + # Get valid timestep based on sigma_max, which should always be in timestep schedule. + timestep = scheduler.sigma_to_t(scheduler.config.sigma_max) + if scheduler_class == VQDiffusionScheduler: num_vec_classes = scheduler_config["num_vec_classes"] sample = self.dummy_sample(num_vec_classes) @@ -539,6 +553,10 @@ def recursive_check(tuple_object, dict_object): scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) + if scheduler_class == CMStochasticIterativeScheduler: + # Get valid timestep based on sigma_max, which should always be in timestep schedule. + timestep = scheduler.sigma_to_t(scheduler.config.sigma_max) + if scheduler_class == VQDiffusionScheduler: num_vec_classes = scheduler_config["num_vec_classes"] sample = self.dummy_sample(num_vec_classes) @@ -594,7 +612,12 @@ def test_scheduler_public_api(self): if scheduler_class != VQDiffusionScheduler: sample = self.dummy_sample - scaled_sample = scheduler.scale_model_input(sample, 0.0) + if scheduler_class == CMStochasticIterativeScheduler: + # Get valid timestep based on sigma_max, which should always be in timestep schedule. + scaled_sigma_max = scheduler.sigma_to_t(scheduler.config.sigma_max) + scaled_sample = scheduler.scale_model_input(sample, scaled_sigma_max) + else: + scaled_sample = scheduler.scale_model_input(sample, 0.0) self.assertEqual(sample.shape, scaled_sample.shape) def test_add_noise_device(self): @@ -606,7 +629,12 @@ def test_add_noise_device(self): scheduler.set_timesteps(100) sample = self.dummy_sample.to(torch_device) - scaled_sample = scheduler.scale_model_input(sample, 0.0) + if scheduler_class == CMStochasticIterativeScheduler: + # Get valid timestep based on sigma_max, which should always be in timestep schedule. + scaled_sigma_max = scheduler.sigma_to_t(scheduler.config.sigma_max) + scaled_sample = scheduler.scale_model_input(sample, scaled_sigma_max) + else: + scaled_sample = scheduler.scale_model_input(sample, 0.0) self.assertEqual(sample.shape, scaled_sample.shape) noise = torch.randn_like(scaled_sample).to(torch_device) @@ -637,7 +665,7 @@ def test_deprecated_kwargs(self): def test_trained_betas(self): for scheduler_class in self.scheduler_classes: - if scheduler_class == VQDiffusionScheduler: + if scheduler_class in (VQDiffusionScheduler, CMStochasticIterativeScheduler): continue scheduler_config = self.get_scheduler_config() From cd1c5a16457084df3b46cac9b3df42fefd5f4dd8 Mon Sep 17 00:00:00 2001 From: sunhs Date: Thu, 6 Jul 2023 11:21:52 +0800 Subject: [PATCH 15/15] add test case for StableDiffusionKDiffusionPipeline noise_sampler --- .../test_stable_diffusion_k_diffusion.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py index 4eccb871a0cb..25da13d9f922 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py @@ -104,3 +104,33 @@ def test_stable_diffusion_karras_sigmas(self): ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_noise_sampler_seed(self): + sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + sd_pipe.set_scheduler("sample_dpmpp_sde") + + prompt = "A painting of a squirrel eating a burger" + seed = 0 + images1 = sd_pipe( + [prompt], + generator=torch.manual_seed(seed), + noise_sampler_seed=seed, + guidance_scale=9.0, + num_inference_steps=20, + output_type="np", + ).images + images2 = sd_pipe( + [prompt], + generator=torch.manual_seed(seed), + noise_sampler_seed=seed, + guidance_scale=9.0, + num_inference_steps=20, + output_type="np", + ).images + + assert images1.shape == (1, 512, 512, 3) + assert images2.shape == (1, 512, 512, 3) + assert np.abs(images1.flatten() - images2.flatten()).max() < 1e-2